Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experiment] Change experiment cli "--file" to "--template" #2311

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/promptflow/promptflow/_cli/_pf/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def add_experiment_start(subparsers):
# Start a named experiment:
pf experiment start -n my_experiment --inputs data1=data1_val data2=data2_val
# Run an experiment by yaml file:
pf experiment start --file path/to/my_experiment.exp.yaml --inputs data1=data1_val data2=data2_val
pf experiment start --template path/to/my_experiment.exp.yaml --inputs data1=data1_val data2=data2_val
"""
activate_action(
name="start",
description="Start an experiment.",
epilog=epilog,
add_params=[add_param_name, add_param_file, add_param_input, add_param_stream] + base_params,
add_params=[add_param_name, add_param_template, add_param_input, add_param_stream] + base_params,
subparsers=subparsers,
help_message="Start an experiment.",
action_param_name="sub_action",
Expand Down Expand Up @@ -235,20 +235,18 @@ def start_experiment(args: argparse.Namespace):
if args.name:
logger.debug(f"Starting a named experiment {args.name}.")
inputs = list_of_dict_to_dict(args.inputs)
if inputs:
logger.warning("The inputs of named experiment cannot be modified.")
client = _get_pf_client()
experiment = client._experiments.get(args.name)
result = client._experiments.start(experiment=experiment, stream=args.stream)
elif args.file:
result = client._experiments.start(experiment=experiment, inputs=inputs, stream=args.stream)
elif args.template:
from promptflow._sdk._load_functions import _load_experiment

logger.debug(f"Starting an anonymous experiment {args.file}.")
experiment = _load_experiment(source=args.file)
logger.debug(f"Starting an anonymous experiment {args.template}.")
experiment = _load_experiment(source=args.template)
inputs = list_of_dict_to_dict(args.inputs)
result = _get_pf_client()._experiments.start(experiment=experiment, inputs=inputs, stream=args.stream)
else:
raise UserErrorException("To start an experiment, one of [name, file] must be specified.")
raise UserErrorException("To start an experiment, one of [name, template] must be specified.")
print(json.dumps(result._to_dict(), indent=4))


Expand Down
7 changes: 1 addition & 6 deletions src/promptflow/promptflow/_sdk/_load_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import hashlib
from os import PathLike
from pathlib import Path
from typing import IO, AnyStr, Optional, Union
Expand All @@ -11,7 +10,6 @@
from .._utils.logger_utils import get_cli_sdk_logger
from .._utils.yaml_utils import load_yaml
from ._errors import MultipleExperimentTemplateError, NoExperimentTemplateError
from ._utils import _sanitize_python_variable_name
from .entities import Run
from .entities._connection import CustomConnection, _Connection
from .entities._experiment import Experiment, ExperimentTemplate
Expand Down Expand Up @@ -205,8 +203,5 @@ def _load_experiment(
absolute_path = source.resolve().absolute().as_posix()
if not source.exists():
raise NoExperimentTemplateError(f"Experiment file {absolute_path} not found.")
anonymous_exp_name = _sanitize_python_variable_name(
f"{source.stem}_{hashlib.sha1(absolute_path.encode('utf-8')).hexdigest()}"
)
experiment = load_common(Experiment, source, params_override=[{"name": anonymous_exp_name}], **kwargs)
experiment = load_common(Experiment, source, **kwargs)
return experiment
24 changes: 13 additions & 11 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1985,18 +1985,20 @@ def wait_for_experiment_terminated(experiment_name):
@pytest.mark.skipif(condition=not is_live(), reason="Injection cannot passed to detach process.")
@pytest.mark.usefixtures("setup_experiment_table")
def test_experiment_start_anonymous_experiment(self, monkeypatch, local_client):
from promptflow._sdk._load_functions import _load_experiment

with mock.patch("promptflow._sdk._configuration.Configuration.is_internal_features_enabled") as mock_func:
mock_func.return_value = True
experiment_file = f"{EXPERIMENT_DIR}/basic-script-template/basic-script.exp.yaml"
run_pf_command("experiment", "start", "--file", experiment_file, "--stream")
experiment = _load_experiment(source=experiment_file)
exp = local_client._experiments.get(name=experiment.name)
assert len(exp.node_runs) == 4
assert all(len(exp.node_runs[node_name]) > 0 for node_name in exp.node_runs)
metrics = local_client.runs.get_metrics(name=exp.node_runs["eval"][0]["name"])
assert "accuracy" in metrics
from promptflow._sdk.entities._experiment import Experiment

with mock.patch.object(Experiment, "_generate_name") as mock_generate_name:
experiment_name = str(uuid.uuid4())
mock_generate_name.return_value = experiment_name
mock_func.return_value = True
experiment_file = f"{EXPERIMENT_DIR}/basic-script-template/basic-script.exp.yaml"
run_pf_command("experiment", "start", "--template", experiment_file, "--stream")
exp = local_client._experiments.get(name=experiment_name)
assert len(exp.node_runs) == 4
assert all(len(exp.node_runs[node_name]) > 0 for node_name in exp.node_runs)
metrics = local_client.runs.get_metrics(name=exp.node_runs["eval"][0]["name"])
assert "accuracy" in metrics

@pytest.mark.usefixtures("setup_experiment_table", "recording_injection")
def test_experiment_test(self, monkeypatch, capfd, local_client, tmpdir):
Expand Down
Loading