Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Add support for large prompts that don't fit in cmd line (#133)
Browse files Browse the repository at this point in the history
* Support for large prompts

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
aahouzi and pre-commit-ci[bot] authored Mar 14, 2024
1 parent 0ec1a6e commit e76a58e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
3 changes: 3 additions & 0 deletions docs/advanced_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Argument description of run.py ([supported MatMul combinations](#supported-matri
| --compute_dtype | Data type of Gemm computation: int8/bf16/fp16/fp32 (default: fp32) |
| --use_ggml | Enable ggml for quantization and inference |
| -p / --prompt | Prompt to start generation with: String (default: empty) |
| -f / --file | Path to a text file containing the prompt (for large prompts) |
| -n / --n_predict | Number of tokens to predict: Int (default: -1, -1 = infinity) |
| -t / --threads | Number of threads to use during computation: Int (default: 56) |
| -b / --batch_size_truncate | Batch size for prompt processing: Int (default: 512) |
Expand All @@ -22,6 +23,7 @@ Argument description of run.py ([supported MatMul combinations](#supported-matri
| --color | Colorise output to distinguish prompt and user input from generations |
| --keep | Number of tokens to keep from the initial prompt: Int (default: 0, -1 = all) |
| --shift-roped-k | Use [ring-buffer](./docs/infinite_inference.md#shift-rope-k-and-ring-buffer) and thus do not re-computing after reaching ctx_size (default: False) |
| --token | Access token ID for models that require it (e.g: LLaMa2, etc..) |


### 1. Conversion and Quantization
Expand Down Expand Up @@ -108,6 +110,7 @@ Argument description of inference.py:
| -m / --model | Path to the executed model: String |
| --build_dir | Path to the build file: String |
| -p / --prompt | Prompt to start generation with: String (default: empty) |
| -f / --file | Path to a text file containing the prompt (for large prompts) |
| -n / --n_predict | Number of tokens to predict: Int (default: -1, -1 = infinity) |
| -t / --threads | Number of threads to use during computation: Int (default: 56) |
| -b / --batch_size | Batch size for prompt processing: Int (default: 512) |
Expand Down
21 changes: 16 additions & 5 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from pathlib import Path
import argparse
from typing import List, Optional
Expand All @@ -36,6 +35,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
help="Prompt to start generation with: String (default: empty)",
default="",
)
parser.add_argument(
"-f",
"--file",
type=str,
help="Path to a text file containing the prompt (for large prompts)",
default=None,
)
parser.add_argument(
"--tokenizer",
type=str,
Expand Down Expand Up @@ -126,13 +132,18 @@ def main(args_in: Optional[List[str]] = None) -> None:

args = parser.parse_args(args_in)
print(args)
if args.file:
with open(args.file, 'r', encoding='utf-8') as f:
prompt_text = f.read()
else:
prompt_text = args.prompt
model_name = model_maps.get(args.model_name, args.model_name)
package_path = os.path.dirname(neural_speed.__file__)
path = Path(package_path, "./run_{}".format(model_name))

cmd = [path]
cmd.extend(["--model", args.model])
cmd.extend(["--prompt", args.prompt])
cmd.extend(["--prompt", prompt_text])
cmd.extend(["--n-predict", str(args.n_predict)])
cmd.extend(["--threads", str(args.threads)])
cmd.extend(["--batch-size-truncate", str(args.batch_size_truncate)])
Expand All @@ -153,7 +164,7 @@ def main(args_in: Optional[List[str]] = None) -> None:

if (args.model_name == "chatglm"):
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
token_ids_list = tokenizer.encode(args.prompt)
token_ids_list = tokenizer.encode(prompt_text)
token_ids_list = map(str, token_ids_list)
token_ids_str = ', '.join(token_ids_list)
cmd.extend(["--ids", token_ids_str])
Expand Down Expand Up @@ -193,14 +204,14 @@ def encode_history(history, max_length=4096):
else:
ids.append(ASSISTANT_TOKEN_ID)

content_ids = tokenizer.encode(args.prompt)
content_ids = tokenizer.encode(prompt_text)
ids.extend(content_ids)

ids.append(ASSISTANT_TOKEN_ID)
truncate(ids, max_length)
return ids

history = [args.prompt]
history = [prompt_text]
token_ids_list = encode_history(history)
token_ids_list = map(str, token_ids_list)
token_ids_str = ', '.join(token_ids_list)
Expand Down
10 changes: 8 additions & 2 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from pathlib import Path
import argparse
from typing import List, Optional
Expand Down Expand Up @@ -89,6 +87,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
help="Prompt to start generation with: String (default: empty)",
default="Once upon a time, there existed a ",
)
parser.add_argument(
"-f",
"--file",
type=str,
help="Path to a text file containing the prompt (for large prompts)",
default=None,
)
parser.add_argument(
"-n",
"--n_predict",
Expand Down Expand Up @@ -207,6 +212,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
infer_cmd.extend(["--model_name", model_type])
infer_cmd.extend(["-m", Path(work_path, "ne_{}_{}.bin".format(model_type, args.weight_dtype, args.group_size))])
infer_cmd.extend(["--prompt", args.prompt])
infer_cmd.extend(["--file", args.file])
infer_cmd.extend(["--n_predict", str(args.n_predict)])
infer_cmd.extend(["--threads", str(args.threads)])
infer_cmd.extend(["--batch_size_truncate", str(args.batch_size_truncate)])
Expand Down

0 comments on commit e76a58e

Please sign in to comment.