Source code for pychron.git_archive.repo_manager

# ===============================================================================
# Copyright 2013 Jake Ross
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================

# ============= enthought library imports =======================
import hashlib
import os
import shutil
import subprocess
import sys
import time
from datetime import datetime

import git
from git import Repo
from git.exc import GitCommandError
from traits.api import Any, Str, List, Event

from pychron.core.helpers.filetools import fileiter
from pychron.core.progress import open_progress
from pychron.envisage.view_util import open_view
from pychron.git_archive.diff_view import DiffView, DiffModel
from pychron.git_archive.git_objects import GitSha
from pychron.git_archive.history import BaseGitHistory
from pychron.git_archive.merge_view import MergeModel, MergeView
from pychron.git_archive.utils import get_head_commit, ahead_behind, from_gitlog, LOGFMT
from pychron.git_archive.views import NewBranchView
from pychron.loggable import Loggable
from pychron.pychron_constants import DATE_FORMAT, NULL_STR
from pychron.updater.commit_view import CommitView


def get_repository_branch(path):
    r = Repo(path)
    b = r.active_branch
    return b.name


def grep(arg, name):
    process = subprocess.Popen(["grep", "-lr", arg, name], stdout=subprocess.PIPE)
    stdout, stderr = process.communicate()
    return stdout, stderr


def format_date(d):
    return time.strftime("%m/%d/%Y %H:%M", time.gmtime(d))


def isoformat_date(d):
    if isinstance(d, (float, int)):
        d = datetime.fromtimestamp(d)

    return d.strftime(DATE_FORMAT)
    # return time.mktime(time.gmtime(d))


class StashCTX(object):
    def __init__(self, repo):
        self._repo = repo
        self._error = None

    def __enter__(self):
        try:
            self._repo.git.stash()
        except GitCommandError as e:
            self._error = e
            return e

    def __exit__(self, *args, **kw):
        if not self._error:
            try:
                self._repo.git.stash("pop")
            except GitCommandError:
                pass


