Source code for gitmopy.git
"""
`gitmopy`'s Git-related utility functions.
"""
from typing import Dict, List, Union
from git import Remote, Repo
from git.exc import GitCommandError
from gitmopy.constants import _sentinels
from gitmopy.utils import col, print
[docs]
class CatchRemoteException:
def __init__(self, remote: str):
"""
Context manager to catch GitCommandError when pushing to a remote.
Also stores whether the issue is that the remote has no upstream branch.
Args:
remote (str): Name of the remote repository that is being pushed to.
"""
[docs]
self.set_upsteam = False
def __enter__(self):
"""
Enter the context manager.
Returns:
CatchRemoteException: the context manager itself.
"""
return self
def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the context manager and catch a potential GitCommandError exception.
Prints the error in red.
If the error is that the remote has no upstream branch, store a ``bool``
attribute to tell the CLI it should ask the user if they want to set it up.
Args:
exc_type (BaseException): Exception type
exc_value (BaseException): Exception value
traceback (TracebackType): _description_
Returns:
bool: True. This allows to always catch the exception.
"""
if exc_type is GitCommandError:
self.error = True
print(
f"[bold red]Error:[/bold red] could not push to {self.remote}:",
)
print("[red]" + exc_value.stderr + "[/red]")
return True
[docs]
def fetch_all(repo):
"""
Fetch all remotes of a GitPython repository.
Args:
repo (git.Repo): Repository to fetch remotes from.
"""
for r in repo.remotes:
r.fetch()
[docs]
def has_upstreams(
repo: Repo, remotes: List[Union[str, Remote]], branch_name: str
) -> Dict[str, bool]:
"""
Check which remotes have a branch with a given name.
Args:
repo (git.Repo): Repository to check branches from.
remotes (List[Union[str, Remote]]): List of remotes to check.
branch_name (str): Name of the branch to check.
Returns:
Dict[str, bool]: Dictionnary of booleans indicating if each remote has the
branch.
"""
fetch_all(repo)
remote_has_upstream = {
r.name if isinstance(r, Remote) else r: False for r in remotes
}
remote_refs = {
ref.remote_name: True
for ref in repo.refs
if ref.is_remote()
and ref.name.removeprefix(ref.remote_name + "/") == branch_name
}
remote_has_upstream.update(remote_refs)
return remote_has_upstream
[docs]
def get_staged(repo: Repo) -> List[str]:
"""
Get staged files from a GitPython repository.
Args:
repo (git.Repo): Repository to get staged files from.
Returns:
List[str]: File paths of staged files.
"""
return [item.a_path for item in repo.index.diff("HEAD")]
[docs]
def get_unstaged(repo: Repo) -> List[str]:
"""
Get unstaged files from a GitPython repository.
Args:
repo (git.Repo): Repository to get unstaged files from.
Returns:
List[str]: File paths of unstaged files.
"""
return [item.a_path for item in repo.index.diff(None)]
[docs]
def get_untracked(repo: Repo) -> List[str]:
"""
Get untracked files from a GitPython repository.
Args:
repo (git.Repo): Repository to get untracked files from.
Returns:
List[str]: File paths of untracked files.
"""
return [item for item in repo.untracked_files]
[docs]
def get_files_status(repo: Repo) -> Dict[str, List[str]]:
"""
Make a dictionnary of the files' status in a GitPython repository.
Keys are "staged", "unstaged" and "untracked".
Values are lists of file paths.
Args:
repo (git.Repo): Repository to get files' status from.
Returns:
Dict[str, List[str]]: Dictionnary of files' status.
"""
return {
"staged": get_staged(repo),
"unstaged": get_unstaged(repo),
"untracked": get_untracked(repo),
}
[docs]
def unstage(repo: Repo, files: List[str]) -> None:
"""
Unstage files from the index.
Args:
repo (Repo): GitPython repository object.
files (List[str]): List of files to unstage.
"""
for f in files:
repo.git.restore("--staged", f)
[docs]
def commits_behind(repo: Repo) -> int:
"""
Get the number of commits the local current branch is behind for each remote.
Args:
repo (Repo): GitPython repository object.
Returns:
int: Number of commits the local current branch is behind for each remote.
"""
b = repo.active_branch.name
behinds = {}
for r in repo.remotes:
try:
behinds[r.name] = len(list(repo.iter_commits(f"{b}..{r.name}/{b}")))
except GitCommandError as e:
if "fatal: bad revision" in str(e):
behinds[r.name] = _sentinels["no-branch"]
return behinds
[docs]
def commits_ahead(repo: Repo) -> int:
"""
Get the number of commits the local current branch is behind for each remote.
Args:
repo (Repo): GitPython repository object.
Returns:
int: Number of commits the local current branch is behind for each remote.
"""
b = repo.active_branch.name
aheads = {}
for r in repo.remotes:
try:
aheads[r.name] = len(list(repo.iter_commits(f"{r.name}/{b}..{b}")))
except GitCommandError:
# already caught and printed in commits_behind
pass
return aheads