diff --git a/docs/advanced_usage.md b/docs/advanced_usage.md index b6dd139e3..c82c6b194 100644 --- a/docs/advanced_usage.md +++ b/docs/advanced_usage.md @@ -13,6 +13,7 @@ Argument description of run.py ([supported MatMul combinations](#supported-matri | --compute_dtype | Data type of Gemm computation: int8/bf16/fp16/fp32 (default: fp32) | | --use_ggml | Enable ggml for quantization and inference | | -p / --prompt | Prompt to start generation with: String (default: empty) | +| -f / --file | Path to a text file containing the prompt (for large prompts) | | -n / --n_predict | Number of tokens to predict: Int (default: -1, -1 = infinity) | | -t / --threads | Number of threads to use during computation: Int (default: 56) | | -b / --batch_size_truncate | Batch size for prompt processing: Int (default: 512) | @@ -22,6 +23,7 @@ Argument description of run.py ([supported MatMul combinations](#supported-matri | --color | Colorise output to distinguish prompt and user input from generations | | --keep | Number of tokens to keep from the initial prompt: Int (default: 0, -1 = all) | | --shift-roped-k | Use [ring-buffer](./docs/infinite_inference.md#shift-rope-k-and-ring-buffer) and thus do not re-computing after reaching ctx_size (default: False) | +| --token | Access token ID for models that require it (e.g: LLaMa2, etc..) | ### 1. Conversion and Quantization @@ -108,6 +110,7 @@ Argument description of inference.py: | -m / --model | Path to the executed model: String | | --build_dir | Path to the build file: String | | -p / --prompt | Prompt to start generation with: String (default: empty) | +| -f / --file | Path to a text file containing the prompt (for large prompts) | | -n / --n_predict | Number of tokens to predict: Int (default: -1, -1 = infinity) | | -t / --threads | Number of threads to use during computation: Int (default: 56) | | -b / --batch_size | Batch size for prompt processing: Int (default: 512) | diff --git a/scripts/inference.py b/scripts/inference.py index 08de992b8..c78940777 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import sys from pathlib import Path import argparse from typing import List, Optional @@ -36,6 +35,13 @@ def main(args_in: Optional[List[str]] = None) -> None: help="Prompt to start generation with: String (default: empty)", default="", ) + parser.add_argument( + "-f", + "--file", + type=str, + help="Path to a text file containing the prompt (for large prompts)", + default=None, + ) parser.add_argument( "--tokenizer", type=str, @@ -126,13 +132,18 @@ def main(args_in: Optional[List[str]] = None) -> None: args = parser.parse_args(args_in) print(args) + if args.file: + with open(args.file, 'r', encoding='utf-8') as f: + prompt_text = f.read() + else: + prompt_text = args.prompt model_name = model_maps.get(args.model_name, args.model_name) package_path = os.path.dirname(neural_speed.__file__) path = Path(package_path, "./run_{}".format(model_name)) cmd = [path] cmd.extend(["--model", args.model]) - cmd.extend(["--prompt", args.prompt]) + cmd.extend(["--prompt", prompt_text]) cmd.extend(["--n-predict", str(args.n_predict)]) cmd.extend(["--threads", str(args.threads)]) cmd.extend(["--batch-size-truncate", str(args.batch_size_truncate)]) @@ -153,7 +164,7 @@ def main(args_in: Optional[List[str]] = None) -> None: if (args.model_name == "chatglm"): tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) - token_ids_list = tokenizer.encode(args.prompt) + token_ids_list = tokenizer.encode(prompt_text) token_ids_list = map(str, token_ids_list) token_ids_str = ', '.join(token_ids_list) cmd.extend(["--ids", token_ids_str]) @@ -193,14 +204,14 @@ def encode_history(history, max_length=4096): else: ids.append(ASSISTANT_TOKEN_ID) - content_ids = tokenizer.encode(args.prompt) + content_ids = tokenizer.encode(prompt_text) ids.extend(content_ids) ids.append(ASSISTANT_TOKEN_ID) truncate(ids, max_length) return ids - history = [args.prompt] + history = [prompt_text] token_ids_list = encode_history(history) token_ids_list = map(str, token_ids_list) token_ids_str = ', '.join(token_ids_list) diff --git a/scripts/run.py b/scripts/run.py index d823edb05..631e0716e 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import sys from pathlib import Path import argparse from typing import List, Optional @@ -89,6 +87,13 @@ def main(args_in: Optional[List[str]] = None) -> None: help="Prompt to start generation with: String (default: empty)", default="Once upon a time, there existed a ", ) + parser.add_argument( + "-f", + "--file", + type=str, + help="Path to a text file containing the prompt (for large prompts)", + default=None, + ) parser.add_argument( "-n", "--n_predict", @@ -207,6 +212,7 @@ def main(args_in: Optional[List[str]] = None) -> None: infer_cmd.extend(["--model_name", model_type]) infer_cmd.extend(["-m", Path(work_path, "ne_{}_{}.bin".format(model_type, args.weight_dtype, args.group_size))]) infer_cmd.extend(["--prompt", args.prompt]) + infer_cmd.extend(["--file", args.file]) infer_cmd.extend(["--n_predict", str(args.n_predict)]) infer_cmd.extend(["--threads", str(args.threads)]) infer_cmd.extend(["--batch_size_truncate", str(args.batch_size_truncate)])