Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional fields #28

Merged
merged 20 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fb42a84
Added new unit tests to simulate failed prompts where the returned va…
aelaguiz Jul 10, 2024
3075808
Removed __init__ constructors from TestFailedPromptSignature and Test…
aelaguiz Jul 10, 2024
14e6333
The commit message for the changes made in the provided diffs is:
aelaguiz Jul 10, 2024
397f0f5
Added support for FakeLLM in PromptRunner.
aelaguiz Jul 10, 2024
b54d783
Added support for a new 'fake_anthropic' LLM type that wraps the Fake…
aelaguiz Jul 10, 2024
170d276
Added unit tests for the FakeLLM from langchain_contrib.
aelaguiz Jul 10, 2024
b58d41f
Renamed `TestFailedPromptSignature` and `TestFailedPromptStrategy` cl…
aelaguiz Jul 10, 2024
2cc14c7
Rewritten test_failed_prompts to use the fake Anthropic LLM pattern.
aelaguiz Jul 10, 2024
35680de
Improved error message for failed prompt output validation.
aelaguiz Jul 10, 2024
d6513c0
Added a new test for optional output field.
aelaguiz Jul 10, 2024
5c0d0b4
Configured the PromptRunner to use the correct LLM type and set hard_…
aelaguiz Jul 10, 2024
10c37d8
Refactored the `_invoke_with_retries` function in the PromptRunner cl…
aelaguiz Jul 10, 2024
dd432e6
Removed unnecessary print statements and added support for 'fake_anth…
aelaguiz Jul 10, 2024
f626efb
Removed the `print_prompt` variable from `prompt_runners.py` and `pro…
aelaguiz Jul 10, 2024
38ff31e
Implemented handling of optional output fields in the PromptRunner cl…
aelaguiz Jul 10, 2024
79c2909
Corrected the error message in the `PromptStrategy` class to display …
aelaguiz Jul 10, 2024
b01677c
Improved the exception raised in `validate_inputs` to provide more de…
aelaguiz Jul 10, 2024
d979b3d
Added logging to the `tests/test_model_train.py` file.
aelaguiz Jul 10, 2024
cf3cea0
Initialized the langdspy logger and configured it to log to the conso…
aelaguiz Jul 10, 2024
c92959c
All tests pass
aelaguiz Jul 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ __pycache__/
htmlcov/
.coverage
*.swp
.aider*
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ if __name__ == "__main__":
X_test = dataset['test']['X']
y_test = dataset['test']['y']

model = ProductSlugGenerator(n_jobs=4, print_prompt=True)
model = ProductSlugGenerator(n_jobs=4)

before_test_accuracy = None
if os.path.exists(output_path):
Expand Down
3 changes: 1 addition & 2 deletions examples/amazon/generate_slugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def evaluate_model(model, X, y):
X_test = dataset['test']['X']
y_test = dataset['test']['y']

model = ProductSlugGenerator(n_jobs=1, print_prompt=True)
# model.generate_slug.set_model_kwargs({'print_prompt': True})
model = ProductSlugGenerator(n_jobs=1)

before_test_accuracy = None
if os.path.exists(output_path):
Expand Down
230 changes: 102 additions & 128 deletions langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.pydantic_v1 import validator
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
# from langchain_contrib.llms.testing import FakeLLM
from typing import Any, Dict, List, Type, Optional, Callable
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
Expand Down Expand Up @@ -114,152 +115,127 @@ def clear_prompt_history(self):

def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[RunnableConfig] = {}):
total_max_tries = max_tries

hard_fail = config.get('hard_fail', False)
llm_type = config.get('llm_type') # Get the LLM type from the configuration
if llm_type is None:
llm_type = self._determine_llm_type(config['llm']) # Auto-detect the LLM type if not specified

llm_model = self._determine_llm_model(config['llm'])

hard_fail = config.get('hard_fail', True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah this file is a lot cleaner now

llm_type, llm_model = self._get_llm_info(config)

logger.debug(f"LLM type: {llm_type} - model {llm_model}")


res = {}
formatted_prompt = None
prompt_res = None

while max_tries >= 1:
start_time = time.time()
try:
kwargs = {**self.model_kwargs, **self.kwargs}
# logger.debug(f"PromptRunner invoke with input {input} and kwargs {kwargs} and config {config}")
# logger.debug(f"Prompt runner kwargs: {kwargs}")
trained_state = config.get('trained_state', None)
# logger.debug(f"1 - Trained state is {trained_state}")
if not trained_state or not trained_state.examples:
# logger.debug(f"2 - Trained state is {trained_state}")
trained_state = self.model_kwargs.get('trained_state', None)
# logger.debug(f"3 - Trained state is {trained_state}")

if not trained_state or not trained_state.examples:
_trained_state = self.kwargs.get('trained_state', None)
if not trained_state:
trained_state = _trained_state
# logger.debug(f"4 - Trained state is {trained_state}")

print_prompt = kwargs.get('print_prompt', config.get('print_prompt', False))
# logger.debug(f"Print prompt {print_prompt} kwargs print prompt {kwargs.get('print_prompt')} config print prompt {config.get('print_prompt')}")

# logger.debug(f"PromptRunner invoke with trained_state {trained_state}")
invoke_args = {**input, 'print_prompt': print_prompt, **kwargs, 'trained_state': trained_state, 'use_training': config.get('use_training', True), 'llm_type': llm_type}
formatted_prompt = self.template.format_prompt(**invoke_args)

if print_prompt:
print(f"------------------------PROMPT START--------------------------------")
print(formatted_prompt)
print(f"------------------------PROMPT END----------------------------------\n")

prompt_logger.info(f"------------------------PROMPT START--------------------------------")
prompt_logger.info(formatted_prompt)
prompt_logger.info(f"------------------------PROMPT END----------------------------------\n")

# logger.debug(f"Invoke args: {invoke_args}")
prompt_res = chain.invoke(invoke_args, config=config)
formatted_prompt, prompt_res = self._execute_prompt(chain, input, config, llm_type)
parsed_output, validation_err = self._process_output(prompt_res, input, llm_type)

end_time = time.time()
self._log_prompt_history(config, formatted_prompt, prompt_res, parsed_output, validation_err, start_time, end_time)

if validation_err is None:
return parsed_output

except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Failed in the LLM layer {e} - sleeping then trying again")
time.sleep(random.uniform(0.1, 1.5))
max_tries -= 1
continue


validation_err = None

# logger.debug(f"Raw output for prompt runner {self.template.__class__.__name__}: {res}")
if print_prompt:
print(f"------------------------RESULT START--------------------------------")
print(prompt_res)
print(f"------------------------RESULT END----------------------------------\n")

prompt_logger.info(f"------------------------RESULT START--------------------------------")
prompt_logger.info(prompt_res)
prompt_logger.info(f"------------------------RESULT END----------------------------------\n")

# Use the parse_output_to_fields method from the PromptStrategy
parsed_output = {}
self._handle_exception(e, max_tries)

max_tries -= 1
if max_tries >= 1:
self._handle_retry(max_tries)

return self._handle_failure(hard_fail, total_max_tries, prompt_res)

def _get_llm_info(self, config):
llm_type = config.get('llm_type') or self._determine_llm_type(config['llm'])
llm_model = self._determine_llm_model(config['llm'])
return llm_type, llm_model

def _execute_prompt(self, chain, input, config, llm_type):
kwargs = {**self.model_kwargs, **self.kwargs}
trained_state = self._get_trained_state(config)

invoke_args = {**input, **kwargs, 'trained_state': trained_state, 'use_training': config.get('use_training', True), 'llm_type': llm_type}
formatted_prompt = self.template.format_prompt(**invoke_args)

self._log_prompt(formatted_prompt)

prompt_res = chain.invoke(invoke_args, config=config)
return formatted_prompt, prompt_res

def _get_trained_state(self, config):
trained_state = config.get('trained_state') or self.model_kwargs.get('trained_state') or self.kwargs.get('trained_state')
return trained_state if trained_state and trained_state.examples else None

def _log_prompt(self, formatted_prompt):
prompt_logger.info(f"------------------------PROMPT START--------------------------------")
prompt_logger.info(formatted_prompt)
prompt_logger.info(f"------------------------PROMPT END----------------------------------\n")

def _process_output(self, prompt_res, input, llm_type):
self._log_result(prompt_res)

parsed_output = {}
validation_err = None
try:
parsed_output = self.template.parse_output_to_fields(prompt_res, llm_type)
validation_err = self._validate_output(parsed_output, input)
except Exception as e:
validation_err = f"Failed to parse output for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
import traceback
traceback.print_exc()

return parsed_output, validation_err

def _log_result(self, prompt_res):
prompt_logger.info(f"------------------------RESULT START--------------------------------")
prompt_logger.info(prompt_res)
prompt_logger.info(f"------------------------RESULT END----------------------------------\n")

def _validate_output(self, parsed_output, input):
for attr_name, output_field in self.template.output_variables.items():
output_value = parsed_output.get(attr_name)
if output_value is None:
if not output_field.kwargs['optional']:
return f"Failed to get output value for non-optional field {attr_name} for prompt runner {self.template.__class__.__name__}"
else:
parsed_output[attr_name] = None
continue

if not output_field.validate_value(input, output_value):
return f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"

try:
parsed_output = self.template.parse_output_to_fields(prompt_res, llm_type)
parsed_output[attr_name] = output_field.transform_value(output_value)
except Exception as e:
import traceback
traceback.print_exc()
validation_err = f"Failed to parse output for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
# logger.debug(f"Parsed output: {parsed_output}")

len_parsed_output = len(parsed_output.keys())
len_output_variables = len(self.template.output_variables.keys())
logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")

if len(parsed_output.keys()) != len(self.template.output_variables.keys()):
validation_err = f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)

if validation_err is None:
# Transform and validate the outputs
for attr_name, output_field in self.template.output_variables.items():
output_value = parsed_output.get(attr_name)
if not output_value:
validation_err = f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
continue

# Validate the transformed value
if not output_field.validate_value(input, output_value):
validation_err = f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)

# Get the transformed value
try:
transformed_val = output_field.transform_value(output_value)
except Exception as e:
import traceback
traceback.print_exc()
validation_err = f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
continue

# Update the output with the transformed value
parsed_output[attr_name] = transformed_val

end_time = time.time()
self.prompt_history.add_entry(self._determine_llm_type(config['llm']) + " " + self._determine_llm_model(config['llm']), formatted_prompt, prompt_res, parsed_output, validation_err, start_time, end_time)

res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()}

if validation_err is None:
return res
return f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"

max_tries -= 1
if max_tries >= 1:
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
time.sleep(random.uniform(0.05, 0.25))
return None

def _log_prompt_history(self, config, formatted_prompt, prompt_res, parsed_output, validation_err, start_time, end_time):
llm_info = f"{self._determine_llm_type(config['llm'])} {self._determine_llm_model(config['llm'])}"
self.prompt_history.add_entry(llm_info, formatted_prompt, prompt_res, parsed_output, validation_err, start_time, end_time)

def _handle_exception(self, e, max_tries):
import traceback
traceback.print_exc()
logger.error(f"Failed in the LLM layer {e} - sleeping then trying again")
time.sleep(random.uniform(0.1, 1.5))

def _handle_retry(self, max_tries):
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
time.sleep(random.uniform(0.05, 0.25))

def _handle_failure(self, hard_fail, total_max_tries, prompt_res):
if hard_fail:
raise ValueError(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries.")
else:
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries, returning None.")
if len(self.template.output_variables.keys()) == 1:
res = {attr_name: prompt_res for attr_name in self.template.output_variables.keys()}
return {attr_name: prompt_res for attr_name in self.template.output_variables.keys()}
else:
res = {attr_name: None for attr_name in self.template.output_variables.keys()}

return res
return {attr_name: None for attr_name in self.template.output_variables.keys()}

def invoke(self, input: Input, config: Optional[RunnableConfig] = {}) -> Output:
# logger.debug(f"Template: {self.template}")
# logger.debug(f"Config: {config}")

chain = (
self.template
| config['llm']
Expand All @@ -273,8 +249,6 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = {}) -> Output:

res = self._invoke_with_retries(chain, input, max_retries, config=config)

# logger.debug(f"Result: {res}")

prediction_data = {**input, **res}


Expand Down
35 changes: 19 additions & 16 deletions langdspy/prompt_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,26 @@ class PromptStrategy(BaseModel):
best_subset: List[Any] = []

def validate_inputs(self, inputs_dict):
if not set(inputs_dict.keys()) == set(self.input_variables.keys()):
missing_keys = set(self.input_variables.keys()) - set(inputs_dict.keys())
unexpected_keys = set(inputs_dict.keys()) - set(self.input_variables.keys())
expected_keys = set(self.input_variables.keys())
received_keys = set(inputs_dict.keys())

if expected_keys != received_keys:
missing_keys = expected_keys - received_keys
unexpected_keys = received_keys - expected_keys
error_message = []

if missing_keys:
error_message.append(f"Missing input keys: {', '.join(missing_keys)}")
logger.error(f"Missing input keys: {missing_keys}")
if unexpected_keys:
error_message.append(f"Unexpected input keys: {', '.join(unexpected_keys)}")
logger.error(f"Unexpected input keys: {unexpected_keys}")

logger.error(f"Input keys do not match expected input keys Expected = {inputs_dict.keys()} Received = {self.input_variables.keys()}")
raise ValueError(f"Input keys do not match expected input keys Expected: {inputs_dict.keys()} Received: {self.input_variables.keys()}")

error_message.append(f"Expected keys: {', '.join(expected_keys)}")
error_message.append(f"Received keys: {', '.join(received_keys)}")

logger.error(f"Input keys do not match expected input keys. Expected: {expected_keys}, Received: {received_keys}")
raise ValueError(". ".join(error_message))

def format(self, **kwargs: Any) -> str:
logger.debug(f"PromptStrategy format with kwargs: {kwargs}")
Expand All @@ -96,24 +106,17 @@ def format_prompt(self, **kwargs: Any) -> str:
llm_type = kwargs.pop('llm_type', None)

trained_state = kwargs.pop('trained_state', None)
print_prompt = kwargs.pop('print_prompt', False)
use_training = kwargs.pop('use_training', True)
examples = kwargs.pop('__examples__', self.__examples__) # Add this line

# print(f"Formatting prompt with trained_state {trained_state} and print_prompt {print_prompt} and kwargs {kwargs}")
# print(f"Formatting prompt with use_training {use_training}")

try:
# logger.debug(f"Formatting prompt with kwargs: {kwargs}")
self.validate_inputs(kwargs)

# logger.debug(f"PromptStrategy format_prompt with kwargs: {kwargs}")

if llm_type == 'openai':
prompt = self._format_openai_prompt(trained_state, use_training, examples, **kwargs)
elif llm_type == 'openai_json':
prompt = self._format_openai_json_prompt(trained_state, use_training, examples, **kwargs)
elif llm_type == 'anthropic':
elif llm_type == 'anthropic' or llm_type == 'fake_anthropic':
prompt = self._format_anthropic_prompt(trained_state, use_training, examples, **kwargs)

return prompt
Expand All @@ -128,7 +131,7 @@ def parse_output_to_fields(self, output: str, llm_type: str) -> dict:
return self._parse_openai_json_output_to_fields(output)
elif llm_type == 'openai':
return self._parse_openai_output_to_fields(output)
elif llm_type == 'anthropic':
elif llm_type == 'anthropic' or llm_type == 'fake_anthropic':
return self._parse_anthropic_output_to_fields(output)
elif llm_type == 'test':
return self._parse_openai_output_to_fields(output)
Expand Down Expand Up @@ -460,4 +463,4 @@ def _parse_openai_json_output_to_fields(self, output: str) -> dict:
raise e
except Exception as e:
logger.error(f"An error occurred while parsing JSON output: {e}")
raise e
raise e
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ platformdirs = "4.2.0"
pluggy = "1.4.0"
ptyprocess = "0.7.0"
pycparser = "2.21"
pydantic = "2.6.1"
pydantic-core = "2.16.2"
pygments = "2.17.2"
pyproject-hooks = "1.0.0"
pytest = "8.0.2"
Expand All @@ -75,7 +73,7 @@ shellingham = "1.5.4"
sniffio = "1.3.0"
tenacity = "8.2.3"
threadpoolctl = "3.3.0"
tiktoken = "0.6.0"
tiktoken = "^0.7.0"
tokenizers = "0.15.2"
tomlkit = "0.12.4"
tqdm = "4.66.2"
Expand All @@ -87,6 +85,11 @@ xattr = "1.1.0"
yarl = "1.9.4"
zipp = "3.17.0"
ratelimit = "^2.2.1"
langchain = "^0.2.7"
langchain-anthropic = "^0.1.19"
langchain-openai = "^0.1.14"
langchain-community = "^0.2.7"
scikit-learn = "^1.5.1"

[tool.poetry.dev-dependencies]

Expand Down
Loading
Loading