[docs]class GitRepoManager(Loggable): """ manage a local git repository """ _repo = Any # root=Directory path = Str selected = Any selected_branch = Str selected_path_commits = List selected_commits = List refresh_commits_table_needed = Event path_dirty = Event remote = Str def set_name(self, p): self.name = "{}<GitRepo>".format(os.path.basename(p))
[docs] def open_repo(self, name, root=None): """ name: name of repo root: root directory to create new repo """ if root is None: p = name else: p = os.path.join(root, name) self.path = p self.set_name(p) if os.path.isdir(p): self.init_repo(p) return True else: os.mkdir(p) repo = Repo.init(p) self.debug("created new repo {}".format(p)) self._repo = repo return False
[docs] def init_repo(self, path): """ path: absolute path to repo return True if git repo exists """ if os.path.isdir(path): g = os.path.join(path, ".git") if os.path.isdir(g): self._repo = Repo(path) self.set_name(path) return True else: self.debug("{} is not a valid repo. Initializing now".format(path)) self._repo = Repo.init(path) self.set_name(path)
def delete_local_commits(self, remote="origin", branch=None): if branch is None: branch = self._repo.active_branch.name self._repo.git.reset("--hard", "{}/{}".format(remote, branch)) def delete_commits(self, hexsha, remote="origin", branch=None, push=True): if branch is None: branch = self._repo.active_branch.name self._repo.git.reset("--hard", hexsha) if push: self._repo.git.push(remote, branch, "--force") def add_paths_explicit(self, apaths): self.index.add(apaths) def add_paths(self, apaths): if not isinstance(apaths, (list, tuple)): apaths = (apaths,) changes = self.get_local_changes(change_type=("A", "R", "M")) changes = [os.path.join(self.path, c) for c in changes] if changes: self.debug("-------- local changes ---------") for c in changes: self.debug(c) deletes = self.get_local_changes(change_type=("D",)) if deletes: self.debug("-------- deletes ---------") for c in deletes: self.debug(c) untracked = self.untracked_files() if untracked: self.debug("-------- untracked paths --------") for t in untracked: self.debug(t) changes.extend(untracked) self.debug("add paths {}".format(apaths)) ps = [p for p in apaths if p in changes] self.debug("changed paths {}".format(ps)) changed = bool(ps) if ps: for p in ps: self.debug("adding to index: {}".format(os.path.relpath(p, self.path))) self.index.add(ps) ps = [p for p in apaths if p in deletes] self.debug("delete paths {}".format(ps)) delete_changed = bool(ps) if ps: for p in ps: self.debug( "removing from index: {}".format(os.path.relpath(p, self.path)) ) self.index.remove(ps, working_tree=True) return changed or delete_changed def add_ignore(self, *args): ignores = [] p = os.path.join(self.path, ".gitignore") if os.path.isfile(p): with open(p, "r") as rfile: ignores = [line.strip() for line in rfile] args = [a for a in args if a not in ignores] if args: with open(p, "a") as afile: for a in args: afile.write("{}\n".format(a)) self.add(p, commit=False)
[docs] def get_modification_date(self, path): """ "Fri May 18 11:13:57 2018 -0600" :param path: :return: """ d = self.cmd( "log", "-1", '--format="%ad"', "--date=format:{}".format(DATE_FORMAT), "--", path, ) if d: d = d[1:-1] return d
def out_of_date(self, branchname=None): repo = self._repo if branchname is None: branchname = repo.active_branch.name pd = open_progress(2) origin = repo.remotes.origin pd.change_message("Fetching {} {}".format(origin, branchname)) repo.git.fetch(origin, branchname) pd.change_message("Complete") # try: # oref = origin.refs[branchname] # remote_commit = oref.commit # except IndexError: # remote_commit = None # # branch = getattr(repo.heads, branchname) # local_commit = branch.commit local_commit, remote_commit = self._get_local_remote_commit(branchname) self.debug("out of date {} {}".format(local_commit, remote_commit)) return local_commit != remote_commit def _get_local_remote_commit(self, branchname=None): repo = self._repo origin = repo.remotes.origin try: oref = origin.refs[branchname] remote_commit = oref.commit except IndexError: remote_commit = None if branchname is None: branch = repo.active_branch.name else: try: branch = repo.heads[branchname] except AttributeError: return None, None local_commit = branch.commit return local_commit, remote_commit @classmethod def clone_from(cls, url, path): repo = cls() if repo.clone(url, path): return repo # # progress = open_progress(100) # # # # def func(op_code, cur_count, max_count=None, message=''): # # if max_count: # # progress.max = int(max_count) + 2 # # if message: # # message = 'Cloning repository {} -- {}'.format(url, message[2:]) # # progress.change_message(message, auto_increment=False) # # progress.update(int(cur_count)) # # # # if op_code == 66: # # progress.close() # # rprogress = CallableRemoteProgress(func) # rprogress = None # try: # Repo.clone_from(url, path, progress=rprogress) # except GitCommandError as e: # print(e) # shutil.rmtree(path) # # def foo(): # # try: # # Repo.clone_from(url, path, progress=rprogress) # # except GitCommandError: # # shutil.rmtree(path) # # # # evt.set() # # # t = Thread(target=foo) # # t.start() # # period = 0.1 # # while not evt.is_set(): # # st = time.time() # # # v = prog.get_value() # # # if v == n - 2: # # # prog.increase_max(50) # # # n += 50 # # # # # # prog.increment() # # time.sleep(max(0, period - time.time() + st)) # # prog.close() def clone(self, url, path, reraise=False): try: self._repo = Repo.clone_from(url, path) return True except GitCommandError as e: self.warning_dialog( "Cloning error: {}, url={}, path={}".format(e, url, path), position=(100, 100), ) if reraise: raise
[docs] def unpack_blob(self, hexsha, p): """ p: str. should be absolute path """ repo = self._repo tree = repo.commit(hexsha).tree # blob = next((bi for ti in tree.trees # for bi in ti.blobs # if bi.abspath == p), None) blob = None for ts in ((tree,), tree.trees): for ti in ts: for bi in ti.blobs: # print bi.abspath, p if bi.abspath == p: blob = bi break else: print("failed unpacking", p) return blob.data_stream.read() if blob else ""
def shell(self, cmd, *args): repo = self._repo func = getattr(repo.git, cmd) return func(*args) def truncate_repo(self, date="1 month"): repo = self._repo name = os.path.basename(self.path) backup = ".{}".format(name) repo.git.clone("--mirror", "".format(name), "./{}".format(backup)) logs = repo.git.log("--pretty=%H", '-after "{}"'.format(date)) logs = reversed(logs.split("\n")) sha = next(logs) gpath = os.path.join(self.path, ".git", "info", "grafts") with open(gpath, "w") as wfile: wfile.write(sha) repo.git.filter_branch("--tag-name-filter", "cat", "--", "--all") repo.git.gc("--prune=now") def get_dag(self, branch=None, delim="$", limit=None, simplify=True): fmt_args = ("%H", "%ai", "%ar", "%s", "%an", "%ae", "%d", "%P") fmt = delim.join(fmt_args) args = [ "--abbrev-commit", "--topo-order", "--reverse", # '--author-date-order', # '--decorate=full', "--format={}".format(fmt), ] if simplify: args.append("--simplify-by-decoration") if branch == NULL_STR: args.append("--all") else: args.append("-b") args.append(branch) if limit: args.append("-{}".format(limit)) return self._repo.git.log(*args) def commits_iter(self, p, keys=None, limit="-"): repo = self._repo p = os.path.join(repo.working_tree_dir, p) p = p.replace(" ", "\ ") hx = repo.git.log( "--pretty=%H", "--follow", "-{}".format(limit), "--", p ).split("\n") def func(hi): commit = repo.rev_parse(hi) r = [ hi, ] if keys: r.extend([getattr(commit, ki) for ki in keys]) return r return (func(ci) for ci in hx) def odiff(self, a, b, **kw): a = self._repo.commit(a) return a.diff(b, **kw) def diff(self, a, b, *args): return self._git_command(lambda g: g.diff(a, b, *args), "diff") def status(self): return self._git_command(lambda g: g.status(), "status") def report_local_changes(self): self.debug("Local Changes to {}".format(self.path)) for p in self.get_local_changes(): self.debug("\t{}".format(p)) def commit_dialog(self): from pychron.git_archive.commit_dialog import CommitDialog ps = self.get_local_changes() cd = CommitDialog(ps) info = cd.edit_traits() if info.result: index = self.index index.add([mp.path for mp in cd.valid_paths()]) self.commit(cd.commit_message) return True def get_local_changes(self, change_type=("M",)): repo = self._repo diff = repo.index.diff(None) return [ di.a_blob.abspath for change_type in change_type for di in diff.iter_change_type(change_type) ] # diff_str = repo.git.diff('HEAD', '--full-index') # diff_str = StringIO(diff_str) # diff_str.seek(0) # # class ProcessWrapper: # stderr = None # stdout = None # # def __init__(self, f): # self._f = f # # def wait(self, *args, **kw): # pass # # def read(self): # return self._f.read() # # proc = ProcessWrapper(diff_str) # # diff = Diff._index_from_patch_format(repo, proc) # root = self.path # # # # for diff_added in hcommit.diff('HEAD~1').iter_change_type('A'): # print(diff_added) # diff = hcommit.diff() # diff = repo.index.diff(repo.head.commit) # return [os.path.relpath(di.a_blob.abspath, root) for di in diff.iter_change_type('M')] # patches = map(str.strip, diff_str.split('diff --git')) # patches = ['\n'.join(p.split('\n')[2:]) for p in patches[1:]] # # diff_str = StringIO(diff_str) # diff_str.seek(0) # index = Diff._index_from_patch_format(repo, diff_str) # # return index, patches # def get_head_object(self): return get_head_commit(self._repo) def get_head(self, commit=True, hexsha=True): head = self._repo if commit: head = head.commit() if hexsha: head = head.hexsha return head # return self._repo.head.commit.hexsha def cmd(self, cmd, *args): return getattr(self._repo.git, cmd)(*args) def is_dirty(self): return self._repo.is_dirty() def untracked_files(self): lines = self._repo.git.status(porcelain=True, untracked_files=True) # Untracked files preffix in porcelain mode prefix = "?? " untracked_files = list() iswindows = sys.platform == "win32" for line in lines.split("\n"): if not line.startswith(prefix): continue filename = line[len(prefix) :].rstrip("\n") # Special characters are escaped if filename[0] == filename[-1] == '"': filename = filename[1:-1].decode("string_escape") if iswindows: filename = filename.replace("/", "\\") untracked_files.append(os.path.join(self.path, filename)) # finalize_process(proc) return untracked_files def has_staged(self): return self._repo.git.diff("HEAD", "--name-only") # return self._repo.is_dirty() def has_unpushed_commits(self, remote="origin", branch="master"): if self._repo: # return self._repo.git.log('--not', '--remotes', '--oneline') if remote in self._repo.remotes: return self._repo.git.log( "{}/{}..HEAD".format(remote, branch), "--oneline" ) def add_unstaged(self, root=None, add_all=False, extension=None, use_diff=False): if root is None: root = self.path index = self.index def func(ps, extension): if extension: if not isinstance(extension, tuple): extension = (extension,) ps = [pp for pp in ps if os.path.splitext(pp)[1] in extension] if ps: self.debug("adding to index {}".format(ps)) index.add(ps) if use_diff: pass # try: # ps = [diff.a_blob.path for diff in index.diff(None)] # func(ps, extension) # except IOError,e: # print 'exception', e elif add_all: self._repo.git.add(".") else: for r, ds, fs in os.walk(root): ds[:] = [d for d in ds if d[0] != "."] ps = [os.path.join(r, fi) for fi in fs] func(ps, extension) def update_gitignore(self, *args): p = os.path.join(self.path, ".gitignore") # mode = 'a' if os.path.isfile(p) else 'w' args = list(args) if os.path.isfile(p): with open(p, "r") as rfile: for line in fileiter(rfile, strip=True): for i, ai in enumerate(args): if line == ai: args.pop(i) if args: with open(p, "a") as wfile: for ai in args: wfile.write("{}\n".format(ai)) self._add_to_repo(p, msg="updated .gitignore") def get_commit(self, hexsha): repo = self._repo return repo.commit(hexsha) def tag_branch(self, tagname): repo = self._repo repo.create_tag(tagname) def get_current_branch(self): repo = self._repo return repo.active_branch.name def checkout_branch(self, name, inform=True): repo = self._repo if name.startswith("origin"): name = name[7:] remote = repo.remote() rref = getattr(remote.refs, name) repo.create_head(name, rref) branch = repo.heads[name] branch.set_tracking_branch(rref) else: branch = getattr(repo.heads, name) try: branch.checkout() self.selected_branch = name self._load_branch_history() if inform: self.information_dialog('Repository now on branch "{}"'.format(name)) except BaseException as e: self.warning_dialog( 'There was an issue trying to checkout branch "{}"'.format(name) ) raise e def delete_branch(self, name): self._repo.delete_head(name) def get_branch(self, name): return getattr(self._repo.heads, name) def create_branch(self, name=None, commit="HEAD", inform=True): repo = self._repo if name is None: nb = NewBranchView(branches=repo.branches) info = nb.edit_traits() if info.result: name = nb.name else: return if name not in repo.branches: branch = repo.create_head(name, commit=commit) branch.checkout() if inform: self.information_dialog('Repository now on branch "{}"'.format(name)) return name def create_remote(self, url, name="origin", force=False): repo = self._repo if repo: self.debug("setting remote {} {}".format(name, url)) # only create remote if doesnt exist if not hasattr(repo.remotes, name): self.debug("create remote {}".format(name, url)) repo.create_remote(name, url) elif force: repo.delete_remote(name) repo.create_remote(name, url) def delete_remote(self, name="origin"): repo = self._repo if repo: if hasattr(repo.remotes, name): repo.delete_remote(name) def get_branch_names(self): return [b.name for b in self._repo.branches] + [ b.name for b in self._repo.remote().refs if b.name.lower() != "origin/head" ] def git_history_view(self, branchname): repo = self._repo h = BaseGitHistory(branchname=branchname) origin = repo.remotes.origin try: oref = origin.refs[branchname] remote_commit = oref.commit except IndexError: remote_commit = None branch = self.get_branch(branchname) local_commit = branch.commit h.local_commit = str(local_commit) txt = repo.git.rev_list( "--left-right", "{}...{}".format(local_commit, remote_commit) ) commits = [ci[1:] for ci in txt.split("\n")] commits = [repo.commit(i) for i in commits] h.set_items(commits) commit_view = CommitView(model=h) return commit_view
[docs] def pull( self, branch="master", remote="origin", handled=True, use_progress=True, use_auto_pull=False, ): """ fetch and merge if use_auto_pull is False ask user if they want to accept the available updates """ self.debug("pulling {} from {}".format(branch, remote)) repo = self._repo try: remote = self._get_remote(remote) except AttributeError as e: print("repo man pull", e) return if remote: self.debug("pulling from url: {}".format(remote.url)) if use_progress: prog = open_progress( 3, show_percent=False, title="Pull Repository {}".format(self.name), close_at_end=False, ) prog.change_message( 'Fetching branch:"{}" from "{}"'.format(branch, remote) ) try: self.fetch(remote) except GitCommandError as e: self.debug(e) if not handled: raise e self.debug("fetch complete") def merge(): try: repo.git.merge("FETCH_HEAD") except GitCommandError as e: self.critical("Pull-merge FETCH_HEAD={}".format(e)) self.smart_pull(branch=branch, remote=remote) if not use_auto_pull: ahead, behind = self.ahead_behind(remote) if behind: if self.confirmation_dialog( 'Repository "{}" is behind the official version by {} changes.\n' "Would you like to pull the available changes?".format( self.name, behind ) ): # show the changes h = self.git_history_view(branch) info = h.edit_traits(kind="livemodal") if info.result: merge() else: merge() if use_progress: prog.close() self.debug("pull complete")
def has_remote(self, remote="origin"): return bool(self._get_remote(remote)) def push(self, branch="master", remote=None, inform=False): if remote is None: remote = "origin" rr = self._get_remote(remote) if rr: try: self._repo.git.push(remote, branch) if inform: self.information_dialog("{} push complete".format(self.name)) except GitCommandError as e: self.debug_exception() if inform: self.warning_dialog( "{} push failed. See log file for more details".format( self.name ) ) # self._git_command(lambda g: g.push(remote, branch), tag='GitRepoManager.push') else: self.warning('No remote called "{}"'.format(remote)) def _git_command(self, func, tag): try: return func(self._repo.git) except GitCommandError as e: self.warning("Git command failed. {}, {}".format(e, tag)) def rebase(self, onto_branch="master"): if self._repo: repo = self._repo branch = self.get_current_branch() self.checkout_branch(onto_branch) self.pull() repo.git.rebase(onto_branch, branch) def smart_pull( self, branch="master", remote="origin", quiet=True, accept_our=False, accept_their=False, ): if remote not in self._repo.remotes: return True try: ahead, behind = self.ahead_behind(remote) except GitCommandError as e: self.debug("Smart pull error: {}".format(e)) return self.debug("Smart pull ahead: {} behind: {}".format(ahead, behind)) repo = self._repo if behind: if ahead: if not quiet: if not self.confirmation_dialog( "You are {} behind and {} commits ahead. " "There are potential conflicts that you will have to resolve." "\n\nWould you like to Continue?".format(behind, ahead) ): return # check for unresolved conflicts # self._resolve_conflicts(branch, remote, accept_our, accept_their, True) try: repo.git.merge("--abort") except GitCommandError: pass # potentially conflicts with StashCTX(repo) as error: if error: self.warning_dialog( "Failed stashing your local changes. " "Fix repository {} " "before proceeding. {}".format( os.path.basename(repo.working_dir), error ) ) return # do merge try: # repo.git.rebase('--preserve-merges', '{}/{}'.format(remote, branch)) repo.git.merge("{}/{}".format(remote, branch)) except GitCommandError: if self.confirmation_dialog( "There appears to be a conflict with {}." "\n\nWould you like to accept the master copy (Yes).\n\nOtherwise " "you will need to merge the changes manually (No)".format( self.name ) ): try: repo.git.merge("--abort") except GitCommandError: pass try: repo.git.reset("--hard", "{}/{}".format(remote, branch)) except GitCommandError: pass elif self.confirmation_dialog( "Would you like to accept all of your current changes even " "though there are newer changes available?" ): accept_our = True # try: # repo.git.pull('-X', 'theirs', '--commit', '--no-edit') # return True # except GitCommandError: # clean = repo.git.clean('-n') # if clean: # if self.confirmation_dialog('''You have untracked files that could be an issue. # {} # # You like to delete them and try again?'''.format(clean)): # try: # repo.git.clean('-fd') # except GitCommandError: # self.warning_dialog('Failed to clean repository') # return # # try: # repo.git.pull('-X', 'theirs', '--commit', '--no-edit') # return True # except GitCommandError: # self.warning_dialog('Failed pulling changes for {}'.format(self.name)) # else: # self.warning_dialog('Failed pulling changes for {}'.format(self.name)) # return self._resolve_conflicts(branch, remote, accept_our, accept_their, quiet) else: self.debug("merging {} commits".format(behind)) self._git_command( lambda g: g.merge("FETCH_HEAD"), "GitRepoManager.smart_pull/!ahead" ) else: self.debug("Up-to-date with {}".format(remote)) if not quiet: self.information_dialog( 'Repository "{}" up-to-date with {}'.format(self.name, remote) ) return True def fetch(self, remote="origin"): if self._repo: return self._git_command(lambda g: g.fetch(remote), "GitRepoManager.fetch") # return self._repo.git.fetch(remote) def ahead_behind(self, remote="origin"): self.debug("ahead behind") repo = self._repo ahead, behind = ahead_behind(repo, remote) return ahead, behind def merge(self, from_, to_=None, inform=True): repo = self._repo if to_: dest = getattr(repo.branches, to_) dest.checkout() src = getattr(repo.branches, from_) try: repo.git.merge(src.commit) except GitCommandError: self.debug_exception() if inform: self.warning_dialog( "Merging {} into {} failed. See log file for more details".format( from_, to_ ) ) def commit(self, msg, author=None): self.debug("commit message={}, author={}".format(msg, author)) index = self.index if index: try: index.commit(msg, author=author, committer=author) return True except git.exc.GitError as e: self.warning("Commit failed: {}".format(e)) def add(self, p, msg=None, msg_prefix=None, verbose=True, **kw): repo = self._repo # try: # n = len(repo.untracked_files) # except IOError: # n = 0 # try: # if not repo.is_dirty() and not n: # return # except OSError: # pass bp = os.path.basename(p) dest = os.path.join(repo.working_dir, p) dest_exists = os.path.isfile(dest) if msg_prefix is None: msg_prefix = "modified" if dest_exists else "added" if not dest_exists: self.debug("copying to destination.{}>>{}".format(p, dest)) shutil.copyfile(p, dest) if msg is None: msg = "{}".format(bp) msg = "{} - {}".format(msg_prefix, msg) if verbose: self.debug("add to repo msg={}".format(msg)) self._add_to_repo(dest, msg, **kw) def get_log(self, branch, *args): if branch is None: branch = self._repo.active_branch # repo = self._repo # l = repo.active_branch.log(*args) return self.cmd("log", branch, "--oneline", *args).split("\n") def get_commits_from_log(self, greps=None, max_count=None, after=None, before=None): repo = self._repo args = [repo.active_branch.name, "--remove-empty", "--simplify-merges"] if max_count: args.append("--max-count={}".format(max_count)) if after: args.append("--after={}".format(after)) if before: args.append("--before={}".format(before)) if greps: greps = "\|".join(greps) args.append("--grep=^{}".format(greps)) args.append(LOGFMT) # txt = self.cmd('log', *args) # self.debug('git log {}'.format(' '.join(args))) cs = self._gitlog_commits(args) return cs def get_active_branch(self): return self._repo.active_branch.name def get_sha(self, path=None): sha = "" if path: logstr = self.cmd("ls-tree", "HEAD", path) try: mode, kind, sha_name = logstr.split(" ") sha, name = sha_name.split("\t") except ValueError: pass return sha def get_branch_diff(self, from_, to_): args = ("{}..{}".format(from_, to_), LOGFMT) return self._gitlog_commits(args) def add_tag(self, name, message, hexsha=None): args = ("-a", name, "-m", message) if hexsha: args = args + (hexsha,) self.cmd("tag", *args) # action handlers def diff_selected(self): if self._validate_diff(): if len(self.selected_commits) == 2: l, r = self.selected_commits dv = self._diff_view_factory(l, r) open_view(dv) def revert_to_selected(self): # check for uncommitted changes # warn user the uncommitted changes will be lost if revert now commit = self.selected_commits[-1] self.revert(commit.hexsha, self.selected) def revert(self, hexsha, path): self._repo.git.checkout(hexsha, path) self.path_dirty = path self._set_active_commit() def revert_commit(self, hexsha): self._git_command(lambda g: g.revert(hexsha), "revert_commit") def load_file_history(self, p): repo = self._repo try: hexshas = repo.git.log("--pretty=%H", "--follow", "--", p).split("\n") self.selected_path_commits = self._parse_commits(hexshas, p) self._set_active_commit() except GitCommandError: self.selected_path_commits = [] def get_modified_files(self, hexsha): def func(git): return git.diff_tree("--no-commit-id", "--name-only", "-r", hexsha) txt = self._git_command(func, "get_modified_files") return txt.split("\n") # private def _gitlog_commits(self, args): txt = self._git_command(lambda g: g.log(*args), "log") cs = [] if txt: cs = [from_gitlog(l.strip()) for l in txt.split("\n")] return cs def _resolve_conflicts(self, branch, remote, accept_our, accept_their, quiet): conflict_paths = self._get_conflict_paths() self.debug("resolve conflict_paths: {}".format(conflict_paths)) if conflict_paths: mm = MergeModel(conflict_paths, branch=branch, remote=remote, repo=self) if accept_our: mm.accept_our() elif accept_their: mm.accept_their() else: mv = MergeView(model=mm) mv.edit_traits(kind="livemodal") else: if not quiet: self.information_dialog("There were no conflicts identified") def _get_conflict_paths(self): def func(git): return git.diff("--name-only", "--diff-filter=U") txt = self._git_command(func, "get conflict paths") return [line for line in txt.split("\n") if line.strip()] def _validate_diff(self): return True def _diff_view_factory(self, a, b): # d = self.diff(a.hexsha, b.hexsha) if not a.blob: a.blob = self.unpack_blob(a.hexsha, a.name) if not b.blob: b.blob = self.unpack_blob(b.hexsha, b.name) model = DiffModel(left_text=b.blob, right_text=a.blob) dv = DiffView(model=model) return dv def _add_to_repo(self, p, msg, commit=True): index = self.index if index: if not isinstance(p, list): p = [p] try: index.add(p) except IOError as e: self.warning('Failed to add file. Error:"{}"'.format(e)) # an IOError has been caused in the past by "'...index.lock' could not be obtained" os.remove(os.path.join(self.path, ".git", "index.lock")) try: self.warning('Retry after "Failed to add file"'.format(e)) index.add(p) except IOError as e: self.warning('Retry failed. Error:"{}"'.format(e)) return if commit: index.commit(msg) def _get_remote(self, remote): repo = self._repo try: return getattr(repo.remotes, remote) except AttributeError: pass def _get_branch_history(self): repo = self._repo hexshas = repo.git.log("--pretty=%H").split("\n") return hexshas def _load_branch_history(self): hexshas = self._get_branch_history() self.commits = self._parse_commits(hexshas) def _parse_commits(self, hexshas, p=""): def factory(ci): repo = self._repo obj = repo.rev_parse(ci) cx = GitSha( message=obj.message, hexsha=obj.hexsha, name=p, date=obj.committed_datetime, ) # date=format_date(obj.committed_date)) return cx return [factory(ci) for ci in hexshas] def _set_active_commit(self): p = self.selected with open(p, "r") as rfile: chexsha = hashlib.sha1(rfile.read()).hexdigest() for c in self.selected_path_commits: blob = self.unpack_blob(c.hexsha, p) c.active = chexsha == hashlib.sha1(blob).hexdigest() if blob else False self.refresh_commits_table_needed = True # handlers def _selected_fired(self, new): if new: self._selected_hook(new) self.load_file_history(new) def _selected_hook(self, new): pass def _remote_changed(self, new): if new: self.delete_remote() r = "https://github.com/{}".format(new) self.create_remote(r) @property def index(self): return self._repo.index @property def active_repo(self): return self._repo
if __name__ == "__main__": repo = GitRepoManager() repo.open_repo("/Users/ross/Sandbox/mergetest/blocal") repo.smart_pull() # rp = GitRepoManager() # rp.init_repo('/Users/ross/Pychrondata_dev/scripts') # rp.commit_dialog() # ============= EOF ============================================= # repo manager protocol # def get_local_changes(self, repo=None): # repo = self._get_repo(repo) # diff_str = repo.git.diff('--full-index') # patches = map(str.strip, diff_str.split('diff --git')) # patches = ['\n'.join(p.split('\n')[2:]) for p in patches[1:]] # # diff_str = StringIO(diff_str) # diff_str.seek(0) # index = Diff._index_from_patch_format(repo, diff_str) # # return index, patches # def is_dirty(self, repo=None): # repo = self._get_repo(repo) # return repo.is_dirty() # def get_untracked(self): # return self._repo.untracked_files # def _add_repo(self, root): # existed=True # if not os.path.isdir(root): # os.mkdir(root) # existed=False # # gitdir=os.path.join(root, '.git') # if not os.path.isdir(gitdir): # repo = Repo.init(root) # existed = False # else: # repo = Repo(root) # # return repo, existed # def add_repo(self, localpath): # """ # add a blank repo at ``localpath`` # """ # repo, existed=self._add_repo(localpath) # self._repo=repo # self.root=localpath # return existed