Convert SamplingParams.strategy
to a union
#767
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Cleans up how we provide sampling params. Earlier, strategy was an enum and all params (top_p, temperature, top_k) across all strategies were grouped. We now have a strategy union object with each strategy (greedy, top_p, top_k) having its corresponding params.
Earlier,
However, the
strategy
field was not being used in any providers making it confusing to know the exact sampling behavior purely based on the params since you could pass temperature, top_p, top_k and how the provider would interpret those would not be clear.Hence we introduced -- a union where the strategy and relevant params are all clubbed together to avoid this confusion.
Have updated all providers, tests, notebooks, readme and otehr places where sampling params was being used to use the new format.
Test Plan
pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py
// inference on ollama, fireworks and together
with-proxy pytest -v -s -k "ollama" --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/inference/test_text_inference.py
// agents on fireworks
pytest -v -s -k 'fireworks and create_agent' --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/agents/test_agents.py --safety-shield="meta-llama/Llama-Guard-3-8B"
Before submitting
Pull Request section?