Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(autofix): Cache repo client objects to avoid re-initialization #1626

Merged
merged 5 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions src/seer/automation/autofix/autofix_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,34 +134,25 @@ def get_repo_client(
repo_name: str | None = None,
repo_external_id: str | None = None,
type: RepoClientType = RepoClientType.READ,
):
) -> RepoClient:
"""
Gets a repo client for the current single repo or for a given repo name.
If there are more than 1 repos, a repo name must be provided.
"""
repo_client: RepoClient | None = None
repo: RepoDefinition | None = None
if len(self.repos) == 1:
repo_client = RepoClient.from_repo_definition(self.repos[0], type)
repo = self.repos[0]
elif repo_name:
repo = next((r for r in self.repos if r.full_name == repo_name), None)

if not repo:
raise AgentError() from ValueError(f"Repo {repo_name} not found.")

repo_client = RepoClient.from_repo_definition(repo, type)
elif repo_external_id:
repo = next((r for r in self.repos if r.external_id == repo_external_id), None)

if not repo:
raise AgentError() from ValueError(f"Repo {repo_external_id} not found.")

repo_client = RepoClient.from_repo_definition(repo, type)
else:
if not repo:
raise AgentError() from ValueError(
"Please provide a repo name because you have multiple repos."
"Repo not found. Please provide a valid repo name or external ID."
)

return repo_client
return RepoClient.from_repo_definition(repo, type)

def get_file_contents(
self, path: str, repo_name: str | None = None, ignore_local_changes: bool = False
Expand All @@ -185,7 +176,9 @@ def _process_stacktrace_paths(self, stacktrace: Stacktrace):
"""
for repo in self.repos:
try:
repo_client = RepoClient.from_repo_definition(repo, RepoClientType.READ)
repo_client = self.get_repo_client(
repo_external_id=repo.external_id, type=RepoClientType.READ
)
except UnknownObjectException:
self.event_manager.on_error(
error_msg=f"Autofix does not have access to the `{repo.full_name}` repo. Please give permission through the Sentry GitHub integration, or remove the repo from your code mappings.",
Expand Down Expand Up @@ -250,8 +243,8 @@ def commit_changes(
if repo_definition is None:
raise ValueError(f"Repo definition not found for key {key}")

repo_client = RepoClient.from_repo_definition(
repo_definition, RepoClientType.WRITE
repo_client = self.get_repo_client(
repo_external_id=repo_definition.external_id, type=RepoClientType.WRITE
)

branch_ref = repo_client.create_branch_from_changes(
Expand Down Expand Up @@ -360,7 +353,9 @@ def comment_root_cause_on_pr(
)

# comment root cause analysis on PR
repo_client = RepoClient.from_repo_definition(repo_definition, RepoClientType.READ)
repo_client = self.get_repo_client(
repo_external_id=repo_definition.external_id, type=RepoClientType.READ
)
repo_client.comment_root_cause_on_pr_for_copilot(
pr_url, state.run_id, state.request.issue.id, markdown_comment
)
Expand Down
2 changes: 2 additions & 0 deletions src/seer/automation/codebase/repo_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
import os
import shutil
Expand Down Expand Up @@ -181,6 +182,7 @@ def check_repo_read_access(repo: RepoDefinition):
return False

@classmethod
@functools.cache
def from_repo_definition(cls, repo_def: RepoDefinition, type: RepoClientType):
if type == RepoClientType.WRITE:
return cls(*get_write_app_credentials(), repo_def)
Expand Down
7 changes: 7 additions & 0 deletions tests/automation/codebase/test_repo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
from seer.automation.models import RepoDefinition


@pytest.fixture(autouse=True)
def clear_repo_client_cache():
"""Clear the RepoClient.from_repo_definition cache before each test"""
RepoClient.from_repo_definition.cache_clear()
yield


@pytest.fixture
def mock_github():
with patch("seer.automation.codebase.repo_client.Github") as mock:
Expand Down
Loading