Skip to content

Commit

Permalink
Lint + test
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerard authored and Gerard committed Oct 16, 2023
1 parent 768cad8 commit 4c1b041
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
11 changes: 7 additions & 4 deletions src/promptflow-tools/promptflow/tools/open_source_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
REQUIRED_CONFIG_KEYS = ["endpoint_url", "model_family"]
REQUIRED_SECRET_KEYS = ["endpoint_api_key"]
DEFAULT_ENDPOINT_NAME = "-- please enter an endpoint name --"
ENDPOINT_REQUIRED_ENV_VARS = ["AZUREML_ARM_SUBSCRIPTION", "AZUREML_ARM_RESOURCEGROUP", "AZUREML_ARM_WORKSPACE_NAME"]


def handle_oneline_endpoint_error(max_retries: int = 3,
Expand Down Expand Up @@ -105,9 +106,11 @@ def get_deployment_from_endpoint(endpoint_name: str, deployment_name: str = None
subscription_id=os.getenv("AZUREML_ARM_SUBSCRIPTION"),
resource_group_name=os.getenv("AZUREML_ARM_RESOURCEGROUP"),
workspace_name=os.getenv("AZUREML_ARM_WORKSPACE_NAME"))
except:
from promptflow.runtime import PromptFlowRuntime
ml_client = PromptFlowRuntime.get_instance().config.get_ml_client(credential=credential)
except Exception as e:
print(e)
message = "Unable to connect to AzureML. Please ensure the following environment variables are set: "
message += ",".join(ENDPOINT_REQUIRED_ENV_VARS)
raise OpenSourceLLMOnlineEndpointError(message=message)

found = False
for ep in ml_client.online_endpoints.list():
Expand Down Expand Up @@ -453,7 +456,7 @@ def estimate_tokens(input_str: str) -> int:
logger_index.log_metric("completion_tokens", prompt_tokens + response_tokens)
logger_index.log_metric("prompt_tokens", prompt_tokens)
logger_index.log_metric("total_tokens", response_tokens)
except:
except Exception:
pass

return response
Expand Down
13 changes: 5 additions & 8 deletions src/promptflow-tools/tests/test_open_source_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,9 @@ def test_open_source_llm_llama_deployment_miss(self, gpt2_custom_connection):

def test_open_source_llm_endpoint_name(self):
os.environ["AZUREML_ARM_SUBSCRIPTION"] = "ba7979f7-d040-49c9-af1a-7414402bf622"
os.environ["AZUREML_ARM_RESOURCEGROUP"] = "test_resource_groups"
os.environ["AZUREML_ARM_WORKSPACE_NAME"] = "test_workspace"
os.environ["AZUREML_ARM_RESOURCEGROUP"] = "gewoods_rg"
os.environ["AZUREML_ARM_WORKSPACE_NAME"] = "gewoods_ml"

from azure.core.exceptions import ClientAuthenticationError
with pytest.raises(ClientAuthenticationError) as exc_info:
OpenSourceLLM(endpoint_name="Not_Real")

expected_message = "DefaultAzureCredential failed to retrieve a token from the included credentials."
assert exc_info.value.message .startswith(expected_message)
os_llm = OpenSourceLLM(endpoint_name="llama-temp-completion")
response = os_llm.call(self.completion_prompt, API.COMPLETION)
assert len(response) > 25

0 comments on commit 4c1b041

Please sign in to comment.