diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index f797728042..d396087443 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -18,6 +18,7 @@ from dvc.exceptions import DvcException from dvc.path_info import PathInfo +from dvc.repo import Repo from dvc.repo.experiments.base import ( EXEC_BASELINE, EXEC_BRANCH, @@ -331,6 +332,7 @@ def filter_pipeline(stages): checkpoint_func = partial( cls.checkpoint_callback, + dvc, dvc.scm, name, repro_force or checkpoint_reset, @@ -393,7 +395,6 @@ def _repro_dvc( git_url: Optional[str] = None, **kwargs, ): - from dvc.repo import Repo from dvc.utils.serialize import modify_yaml dvc = Repo(dvc_dir) @@ -453,6 +454,7 @@ def _repro_args(cls, dvc): @classmethod def checkpoint_callback( cls, + dvc: "Repo", scm: "Git", name: Optional[str], force: bool, @@ -464,6 +466,23 @@ def checkpoint_callback( exp_rev = cls.commit( scm, exp_hash, exp_name=name, force=force, checkpoint=True ) + + git_remote = os.environ.get("DVC_EXP_AUTO_PUSH", None) + if git_remote: + from dvc.repo.experiments.push import push + + branch = dvc.experiments.get_branch_by_rev( + exp_rev, allow_multiple=None + ) + branch_name = ExpRefInfo.from_ref(branch).name + push( + dvc, + git_remote, + branch_name, + push_cache=True, + run_cache=True, + ) + logger.info({"pushed": branch_name}) logger.info("Checkpoint experiment iteration '%s'.", exp_rev[:7]) except UnchangedExperimentError: pass