Skip to content

Commit

Permalink
Merge pull request #370 from waleedqk/generated_tokens
Browse files Browse the repository at this point in the history
set default values for params to false or none
  • Loading branch information
evaline-ju authored Aug 1, 2024
2 parents 1499cab + 842bba3 commit 1e5aac6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
16 changes: 8 additions & 8 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ def run(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand Down Expand Up @@ -280,10 +280,10 @@ def run_stream_out(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing against the model running in TGIS
Expand Down
16 changes: 8 additions & 8 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ def run(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand Down Expand Up @@ -296,10 +296,10 @@ def run_stream_out(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing for text generation module.
Expand Down
5 changes: 4 additions & 1 deletion caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def validate_inf_params(
error.type_check("<NLP65883540E>", bool, token_logprobs=token_logprobs)
error.type_check("<NLP65883541E>", bool, token_ranks=token_ranks)
error.type_check(
"<NLP65883542E>", bool, include_stop_sequence=include_stop_sequence
"<NLP65883542E>",
bool,
allow_none=True,
include_stop_sequence=include_stop_sequence,
)
error.type_check("<NLP85452188E>", str, allow_none=True, eos_token=eos_token)
error.type_check(
Expand Down

0 comments on commit 1e5aac6

Please sign in to comment.