File indexing completed on 2024-12-01 08:16:20

0001 """
0002 Module containing classes for checking out merge requests
0003 """
0004 
0005 # SPDX-FileCopyrightText: 2020 Jonah BrĂ¼chert <jbb@kaidan.im>
0006 #
0007 # SPDX-License-Identifier: GPL-2.0-or-later
0008 
0009 import argparse
0010 import sys
0011 
0012 from gitlab.v4.objects import ProjectMergeRequest
0013 from gitlab.v4.objects import Project
0014 from gitlab.exceptions import GitlabHttpError, GitlabGetError
0015 
0016 from git.remote import Remote
0017 from git.refs.reference import Reference
0018 
0019 from lab.repositoryconnection import RepositoryConnection
0020 from lab.utils import Utils
0021 from lab.utils import LogType
0022 
0023 from typing import Optional
0024 
0025 
0026 def parser(
0027     subparsers: argparse._SubParsersAction,  # pylint: disable=protected-access
0028 ) -> argparse.ArgumentParser:
0029     """
0030     Subparser for checking-out merge request command
0031     :param subparsers: subparsers object from global parser
0032     :return: checking-out merge request subparser
0033     """
0034     checkouter_parser: argparse.ArgumentParser = subparsers.add_parser(
0035         "checkout", help="check out a remote merge request", aliases=["patch"]
0036     )
0037     checkouter_parser.add_argument(
0038         "number",
0039         metavar="int",
0040         type=int,
0041         nargs=1,
0042         help="Merge request number to checkout",
0043     )
0044     return checkouter_parser
0045 
0046 
0047 def run(args: argparse.Namespace) -> None:
0048     """
0049     run checking-out merge request command
0050     :param args: parsed arguments
0051     """
0052     checkouter: MergeRequestCheckout = MergeRequestCheckout()
0053     checkouter.checkout(args.number[0])
0054 
0055 
0056 class MergeRequestCheckout(RepositoryConnection):
0057     """
0058     Check out a merge request in the current git repository
0059     """
0060 
0061     # private
0062     __mr: ProjectMergeRequest
0063 
0064     def __init__(self) -> None:
0065         RepositoryConnection.__init__(self)
0066 
0067     def add_remote(self) -> Optional[Reference]:
0068         fork_project: Project
0069         try:
0070             fork_project = self._connection.projects.get(self.__mr.source_project_id)
0071         except (GitlabHttpError, GitlabGetError):
0072             Utils.log(
0073                 LogType.ERROR,
0074                 "The source repository of this merge request could not be found on the GitLab instance.",
0075             )
0076             sys.exit(1)
0077 
0078         remote_url: str = fork_project.ssh_url_to_repo
0079         user: str = self.__mr.author["username"]
0080         remote_name = f"fork-{user}"
0081 
0082         remote: Remote
0083         if remote_name not in self._local_repo.remotes:
0084             remote = Remote.add(self._local_repo, remote_name, remote_url)
0085         else:
0086             remote = self._local_repo.remotes[remote_name]
0087 
0088         remote.fetch()
0089 
0090         for ref in remote.refs:
0091             if ref.name == f"{remote_name}/{self.__mr.source_branch}":
0092                 return ref
0093 
0094         Utils.log(LogType.ERROR, "Failed to find remote ref")
0095         sys.exit(1)
0096 
0097     def checkout(self, merge_request_id: int) -> None:
0098         """
0099         Checks out the merge request with the specified id in the local worktree
0100         """
0101         self.__mr = self._remote_project.mergerequests.get(merge_request_id, lazy=False)
0102         print('Checking out merge request "{}"...'.format(self.__mr.title))
0103         print("  branch:", self.__mr.source_branch)
0104 
0105         remote_ref = self.add_remote()
0106 
0107         if self.__mr.source_branch in self._local_repo.refs:
0108             # Make sure not to overwrite local changes
0109             overwrite = Utils.ask_bool(
0110                 'Branch "{}" already exists locally, do you want to overwrite it?'.format(
0111                     self.__mr.source_branch
0112                 )
0113             )
0114 
0115             if not overwrite:
0116                 print("Aborting")
0117                 sys.exit(1)
0118 
0119             # If the branch that we want to overwrite is currently checked out,
0120             # that will of course not work, so try to switch to another branch in the meantime.
0121             if self.__mr.source_branch == self._local_repo.head.reference.name:
0122                 if "main" in self._local_repo.refs:
0123                     self._local_repo.refs.main.checkout()
0124                 elif "master" in self._local_repo.refs:
0125                     self._local_repo.refs.master.checkout()
0126                 else:
0127                     Utils.log(
0128                         LogType.ERROR,
0129                         "The branch that you want to overwrite is currently checked out \
0130                         and no other branch to temporarily switch to could be found. Please check out \
0131                         a different branch and try again.",
0132                     )
0133                     sys.exit(1)
0134 
0135             self._local_repo.delete_head(self.__mr.source_branch, "-f")
0136 
0137         head = self._local_repo.create_head(self.__mr.source_branch, remote_ref)
0138         head.checkout()
0139         self._local_repo.active_branch.set_tracking_branch(remote_ref)