From 34f396fb80a484334c351fdc49c40456e7e03b5a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 4 Nov 2021 22:00:10 -0700 Subject: [PATCH 01/53] Add BeamSearch op schema --- .../core/graph/contrib_ops/contrib_defs.cc | 138 +++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index c9f6d315a1813..a2ca4bfa80d8d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -531,6 +531,141 @@ void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int p } } +bool ParseScalar(const TensorProto* initializer, int& value) { + std::vector parsed_data; + if (initializer->data_type() == TensorProto::INT32) { + const auto& data = ParseData(initializer); + parsed_data.insert(parsed_data.end(), data.begin(), data.end()); + + if (parsed_data.size() == 1) { + value = parsed_data[0]; + return true; + } + } + + return false; +} + +void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + + if (ctx.getNumOutputs() > 1) { + // Here we assume that the third output exist only if second output exists. + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 6, 1); + if (ctx.getNumOutputs() > 2) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 6, 2); + } + } + + // Shape inference + // input 0 (input_ids) shape: (batch_size * num_beams, sequence_length) + // output 0 (sequences) shape: (batch_size * num_return_sequences, max_length) + // output 1 (sequences_scores) shape: (batch_size * num_return_sequences) + // output 2 (scores) shape: (max_length-sequence_length, batch_size*num_beams*num_return_sequences, vocab_size) + if (!hasInputShape(ctx, 0)) { + return; + } + auto& input_ids_shape = getInputShape(ctx, 0); + auto& input_ids_dims = input_ids_shape.dim(); + if (input_ids_dims.size() != 2) { + fail_shape_inference("Inputs 0 shall be 2 dimensions"); + } + if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value())) { + return; + } + int64_t batch_beam_size = input_ids_dims[0].dim_value(); + int64_t sequence_length = input_ids_dims[1].dim_value(); + + const auto max_length = ctx.getInputData(3); + const auto num_beams = ctx.getInputData(4); + const auto num_return_sequences = ctx.getInputData(5); + if (num_beams == nullptr || max_length == nullptr || num_return_sequences == nullptr) { // not initializer + return; + } + + int max_length_value = 0; + if (!ParseScalar(max_length, max_length_value) || max_length_value <= 0) { + fail_shape_inference("Failed to parse max_length or it is not positive integer scalar"); + } + + int num_beams_value = 0; + if (!ParseScalar(num_beams, num_beams_value) || num_beams_value <= 0) { + fail_shape_inference("Failed to parse num_beams or it is not positive integer scalar"); + } + + int num_return_sequences_value = 0; + if (!ParseScalar(num_return_sequences, num_return_sequences_value) || num_return_sequences_value <= 0) { + fail_shape_inference("Failed to parse num_return_sequences or it is not positive integer scalar"); + } + + if (batch_beam_size % num_beams_value != 0) { + fail_shape_inference("input_ids dimension 0 shall be multiple of num_beams"); + } + + int64_t batch_size = batch_beam_size / num_beams_value; + ONNX_NAMESPACE::TensorShapeProto sequences_shape; + sequences_shape.add_dim()->set_dim_value(batch_size * num_beams_value); + sequences_shape.add_dim()->set_dim_value(batch_size * sequence_length); + updateOutputShape(ctx, 0, sequences_shape); + + if (ctx.getNumOutputs() > 1) { + ONNX_NAMESPACE::TensorShapeProto sequences_scores_shape; + sequences_scores_shape.add_dim()->set_dim_value(batch_size * num_beams_value); + updateOutputShape(ctx, 1, sequences_scores_shape); + + if (ctx.getNumOutputs() > 2) { + ONNX_NAMESPACE::TensorShapeProto scores_shape; + scores_shape.add_dim()->set_dim_value(max_length_value - sequence_length); + scores_shape.add_dim()->set_dim_value(batch_size * num_beams_value * num_return_sequences_value); + scores_shape.add_dim(); // vocab_size is unknown + updateOutputShape(ctx, 2, scores_shape); + } + } +} + +void RegisterTextGenerationSchemas() { + ONNX_CONTRIB_OPERATOR_SCHEMA(BeamSearch) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc("Beam Search for text generation. Supports GPT-2 decoder.") + .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) + .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) + .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) + .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) + .Attr( + "body", + "The GPT subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output", + AttributeProto::GRAPH) + .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape (batch_size * num_beams, sequence_length)", "I") + .Input(1, "attention_mask", "Mask to avoid performing attention on padding token indices. Shape (batch_size, sequence_length)", "M", OpSchema::Optional) + .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. 1D input tensor with shape (1,)", "I", OpSchema::Optional) + .Input(3, "max_length", "1D input tensor with shape (1,)", "I") + .Input(4, "num_beams", "Number of beams for beam search. 1 means no beam search. 1D input tensor with shape (1,)", "I") + .Input(5, "num_return_sequences", "1D input tensor with shape (1,)", "I") + .Input(6, "temperature", "The value used to module the logits distribution. Accepts value != 0.0. 1D input tensor with shape (1,)", "T") + .Input(7, "length_penalty", + "Exponential penalty to the length. Default value 1.0 means no penalty." + "Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences." + "1D input tensor with shape (1,)", + "T", OpSchema::Optional) + .Input(8, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0, shape (1,)", "T", OpSchema::Optional) + .Input(9, "vocab_mask", "Mask of vocabulary. Word that masked with 0 are not allowed to be generated, and 1 is allowed. shape (vacab_size,)", "M", OpSchema::Optional) + .Output(0, "sequences", "Word IDs of generated sequences. Shape: (batch_size * num_return_sequences, max_sequence_length)", "I") + .Output(1, "sequences_scores", "Final beam score of the generated sequences. shape (batch_size*num_return_sequences)", "T", OpSchema::Optional) + .Output(2, "scores", + "Processed beam scores for each vocabulary token at each generation step." + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Shape (max_length - input_ids_sequence_length, batch_size*num_beams*num_return_sequences, vocab_size)", + "T", OpSchema::Optional) + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") + .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + BeamSearchShapeInference(ctx); + }); +} + void RegisterBertSchemas() { static const char* Attention_ver1_doc = R"DOC( Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). @@ -3077,7 +3212,8 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i RegisterNhwcSchemas(); RegisterBertSchemas(); - + RegisterTextGenerationSchemas(); + #ifdef BUILD_MS_EXPERIMENTAL_OPS onnxruntime::signal::RegisterSignalSchemas(); #endif From 1fe45a251d8ad38fc0e1c27043e40e4f382a8cf1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 5 Nov 2021 12:44:03 -0700 Subject: [PATCH 02/53] Add ONNX conversion for beams search --- .../tools/transformers/convert_beam_search.py | 346 ++++++++++++++++++ 1 file changed, 346 insertions(+) create mode 100644 onnxruntime/python/tools/transformers/convert_beam_search.py diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py new file mode 100644 index 0000000000000..4dfc043e63ea4 --- /dev/null +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -0,0 +1,346 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import onnx +import logging +import argparse +from pathlib import Path +from onnx import helper +import numpy as np +from transformers import AutoConfig +from gpt2_helper import PRETRAINED_GPT2_MODELS +from onnx_model import OnnxModel +from convert_to_onnx import main as convert_gpt2_to_onnx +from benchmark_helper import create_onnxruntime_session, setup_logger, prepare_environment, Precision +""" +This converts GPT2 model to onnx with beam search operator. + +Examples: + python convert_beam_search.py -m gpt2 --gpt2_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores +""" + +config = None + +logger = logging.getLogger('') + + +def parse_arguments(argv=None): + parser = argparse.ArgumentParser() + + parser.add_argument('-m', + '--model_name_or_path', + required=True, + type=str, + help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_GPT2_MODELS)) + + parser.add_argument('--cache_dir', + required=False, + type=str, + default=os.path.join('.', 'cache_models'), + help='Directory to cache pre-trained models') + + parser.add_argument('--gpt2_onnx', + required=True, + type=str, + help='Output directory for GPT-2 onnx model, or model path ends with .onnx') + + parser.add_argument('--output', + required=False, + type=str, + help='Output directory for beam search model, or model path ends with .onnx') + + parser.add_argument("-p", + "--precision", + required=False, + type=Precision, + default=Precision.FLOAT32, + choices=[Precision.FLOAT32, Precision.FLOAT16], + help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision") + + parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU for inference") + parser.set_defaults(use_gpu=False) + + parser.add_argument('-e', '--use_external_data_format', required=False, action='store_true') + parser.set_defaults(use_external_data_format=False) + + beam_search_group = parser.add_argument_group("beam search options") + + beam_search_group.add_argument('--output_sequences_scores', + required=False, + action='store_true', + help="output sequences scores") + beam_search_group.set_defaults(output_sequences_scores=False) + + beam_search_group.add_argument('--output_token_scores', + required=False, + action='store_true', + help="output token scores") + beam_search_group.set_defaults(output_token_scores=False) + + beam_search_group.add_argument('--early_stopping', required=False, action='store_true') + beam_search_group.set_defaults(early_stopping=False) + + beam_search_group.add_argument('--min_length', type=int, required=False, default=1, help='Min sequence length') + + beam_search_group.add_argument('--max_length', type=int, required=False, default=50, help='Max sequence length') + + beam_search_group.add_argument('--no_repeat_ngram_size', + type=int, + required=False, + default=0, + help='No repeat ngram size') + + beam_search_group.add_argument('--beam_size', type=int, required=False, default=4, help='Beam size') + + beam_search_group.add_argument('--num_return_sequences', + type=int, + required=False, + default=1, + help='Number of return sequence') + + beam_search_group.add_argument('--temperature', + type=float, + required=False, + default=1, + help='Softmax temperature for output logits.') + + beam_search_group.add_argument('--length_penalty', + type=float, + required=False, + default=1, + help='Positive. >1 to penalize and <1 to encorage short sentence.') + + beam_search_group.add_argument('--repetition_penalty', + type=float, + required=False, + default=1, + help='Positive. >1 to penalize and <1 to encorage.') + + mixed_precision_option_grapu = parser.add_argument_group( + "mixed precision conversion parameters that works when \"--precision fp16\" is specified") + + mixed_precision_option_grapu.add_argument('--io_block_list', + nargs='+', + required=False, + default=[], + help='List of inputs or outputs in float32') + + mixed_precision_option_grapu.add_argument( + '--op_block_list', + nargs='+', + required=False, + default=[], + help='List of operators (like Add LayerNormalization FastGelu) to compute in float32.') + + mixed_precision_option_grapu.add_argument('--node_block_list', + nargs='+', + required=False, + default=[], + help='List of node names to compute in float32.') + + mixed_precision_option_grapu.add_argument('--force_fp16_initializers', + required=False, + action='store_true', + help='Convert all float initializers to float16.') + mixed_precision_option_grapu.set_defaults(force_fp16_initializers=False) + + args = parser.parse_args(argv) + + return args + + +def convert_gpt2_to_onnx(args): + model_name = args.model_name_or_path + + print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.gpt2_onnx} ...") + arguments = [ + '--model_name_or_path', model_name, '--output', args.gpt2_onnx, '--optimize_onnx', '--precision', + 'fp32' if args.precision == Precision.FLOAT32 else 'fp16', '--test_runs', '1', '--test_cases', '10' + ] + if args.use_gpu: + arguments.append('--use_gpu') + if args.use_external_data_format: + arguments.append('--use_external_data_format') + + # mixed precision conversion options + if args.io_block_list: + arguments.append('--io_block_list') + arguments.extend(args.io_block_list) + if args.op_block_list: + arguments.append('--op_block_list') + arguments.extend(args.op_block_list) + if args.node_block_list: + arguments.append('--node_block_list') + arguments.extend(args.node_block_list) + if args.force_fp16_initializers: + arguments.append('--force_fp16_initializers') + + convert_gpt2_to_onnx(arguments) + + +def convert_model(args): + if os.path.exists(args.gpt2_onnx): + print(f"skip convert_to_onnx since path existed: {args.gpt2_onnx}") + else: + convert_gpt2_to_onnx(args) + + global config + config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + print(config) + + eos_token_id = config.eos_token_id + pad_token_id = config.eos_token_id + vocab_size = config.vocab_size + + model = onnx.load(args.gpt2_onnx) + model.graph.name = "gpt2 subgraph" + inputs = [ + "input_ids", "attention_mask", "min_length", "max_length", "num_beams", "num_return_sequences", "temperature", + "length_penalty", "repetition_penalty" + ] + + outputs = ["sequences"] + if args.output_sequences_scores: + outputs.append("sequences_scores") + if args.output_token_scores: + outputs.append("scores") + + node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name='BeamSearch_GPT2') + node.domain = "com.microsoft" + node.attribute.extend([ + helper.make_attribute("eos_token_id", eos_token_id), + helper.make_attribute("pad_token_id", pad_token_id), + helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + helper.make_attribute("early_stopping", 1 if args.early_stopping else 0), + helper.make_attribute("body", model.graph), + ]) + + from onnx import TensorProto + + # graph inputs + input_ids = helper.make_tensor_value_info('input_ids', TensorProto.INT32, ['batch_size', 'sequence_length']) + attention_mask = helper.make_tensor_value_info('attention_mask', TensorProto.INT32, + ['batch_size', 'sequence_length']) + + min_length = helper.make_tensor_value_info('min_length', TensorProto.INT32, [1]) + max_length = helper.make_tensor_value_info('max_length', TensorProto.INT32, [1]) + num_beams = helper.make_tensor_value_info('num_beams', TensorProto.INT32, [1]) + num_return_sequences = helper.make_tensor_value_info('num_return_sequences', TensorProto.INT32, [1]) + temperature = helper.make_tensor_value_info('temperature', TensorProto.FLOAT, [1]) + length_penalty = helper.make_tensor_value_info('length_penalty', TensorProto.FLOAT, [1]) + repetition_penalty = helper.make_tensor_value_info('repetition_penalty', TensorProto.FLOAT, [1]) + + graph_inputs = [ + input_ids, attention_mask, min_length, max_length, num_beams, num_return_sequences, temperature, length_penalty, + repetition_penalty + ] + + # graph outputs + sequences = helper.make_tensor_value_info('sequences', TensorProto.INT32, + ['batch_size * num_return_sequences', 'max_length']) + sequences_scores = helper.make_tensor_value_info('sequences_scores', TensorProto.FLOAT, + ['batch_size * num_return_sequences']) + scores = helper.make_tensor_value_info( + 'scores', TensorProto.FLOAT, + ['max_length - sequence_length', 'batch_size * num_beams * num_return_sequences', vocab_size]) + + initializers = [] + + graph_outputs = [sequences] + if args.output_sequences_scores: + graph_outputs.append(sequences_scores) + if args.output_token_scores: + graph_outputs.append(scores) + + new_graph = helper.make_graph([node], 'gpt2-beam-search', graph_inputs, graph_outputs, initializers) + + # Create the model + new_model = helper.make_model(new_graph, producer_name='onnxruntime.transformers', opset_imports=model.opset_import) + onnx.save(new_model, args.output) + + +def test_model(args): + from transformers import GPT2Tokenizer, GPT2LMHeadModel + tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path, + cache_dir=args.cache_dir, + pad_token_id=tokenizer.eos_token_id) + input_ids = tokenizer.encode('I enjoy walking in the park', return_tensors='pt') + + eos_token_id = config.eos_token_id + pad_token_id = config.eos_token_id + vocab_size = config.vocab_size + + print('-' * 50) + print("Test PyTorch model and beam search with huggingface transformers...") + beam_outputs = model.generate(input_ids, + min_length=args.min_length, + max_length=args.max_length, + num_beams=args.num_beams, + early_stopping=args.early_stopping, + no_repeat_ngram_size=args.no_repeat_ngram_size, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + num_return_sequences=args.num_return_sequences, + temperature=args.temperature, + length_penalty=args.length_penalty, + repetition_penalty=args.repetition_penalty) + print("input_ids", input_ids) + print("huggingface transformers output:", beam_outputs) + for i, beam_output in enumerate(beam_outputs): + print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True))) + + print('-' * 50) + print("Test ONNX model and bream search with onnxruntime...") + from onnxruntime import SessionOptions, InferenceSession, __version__ as ort_version, GraphOptimizationLevel + sess_options = SessionOptions() + execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider' + ] if args.use_gpu else ['CPUExecutionProvider'] + + ort_session = InferenceSession(args.output, sess_options, providers=execution_providers) + + _, sequence_length = input_ids.shape + batch_size = 2 + input_ids = input_ids.repeat(batch_size, 1) + + inputs = { + "input_ids": input_ids.cpu().numpy().astype(np.int32), + "attention_mask": np.ones((batch_size, sequence_length), dtype=np.float32), + "min_length": np.array([args.min_length], dtype=np.int32), + "max_length": np.array([args.max_length], dtype=np.int32), + "num_beams": np.array([args.num_beams], dtype=np.int32), + "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32), + "temperature": np.array([args.temperature], dtype=np.float32), + "length_penalty": np.array([args.length_penalty], dtype=np.float32), + "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), + "vocab_mask": np.ones((vocab_size), dtype=np.float32) + } + + test_data_dir = Path(args.output).parent.as_posix() + print("test_data_dir", test_data_dir) + from bert_test_data import output_test_data + all_inputs = [inputs] + for i, inputs in enumerate(all_inputs): + dir = os.path.join(test_data_dir, 'test_data_set_' + str(i)) + output_test_data(dir, inputs) + + print("inputs", inputs) + result = ort_session.run(None, inputs) + print("outputs", result) + #print(tokenizer.decode(result[0][0], skip_special_tokens=True)) + + +def main(): + # TODO: remove debug code + import time + print('You have 30 seconds to attach a debugger.') + time.sleep(30) + + args = parse_arguments() + convert_model(args) + test_model(args) + + +if __name__ == '__main__': + main() From e7a665c394d7567fee449bd414a75412c0a1e41e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 6 Nov 2021 22:40:08 -0700 Subject: [PATCH 03/53] remove attention_mask and change input order --- .../core/graph/contrib_ops/contrib_defs.cc | 41 +++++---- .../tools/transformers/convert_beam_search.py | 88 +++++++++++-------- 2 files changed, 69 insertions(+), 60 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a2ca4bfa80d8d..e0d6af8c5138e 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -552,9 +552,9 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { if (ctx.getNumOutputs() > 1) { // Here we assume that the third output exist only if second output exists. - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 6, 1); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 5, 1); if (ctx.getNumOutputs() > 2) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 6, 2); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 5, 2); } } @@ -577,9 +577,9 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { int64_t batch_beam_size = input_ids_dims[0].dim_value(); int64_t sequence_length = input_ids_dims[1].dim_value(); - const auto max_length = ctx.getInputData(3); - const auto num_beams = ctx.getInputData(4); - const auto num_return_sequences = ctx.getInputData(5); + const auto max_length = ctx.getInputData(1); + const auto num_beams = ctx.getInputData(3); + const auto num_return_sequences = ctx.getInputData(4); if (num_beams == nullptr || max_length == nullptr || num_return_sequences == nullptr) { // not initializer return; } @@ -635,28 +635,27 @@ void RegisterTextGenerationSchemas() { .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr( "body", - "The GPT subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output", + "The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output", AttributeProto::GRAPH) - .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape (batch_size * num_beams, sequence_length)", "I") - .Input(1, "attention_mask", "Mask to avoid performing attention on padding token indices. Shape (batch_size, sequence_length)", "M", OpSchema::Optional) - .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. 1D input tensor with shape (1,)", "I", OpSchema::Optional) - .Input(3, "max_length", "1D input tensor with shape (1,)", "I") - .Input(4, "num_beams", "Number of beams for beam search. 1 means no beam search. 1D input tensor with shape (1,)", "I") - .Input(5, "num_return_sequences", "1D input tensor with shape (1,)", "I") - .Input(6, "temperature", "The value used to module the logits distribution. Accepts value != 0.0. 1D input tensor with shape (1,)", "T") - .Input(7, "length_penalty", + .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size * num_beams, sequence_length)", "I") + .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") + .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) + .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") + .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") + .Input(5, "temperature", "The value used to module the next token probabilities. Accepts value != 0.0. Shape is (1)", "T") + .Input(6, "length_penalty", "Exponential penalty to the length. Default value 1.0 means no penalty." - "Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer sequences." - "1D input tensor with shape (1,)", + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Shape is (1,)", "T", OpSchema::Optional) - .Input(8, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0, shape (1,)", "T", OpSchema::Optional) - .Input(9, "vocab_mask", "Mask of vocabulary. Word that masked with 0 are not allowed to be generated, and 1 is allowed. shape (vacab_size,)", "M", OpSchema::Optional) - .Output(0, "sequences", "Word IDs of generated sequences. Shape: (batch_size * num_return_sequences, max_sequence_length)", "I") - .Output(1, "sequences_scores", "Final beam score of the generated sequences. shape (batch_size*num_return_sequences)", "T", OpSchema::Optional) + .Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) + .Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size * num_return_sequences, max_sequence_length)", "I") + .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size*num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", "Processed beam scores for each vocabulary token at each generation step." "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." - "Shape (max_length - input_ids_sequence_length, batch_size*num_beams*num_return_sequences, vocab_size)", + "Shape is (max_length - input_ids_sequence_length, batch_size*num_beams*num_return_sequences, vocab_size)", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 4dfc043e63ea4..df814e31655c2 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -8,11 +8,11 @@ from pathlib import Path from onnx import helper import numpy as np -from transformers import AutoConfig +from transformers import GPT2Config from gpt2_helper import PRETRAINED_GPT2_MODELS -from onnx_model import OnnxModel from convert_to_onnx import main as convert_gpt2_to_onnx -from benchmark_helper import create_onnxruntime_session, setup_logger, prepare_environment, Precision +from benchmark_helper import Precision + """ This converts GPT2 model to onnx with beam search operator. @@ -20,11 +20,10 @@ python convert_beam_search.py -m gpt2 --gpt2_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores """ -config = None +config:GPT2Config = None logger = logging.getLogger('') - def parse_arguments(argv=None): parser = argparse.ArgumentParser() @@ -91,7 +90,7 @@ def parse_arguments(argv=None): default=0, help='No repeat ngram size') - beam_search_group.add_argument('--beam_size', type=int, required=False, default=4, help='Beam size') + beam_search_group.add_argument('--num_beams', type=int, required=False, default=4, help='Beam size') beam_search_group.add_argument('--num_return_sequences', type=int, @@ -150,7 +149,7 @@ def parse_arguments(argv=None): return args -def convert_gpt2_to_onnx(args): +def gpt2_to_onnx(args): model_name = args.model_name_or_path print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.gpt2_onnx} ...") @@ -164,29 +163,48 @@ def convert_gpt2_to_onnx(args): arguments.append('--use_external_data_format') # mixed precision conversion options - if args.io_block_list: - arguments.append('--io_block_list') - arguments.extend(args.io_block_list) - if args.op_block_list: - arguments.append('--op_block_list') - arguments.extend(args.op_block_list) - if args.node_block_list: - arguments.append('--node_block_list') - arguments.extend(args.node_block_list) - if args.force_fp16_initializers: - arguments.append('--force_fp16_initializers') + if args.precision == Precision.FLOAT16: + assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu" + if args.io_block_list: + arguments.append('--io_block_list') + arguments.extend(args.io_block_list) + if args.op_block_list: + arguments.append('--op_block_list') + arguments.extend(args.op_block_list) + if args.node_block_list: + arguments.append('--node_block_list') + arguments.extend(args.node_block_list) + if args.force_fp16_initializers: + arguments.append('--force_fp16_initializers') convert_gpt2_to_onnx(arguments) + # Run symbolic shape inference to walk around ORT shape inference issue for subgraph. + from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference + out = SymbolicShapeInference.infer_shapes(onnx.load(args.gpt2_onnx), auto_merge=True, guess_output_rank=False) + if out: + onnx.save(out, args.gpt2_onnx) + +def create_ort_session(model_path, use_gpu): + from onnxruntime import SessionOptions, InferenceSession, __version__ as ort_version, GraphOptimizationLevel + sess_options = SessionOptions() + sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL + execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider' + ] if use_gpu else ['CPUExecutionProvider'] + + ort_session = InferenceSession(model_path, sess_options, providers=execution_providers) + return ort_session def convert_model(args): if os.path.exists(args.gpt2_onnx): print(f"skip convert_to_onnx since path existed: {args.gpt2_onnx}") else: - convert_gpt2_to_onnx(args) + gpt2_to_onnx(args) + + #create_ort_session(args.gpt2_onnx, args.use_gpu) global config - config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) print(config) eos_token_id = config.eos_token_id @@ -196,7 +214,7 @@ def convert_model(args): model = onnx.load(args.gpt2_onnx) model.graph.name = "gpt2 subgraph" inputs = [ - "input_ids", "attention_mask", "min_length", "max_length", "num_beams", "num_return_sequences", "temperature", + "input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty", "repetition_penalty" ] @@ -220,11 +238,8 @@ def convert_model(args): # graph inputs input_ids = helper.make_tensor_value_info('input_ids', TensorProto.INT32, ['batch_size', 'sequence_length']) - attention_mask = helper.make_tensor_value_info('attention_mask', TensorProto.INT32, - ['batch_size', 'sequence_length']) - - min_length = helper.make_tensor_value_info('min_length', TensorProto.INT32, [1]) max_length = helper.make_tensor_value_info('max_length', TensorProto.INT32, [1]) + min_length = helper.make_tensor_value_info('min_length', TensorProto.INT32, [1]) num_beams = helper.make_tensor_value_info('num_beams', TensorProto.INT32, [1]) num_return_sequences = helper.make_tensor_value_info('num_return_sequences', TensorProto.INT32, [1]) temperature = helper.make_tensor_value_info('temperature', TensorProto.FLOAT, [1]) @@ -232,7 +247,7 @@ def convert_model(args): repetition_penalty = helper.make_tensor_value_info('repetition_penalty', TensorProto.FLOAT, [1]) graph_inputs = [ - input_ids, attention_mask, min_length, max_length, num_beams, num_return_sequences, temperature, length_penalty, + input_ids, max_length, min_length, num_beams, num_return_sequences, temperature, length_penalty, repetition_penalty ] @@ -275,8 +290,8 @@ def test_model(args): print('-' * 50) print("Test PyTorch model and beam search with huggingface transformers...") beam_outputs = model.generate(input_ids, - min_length=args.min_length, max_length=args.max_length, + min_length=args.min_length, num_beams=args.num_beams, early_stopping=args.early_stopping, no_repeat_ngram_size=args.no_repeat_ngram_size, @@ -293,22 +308,16 @@ def test_model(args): print('-' * 50) print("Test ONNX model and bream search with onnxruntime...") - from onnxruntime import SessionOptions, InferenceSession, __version__ as ort_version, GraphOptimizationLevel - sess_options = SessionOptions() - execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider' - ] if args.use_gpu else ['CPUExecutionProvider'] - ort_session = InferenceSession(args.output, sess_options, providers=execution_providers) + ort_session = create_ort_session(args.output, args.use_gpu) - _, sequence_length = input_ids.shape batch_size = 2 input_ids = input_ids.repeat(batch_size, 1) inputs = { "input_ids": input_ids.cpu().numpy().astype(np.int32), - "attention_mask": np.ones((batch_size, sequence_length), dtype=np.float32), - "min_length": np.array([args.min_length], dtype=np.int32), "max_length": np.array([args.max_length], dtype=np.int32), + "min_length": np.array([args.min_length], dtype=np.int32), "num_beams": np.array([args.num_beams], dtype=np.int32), "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32), "temperature": np.array([args.temperature], dtype=np.float32), @@ -332,15 +341,16 @@ def test_model(args): def main(): + args = parse_arguments() + # TODO: remove debug code import time - print('You have 30 seconds to attach a debugger.') - time.sleep(30) + print('You have 15 seconds to attach a debugger.') + time.sleep(15) - args = parse_arguments() convert_model(args) - test_model(args) + test_model(args) if __name__ == '__main__': main() From 27bf8093e79fcbf2168f8745ec786ca30ae9007b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 8 Nov 2021 14:41:55 -0800 Subject: [PATCH 04/53] add option to run baseline --- .../tools/transformers/convert_beam_search.py | 56 +++++++++++-------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index df814e31655c2..71fba8593ed9e 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -63,6 +63,9 @@ def parse_arguments(argv=None): parser.add_argument('-e', '--use_external_data_format', required=False, action='store_true') parser.set_defaults(use_external_data_format=False) + parser.add_argument('--run_baseline', required=False, action='store_true', help="run huggingface beam search") + parser.set_defaults(run_baseline=False) + beam_search_group = parser.add_argument_group("beam search options") beam_search_group.add_argument('--output_sequences_scores', @@ -215,7 +218,7 @@ def convert_model(args): model.graph.name = "gpt2 subgraph" inputs = [ "input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", - "length_penalty", "repetition_penalty" + "length_penalty", "repetition_penalty", "vocab_mask" ] outputs = ["sequences"] @@ -245,10 +248,11 @@ def convert_model(args): temperature = helper.make_tensor_value_info('temperature', TensorProto.FLOAT, [1]) length_penalty = helper.make_tensor_value_info('length_penalty', TensorProto.FLOAT, [1]) repetition_penalty = helper.make_tensor_value_info('repetition_penalty', TensorProto.FLOAT, [1]) + vocab_mask = helper.make_tensor_value_info('vocab_mask', TensorProto.INT32, [vocab_size]) graph_inputs = [ input_ids, max_length, min_length, num_beams, num_return_sequences, temperature, length_penalty, - repetition_penalty + repetition_penalty, vocab_mask ] # graph outputs @@ -283,28 +287,33 @@ def test_model(args): pad_token_id=tokenizer.eos_token_id) input_ids = tokenizer.encode('I enjoy walking in the park', return_tensors='pt') + global config + if config is None: + config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + eos_token_id = config.eos_token_id pad_token_id = config.eos_token_id vocab_size = config.vocab_size - print('-' * 50) - print("Test PyTorch model and beam search with huggingface transformers...") - beam_outputs = model.generate(input_ids, - max_length=args.max_length, - min_length=args.min_length, - num_beams=args.num_beams, - early_stopping=args.early_stopping, - no_repeat_ngram_size=args.no_repeat_ngram_size, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - num_return_sequences=args.num_return_sequences, - temperature=args.temperature, - length_penalty=args.length_penalty, - repetition_penalty=args.repetition_penalty) - print("input_ids", input_ids) - print("huggingface transformers output:", beam_outputs) - for i, beam_output in enumerate(beam_outputs): - print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True))) + if args.run_baseline: + print('-' * 50) + print("Test PyTorch model and beam search with huggingface transformers...") + beam_outputs = model.generate(input_ids, + max_length=args.max_length, + min_length=args.min_length, + num_beams=args.num_beams, + early_stopping=args.early_stopping, + no_repeat_ngram_size=args.no_repeat_ngram_size, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + num_return_sequences=args.num_return_sequences, + temperature=args.temperature, + length_penalty=args.length_penalty, + repetition_penalty=args.repetition_penalty) + print("input_ids", input_ids) + print("huggingface transformers output:", beam_outputs) + for i, beam_output in enumerate(beam_outputs): + print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True))) print('-' * 50) print("Test ONNX model and bream search with onnxruntime...") @@ -323,7 +332,7 @@ def test_model(args): "temperature": np.array([args.temperature], dtype=np.float32), "length_penalty": np.array([args.length_penalty], dtype=np.float32), "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), - "vocab_mask": np.ones((vocab_size), dtype=np.float32) + "vocab_mask": np.ones((vocab_size), dtype=np.int32) } test_data_dir = Path(args.output).parent.as_posix() @@ -348,7 +357,10 @@ def main(): print('You have 15 seconds to attach a debugger.') time.sleep(15) - convert_model(args) + if os.path.exists(args.output): + print(f"skip conversion since path existed: {args.output}") + else: + convert_model(args) test_model(args) From a6c402fe4ad8f00d04493e6e93a9ca8083658c5a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 8 Nov 2021 14:45:16 -0800 Subject: [PATCH 05/53] add check data type NULL --- .../core/optimizer/insert_cast_transformer.cc | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 82d8a501753aa..6741cc04ea6f1 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -176,17 +176,19 @@ enum TypeGroup { }; TypeGroup GetTypeGroup(DataType type) { - if (*type == "tensor(bool)") { - return Bool; - } + if (type != nullptr) { + if (*type == "tensor(bool)") { + return Bool; + } - if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)" || - *type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { - return Integer; - } + if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)" || + *type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + return Integer; + } - if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { - return Float; + if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { + return Float; + } } return Unknown; From 147763c6102cabd0bd1d210bc611a3e719528306 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 8 Nov 2021 14:46:00 -0800 Subject: [PATCH 06/53] applies VerifyNodeAndOpMatch to subgraph --- onnxruntime/core/graph/graph.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index eba88da49e1f7..dd60991570c38 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2543,6 +2543,15 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { } } + // verify subgraphs + for (auto node_index : nodes_in_topological_order_) { + auto& node = *GetNode(node_index); + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + Graph* subgraph = entry.second; + ORT_RETURN_IF_ERROR(subgraph->VerifyNodeAndOpMatch(options)); + } + } + return Status::OK(); } From bd10853afc479d27791ff2d765638723371d4809 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 9 Nov 2021 20:21:03 -0800 Subject: [PATCH 07/53] update input_ids shape --- onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 12 ++++-------- .../python/tools/transformers/convert_beam_search.py | 10 +++++----- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e0d6af8c5138e..42200f34edb92 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -559,7 +559,7 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { } // Shape inference - // input 0 (input_ids) shape: (batch_size * num_beams, sequence_length) + // input 0 (input_ids) shape: (batch_size, sequence_length) // output 0 (sequences) shape: (batch_size * num_return_sequences, max_length) // output 1 (sequences_scores) shape: (batch_size * num_return_sequences) // output 2 (scores) shape: (max_length-sequence_length, batch_size*num_beams*num_return_sequences, vocab_size) @@ -574,7 +574,8 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value())) { return; } - int64_t batch_beam_size = input_ids_dims[0].dim_value(); + + int64_t batch_size = input_ids_dims[0].dim_value(); int64_t sequence_length = input_ids_dims[1].dim_value(); const auto max_length = ctx.getInputData(1); @@ -599,11 +600,6 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { fail_shape_inference("Failed to parse num_return_sequences or it is not positive integer scalar"); } - if (batch_beam_size % num_beams_value != 0) { - fail_shape_inference("input_ids dimension 0 shall be multiple of num_beams"); - } - - int64_t batch_size = batch_beam_size / num_beams_value; ONNX_NAMESPACE::TensorShapeProto sequences_shape; sequences_shape.add_dim()->set_dim_value(batch_size * num_beams_value); sequences_shape.add_dim()->set_dim_value(batch_size * sequence_length); @@ -637,7 +633,7 @@ void RegisterTextGenerationSchemas() { "body", "The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output", AttributeProto::GRAPH) - .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size * num_beams, sequence_length)", "I") + .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 71fba8593ed9e..9c98bce6e649d 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -318,6 +318,11 @@ def test_model(args): print('-' * 50) print("Test ONNX model and bream search with onnxruntime...") + # TODO: remove debug code + import time + print('You have 15 seconds to attach a debugger.') + time.sleep(15) + ort_session = create_ort_session(args.output, args.use_gpu) batch_size = 2 @@ -352,11 +357,6 @@ def test_model(args): def main(): args = parse_arguments() - # TODO: remove debug code - import time - print('You have 15 seconds to attach a debugger.') - time.sleep(15) - if os.path.exists(args.output): print(f"skip conversion since path existed: {args.output}") else: From 4ba5d0014771692fd1da00f9b41b86a7dbc1240e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 9 Nov 2021 20:22:45 -0800 Subject: [PATCH 08/53] Add node name for Cast node --- onnxruntime/python/tools/transformers/convert_to_onnx.py | 1 - onnxruntime/python/tools/transformers/fusion_utils.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_to_onnx.py b/onnxruntime/python/tools/transformers/convert_to_onnx.py index 500a3f3b90c8e..28120d5e80837 100644 --- a/onnxruntime/python/tools/transformers/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/convert_to_onnx.py @@ -17,7 +17,6 @@ import os import argparse -import coloredlogs import logging import torch import numpy diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index ae0587bd24933..760d26032a3ea 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -36,7 +36,8 @@ def cast_input_to_int32(self, input_name: str): if parent_node and parent_node.op_type == 'Cast': inputs = [parent_node.input[0]] - cast_node = helper.make_node('Cast', inputs=inputs, outputs=[cast_output]) + node_name = self.model.create_node_name('Cast') + cast_node = helper.make_node('Cast', inputs=inputs, outputs=[cast_output], name=node_name) cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.INT32))]) self.model.add_node(cast_node) From e343472cc7e5b87207bbc7f8beb3b7499e05e9e7 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 15 Nov 2021 12:34:51 -0800 Subject: [PATCH 09/53] expose API for topk --- onnxruntime/core/providers/cpu/math/top_k.cc | 50 ++++++++++++++++++++ onnxruntime/core/providers/cpu/math/top_k.h | 7 +++ 2 files changed, 57 insertions(+) diff --git a/onnxruntime/core/providers/cpu/math/top_k.cc b/onnxruntime/core/providers/cpu/math/top_k.cc index fb38cbe98b968..b2c7ac1b032af 100644 --- a/onnxruntime/core/providers/cpu/math/top_k.cc +++ b/onnxruntime/core/providers/cpu/math/top_k.cc @@ -370,6 +370,56 @@ static Status TopKImpl(OpKernelContext* p_op_kernel_context, const Tensor* input return Status::OK(); } +// Wrapper over core TopK implementation +template +Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted, + AllocatorPtr allocator, + onnxruntime::concurrency::ThreadPool* threadpool, + std::unique_ptr& output_values, + std::unique_ptr& output_indices) { + const TensorShape& input_shape = input->Shape(); + + // Will return axis_ as is if positive or fixes it in case it is negative + const auto axis_parsed = HandleNegativeAxis(axis, static_cast(input_shape.NumDimensions())); + + // Check to ensure k is within the bounds of what is available in that specific axis + if (input_shape[axis_parsed] < k) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "k argument [", k, + "] should not be greater than specified axis dim value [", input_shape[axis_parsed], "]"); + } + + // Resize output tensors to be the same shape as the input except + // for the specified dimension ((i.e.) axis_parsed), which will be of size k. E.x. for an input tensor + // of shape [3, 4, 5] and k=2 with axis_parsed=1, both of the outputs will be shape [3, 2, 5] + TensorShape output_shape = input_shape; + output_shape[axis_parsed] = k; + + output_values = Tensor::Create(input->DataType(), output_shape, allocator); + output_indices = Tensor::Create(DataTypeImpl::GetType(), output_shape, allocator); + + // no-op - no output buffers to fill - return silently + if (k == 0) { + return Status::OK(); + } + + if (largest) { + FindTopKElements>(input, input_shape, output_values.get(), output_indices.get(), output_shape, k, sorted, + gsl::narrow_cast(axis_parsed), threadpool); + } else { + FindTopKElements>(input, input_shape, output_values.get(), output_indices.get(), output_shape, k, sorted, + gsl::narrow_cast(axis_parsed), threadpool); + } + + return Status::OK(); +} + +// explicit instantiation +template Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted, + AllocatorPtr allocator, + onnxruntime::concurrency::ThreadPool* threadpool, + std::unique_ptr& output_values, + std::unique_ptr& output_indices); + // Opset ver - 1 to 9 static void TopkOpset9ConstructorCommon(const OpKernelInfo& op_kernel_info, int& axis, unsigned int& k) { diff --git a/onnxruntime/core/providers/cpu/math/top_k.h b/onnxruntime/core/providers/cpu/math/top_k.h index 6cea1c7d6ed96..596f10a8e6f9f 100644 --- a/onnxruntime/core/providers/cpu/math/top_k.h +++ b/onnxruntime/core/providers/cpu/math/top_k.h @@ -19,4 +19,11 @@ class TopK final : public OpKernel { bool largest_; // opset-11 only bool sorted_; // opset-11 only }; + +template +Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted, + AllocatorPtr allocator, + onnxruntime::concurrency::ThreadPool* threadpool, + std::unique_ptr& output_values, + std::unique_ptr& output_indices); } // namespace onnxruntime \ No newline at end of file From 1c2d9cd95ab550a9451d52f7a2f69495986ac0b0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 15 Nov 2021 12:35:52 -0800 Subject: [PATCH 10/53] parse parameters --- .../transformers/beam_search_parameters.cc | 63 +++++++++++++++++++ .../cpu/transformers/beam_search_parameters.h | 43 +++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc create mode 100644 onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc new file mode 100644 index 0000000000000..a830a1da86830 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "beam_search_parameters.h" + +namespace onnxruntime { +namespace contrib { + +Status BeamSearchParameters::Validate() { + ORT_RETURN_IF(eos_token_id < 0, "eos_token_id is invalid"); + ORT_RETURN_IF(pad_token_id < 0, "pad_token_id is invalid"); + return Status::OK(); +} + +void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) { + early_stopping = info.GetAttrOrDefault("early_stopping", 0) == 1; + eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1)); + pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1)); + no_repeat_ngram_size = static_cast(info.GetAttrOrDefault("no_repeat_ngram_size", 0)); +} + +void BeamSearchParameters::ParseFromInputs(OpKernelContext* context){ + ORT_ENFORCE(context != nullptr); + const Tensor* input_ids = context->Input(0); + const auto& dims = input_ids->Shape().GetDims(); + if (dims.size() == 2) { + batch_size = static_cast(dims[0]); + sequence_length = static_cast(dims[1]); + } else { + batch_size = 0; + sequence_length = 0; + } + + auto* max_length_tensor = context->Input(1); + max_length = max_length_tensor ? static_cast(*max_length_tensor->Data()) : 4096; + + auto* min_length_tensor = context->Input(2); + min_length = min_length_tensor ? static_cast(*min_length_tensor->Data()) : 0; + + auto* num_beams_tensor = context->Input(3); + num_beams = num_beams_tensor ? static_cast(*num_beams_tensor->Data()) : 1; + + auto* num_return_sequences_tensor = context->Input(4); + num_return_sequences = num_return_sequences_tensor ? static_cast(*num_return_sequences_tensor->Data()) : 1; + + auto* temperature_tensor = context->Input(5); + temperature = temperature_tensor ? static_cast(*temperature_tensor->Data()) : 1; + + auto* length_penalty_tensor = context->Input(6); + length_penalty = length_penalty_tensor ? static_cast(*length_penalty_tensor->Data()) : 1; + + auto* repetition_penalty_tensor = context->Input(7); + repetition_penalty = repetition_penalty_tensor ? static_cast(*repetition_penalty_tensor->Data()) : 1.0f; +} + +void BeamSearchParameters::SetSubgraphParameters(int heads, int hidden_size_per_head, int vocabulary_size, int layers){ + num_heads = heads; + head_size = hidden_size_per_head; + vocab_size = vocabulary_size; + num_layers = layers; +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h new file mode 100644 index 0000000000000..95753e1d797c9 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +struct BeamSearchParameters { + // from node attributes + int eos_token_id; + int pad_token_id; + int no_repeat_ngram_size; + bool early_stopping; + + // from inputs + int min_length; + int max_length; + int num_beams; + int num_return_sequences; + float temperature; + float length_penalty; + float repetition_penalty; + int batch_size; // deduce from first dimension of input_ids + int sequence_length; // deduce from second dimension of input_ids + + // deduce from subgraph + int num_heads; + int head_size; + int vocab_size; + int num_layers; + + Status Validate(); + + void ParseFromAttributes(const OpKernelInfo& info); + void ParseFromInputs(OpKernelContext* context); + void SetSubgraphParameters(int num_heads, int head_size, int vocab_size, int num_layers); +}; + +} // namespace contrib +} // namespace onnxruntime From 820a53cd15da7f48b0d2734b250387f3b770a2e3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 17 Nov 2021 11:37:40 -0800 Subject: [PATCH 11/53] Add beam search scorer --- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + .../cpu/transformers/beam_earch.cc | 937 ++++++++++++++++++ .../cpu/transformers/beam_search.h | 134 +++ .../cpu/transformers/beam_search_scorer.cc | 198 ++++ .../cpu/transformers/beam_search_scorer.h | 136 +++ .../cpu/transformers/dump_tensor.cc | 37 + .../cpu/transformers/dump_tensor.h | 151 +++ .../tools/transformers/convert_beam_search.py | 2 +- 8 files changed, 1596 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/contrib_ops/cpu/transformers/beam_earch.cc create mode 100644 onnxruntime/contrib_ops/cpu/transformers/beam_search.h create mode 100644 onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc create mode 100644 onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h create mode 100644 onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc create mode 100644 onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 702b5bcd67f5c..61bfe8b862693 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -12,6 +12,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv); @@ -183,6 +184,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { // add more kernels here BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_earch.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_earch.cc new file mode 100644 index 0000000000000..804f14b8e0af0 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_earch.cc @@ -0,0 +1,937 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// there's no way to use a raw pointer as the copy destination with std::copy_n +// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset +// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +#include "core/providers/cpu/controlflow/utils.h" +#include "core/providers/cpu/math/top_k.h" +#include "core/framework/allocator.h" +#include "core/framework/framework_common.h" +#include "core/framework/op_kernel_context_internal.h" +#include "core/framework/session_state.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/framework/session_options.h" +#include "core/framework/TensorSeq.h" +#include "gsl/gsl" +#include "core/providers/cpu/math/softmax_shared.h" +#include "beam_search.h" +#include "dump_tensor.h" + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + BeamSearch, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + BeamSearch); + +REGISTER_KERNEL_TYPED(float) + +// CPU does not support float16 +// REGISTER_KERNEL_TYPED(MLFloat16) + +GptSubgraphInfo::GptSubgraphInfo(const onnxruntime::Node& node, const GraphViewer& subgraph_in) + : subgraph(subgraph_in) { + num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); + + auto& subgraph_inputs = subgraph.GetInputs(); + auto& subgraph_outputs = subgraph.GetOutputs(); + + // inputs: input_ids, position_ids, attention_mask, past_0, past_1, ... + // outputs: logits, present_0, present_1, ... + num_subgraph_inputs = static_cast(subgraph_inputs.size()); + num_subgraph_outputs = static_cast(subgraph_outputs.size()); + + // CheckSubgraph will verify inputs and outputs later. + subgraph_input_names.reserve(num_subgraph_inputs); + for (int i = 0; i < num_subgraph_inputs; ++i) { + subgraph_input_names.push_back(subgraph_inputs[i]->Name()); + } + + subgraph_output_names.reserve(num_subgraph_outputs); + for (int i = 0; i < num_subgraph_outputs; ++i) { + subgraph_output_names.push_back(subgraph_outputs[i]->Name()); + } +} + +void Sequences::Init(const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length) { + // Allocate buffer (shall we use allocator instead?) + sequences[0].assign(batch_beam_size * max_length, 0); + sequences[1].assign(batch_beam_size * max_length, 0); + + // copying input_ids to sequences[0] + gsl::span input = input_ids.Get().DataAsSpan(); + gsl::span output(sequences[0]); + for (int i = 0; i < batch_beam_size; i++) { + gsl::span source = input.subspan(i * sequence_length, sequence_length); + gsl::span target = output.subspan(i * max_length, sequence_length); + gsl::copy(source, target); + } + current_sequences_buffer = 0; + + batch_beam_size_ = batch_beam_size; + max_length_ = max_length; + current_length_ = sequence_length; +} + +gsl::span Sequences::GetSequence(int beam_index) { + gsl::span buffer(sequences[current_sequences_buffer]); + gsl::span sequence = buffer.subspan(beam_index * max_length_, current_length_); + return sequence; +} + +int Sequences::GetSequenceLength() { + return current_length_; +} + +void Sequences::PrintSequences() { +#ifdef DEBUG_BEAM_SEARCH + std::cout << "sequences:" << std::endl; + for (int i = 0; i < batch_beam_size_; i++) { + gsl::span sequence = GetSequence(i); + std::string beam_index = std::to_string(i); + DumpTensor(beam_index.c_str(), sequence.data(), 1, current_length_); + } +#endif +} + +void Sequences::AppendNextTokenToSequences( + gsl::span& beam_indices, + gsl::span& beam_next_tokens) { + //sequences = torch.cat([sequences[beam_indices, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + gsl::span input(sequences[current_sequences_buffer]); + gsl::span output(sequences[1 - current_sequences_buffer]); + + for (int i = 0; i < batch_beam_size_; i++) { + int beam_index = static_cast(beam_indices[i]); + gsl::span source = input.subspan(beam_index * max_length_, current_length_); + gsl::span target = output.subspan(i * max_length_, current_length_); + gsl::copy(source, target); + } + + // append next token to each beam + for (int i = 0; i < batch_beam_size_; i++) { + output[i * max_length_ + current_length_] = beam_next_tokens[i]; + } + + ++current_length_; + current_sequences_buffer = 1 - current_sequences_buffer; // rotate buffer for next round +} + +template +class BeamSearchImpl { + public: + BeamSearchImpl(OpKernelContextInternal& context, + const SessionState& session_state, + const GptSubgraphInfo& info, + concurrency::ThreadPool* thread_pool, + void* stream, + BeamSearchParameters& params); + + // Initialize by validating all the inputs, and allocating the output tensors + Status Initialize(); + + // Execute the batch, by iterating the sequence in each batch entry + // and calling the subgraph with each item in the sequence. + Status Execute(const FeedsFetchesManager& cached_ffm); + + private: + Status CheckInputs(const OpKernelContextInternal& context); + + Status CheckSubgraph(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) const; + + OrtValue ExpandInputs(const OrtValue& input_ids, int num_beams) const; + + // Prepare the inputs for first inference of subgraph + void CreateInitialFeeds(std::vector& feeds); + + // Process Logits and Update the input for next iteration. + Status ProcessLogitsAndUpdateFeeds( + const std::vector& last_outputs, + std::vector& next_inputs, + int current_length); + + // Process logits and append next tokens to sequences + Status GenerateNextToken(const OrtValue& logits, + gsl::span& beam_next_tokens, + gsl::span& beam_indices); + + Status ProcessLogits(const OrtValue& logits, + BeamSearchState& beam_state, + int top_k, + AllocatorPtr& allocator); + + void ProcessNextTokenScores(gsl::span& next_token_scores); + + // Reorder cache by picking the past state based on beam indices + void PickPastState(const std::vector& last_outputs, + std::vector& next_inputs, + gsl::span& beam_indices); + + OpKernelContextInternal& context_; + const SessionState& session_state_; + const GptSubgraphInfo& subgraph_info_; + + concurrency::ThreadPool* thread_pool_; + + const std::vector& implicit_inputs_; + + std::vector next_positions_; + + // Not used in CPU. Stream is for CUDA only. + void* stream_; + + BeamSearchParameters* parameters_; + + std::unique_ptr beam_scorer_; + + BeamSearchState beam_state_; + + AllocatorPtr allocator_; +}; + +template +void BeamSearch::Init(const OpKernelInfo& info) { + // make sure the attribute was present even though we don't need it here. + // The GraphProto is loaded as a Graph instance by main Graph::Resolve, + // and a SessionState instance for executing the subgraph is created by InferenceSession. + // This is available via Info().GetSubgraphSessionState("attribute_name") when Compute is called. + ONNX_NAMESPACE::GraphProto proto; + + ORT_ENFORCE(info.GetAttr("body", &proto).IsOK()); + ORT_IGNORE_RETURN_VALUE(proto); + + parameters_.ParseFromAttributes(info); + + stream_ = nullptr; +} + +template +std::unique_ptr BeamSearch::Create(const OpKernelInfo& info, + void* stream) { + auto result = std::make_unique(info); + result->SetComputeStream(stream); + return result; +} + +template +common::Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, + const std::string& attribute_name, + const SessionState& subgraph_session_state) { + ORT_ENFORCE(subgraph_info_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + ORT_UNUSED_PARAMETER(attribute_name); + + const auto& node = Node(); + subgraph_info_ = std::make_unique(node, subgraph_session_state.GetGraphViewer()); + + ORT_RETURN_IF(subgraph_info_->num_subgraph_outputs <= 1, + "Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in inputs and outputs)."); + + ORT_RETURN_IF(subgraph_info_->num_subgraph_inputs != subgraph_info_->num_subgraph_outputs + 2, + "Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2"); + + std::vector feed_names; + feed_names.reserve(subgraph_info_->num_subgraph_inputs + subgraph_info_->num_implicit_inputs); + + // First, get the location of input_ids of current operator. + const auto& node_inputs = node.InputDefs(); + const OrtMemoryInfo& input_ids_location = utils::FindMemoryInfoForValue(session_state, node_inputs[0]->Name()); + + // position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. + // as we skip them when we call FindDevicesForValues, and default them to be in the same device as input_ids + feed_names.insert(feed_names.end(), subgraph_info_->subgraph_input_names.begin(), subgraph_info_->subgraph_input_names.end()); + + for (auto& entry : node.ImplicitInputDefs()) { + feed_names.push_back(entry->Name()); + } + + std::vector feed_locations; + feed_locations.resize(feed_names.size()); + + for (size_t i = 0, end = feed_names.size(); i < end; ++i) { + if (i >= subgraph_info_->subgraph_input_names.size()) { // implicit inputs + const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]); + feed_locations[i] = location.device; + } else { + feed_locations[i] = input_ids_location.device; + } + } + + std::unique_ptr ffm; + ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_info_->subgraph_output_names, + subgraph_session_state.GetOrtValueNameIdxMap(), ffm)); + ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm)); + + // setup the locations where we want the subgraph output to end up on + std::vector fetch_locations; + fetch_locations.reserve(subgraph_info_->num_subgraph_outputs); + + // past state need to be where we can feed them in to the next iteration, so set the fetch location to match the feed location. + for (int i = 0; i < subgraph_info_->num_subgraph_outputs; ++i) { + fetch_locations.push_back(&input_ids_location); + } + + utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations); + + feeds_fetches_manager_ = std::move(ffm); + + return Status::OK(); +} + +template +Status BeamSearch::Compute(OpKernelContext* ctx) const { + auto* ctx_internal = static_cast(ctx); + auto* session_state = ctx_internal->SubgraphSessionState("body"); + ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'body' attribute."); + ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); + + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + BeamSearchParameters parameters = parameters_; // make a copy + + BeamSearchImpl impl{*ctx_internal, *session_state, *subgraph_info_, thread_pool, stream_, parameters}; + + auto status = impl.Initialize(); + ORT_RETURN_IF_ERROR(status); + + status = impl.Execute(*feeds_fetches_manager_); + + return status; +} + +template +BeamSearchImpl::BeamSearchImpl(OpKernelContextInternal& context, + const SessionState& session_state, + const GptSubgraphInfo& subgraph_info, + concurrency::ThreadPool* thread_pool, + void* stream, + BeamSearchParameters& params) + : context_(context), + session_state_(session_state), + subgraph_info_(subgraph_info), + thread_pool_(thread_pool), + implicit_inputs_(context_.GetImplicitInputs()), + stream_(stream), + parameters_(¶ms), + allocator_(nullptr) { + parameters_->ParseFromInputs(&context); + + allocator_ = session_state.GetExecutionProviders() + .Get(onnxruntime::kCpuExecutionProvider) + ->GetAllocator(0, OrtMemTypeDefault); +} + +template +Status BeamSearchImpl::CheckInputs(const OpKernelContextInternal& context) { + // Input shapes: + // input_ids : (batch_size, sequence_length) + // vocab_mask : (vocab_size) or nullptr + + const Tensor* input_ids = context.Input(0); + const auto& dims = input_ids->Shape().GetDims(); + if (dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input_ids' is expected to have 2 dimensions, got ", + dims.size()); + } + + const Tensor* vocab_mask = context.Input(8); + if (vocab_mask != nullptr) { // vocab_mask is optional + const auto& vocab_mask_dims = vocab_mask->Shape().GetDims(); + if (vocab_mask_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' is expected to have 1 dimension, got ", + vocab_mask_dims.size()); + } + if (static_cast(vocab_mask_dims[0]) != parameters_->vocab_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' shape does not match with vocab_size, got ", + vocab_mask_dims[0]); + } + } + + return Status::OK(); +} + +template +Status BeamSearchImpl::CheckSubgraph(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) const { + ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ", + subgraph_inputs[0]->Name()); + ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", "subgraph input 1 shall be named as position_ids, got: ", + subgraph_inputs[1]->Name()); + ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", "subgraph input 2 shall be named as attention_mask, got: ", + subgraph_inputs[2]->Name()); + ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", "subgraph input 3 shall be named as past_0, got: ", + subgraph_inputs[3]->Name()); + + // Past state shape is like (2, batch_size, 12, past_seq_len, 64). Here 12 and 64 are constants of num_heads and hidden_size/num_heads. + const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape(); + ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ", + past_shape->dim_size()); + + ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2, + "subgraph past state dimension 0 shall have length of 2"); + + ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for number of heads"); + + ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0, + "subgraph past state dimension 4 shall have a positive value for hidden size per head"); + + // check subgraph outputs + ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", "subgraph output 0 shall be named as logits, got: ", + subgraph_outputs[0]->Name()); + + ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", "subgraph input 1 shall be named as present_0, got: ", + subgraph_outputs[1]->Name()); + + // Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size. + const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); + ORT_RETURN_IF(logits_shape->dim_size() != 3, "subgraph logits output is expected to have 3 dimension, got ", + logits_shape->dim_size()); + + ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for vocabulary size"); + + int num_heads = static_cast(past_shape->dim(2).dim_value()); + int head_size = static_cast(past_shape->dim(4).dim_value()); + int vocab_size = static_cast(logits_shape->dim(2).dim_value()); + int num_layers = static_cast(subgraph_outputs.size()) - 1; + parameters_->SetSubgraphParameters(num_heads, head_size, vocab_size, num_layers); + + return Status::OK(); +} + +template +Status BeamSearchImpl::Initialize() { + auto status = Status::OK(); + +#define CHECK_SCALAR_INPUT(name, index, required) \ + auto* name##_tensor = context_.Input(index); \ + if (name##_tensor) { \ + if (!name##_tensor->Shape().IsScalar()) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \ + name##_tensor->Shape()); \ + } \ + } else if (required) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \ + } + + CHECK_SCALAR_INPUT(min_length, 1, false); + + CHECK_SCALAR_INPUT(max_length, 2, true); + + CHECK_SCALAR_INPUT(num_beams, 3, true); + + CHECK_SCALAR_INPUT(num_return_sequences, 4, true); + + CHECK_SCALAR_INPUT(temperature, 5, true); + + CHECK_SCALAR_INPUT(length_penalty, 6, true); + + ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, "'num_return_sequences' has to be smaller or equal to 'num_beams'."); + + auto& inputs = subgraph_info_.subgraph.GetInputs(); + auto& outputs = subgraph_info_.subgraph.GetOutputs(); + ORT_RETURN_IF_ERROR(CheckSubgraph(inputs, outputs)); + + // CheckInputs shall be after CheckSubgraph due to its dependency on vocab_size + ORT_RETURN_IF_ERROR(CheckInputs(context_)); + + return status; +} + +template +OrtValue BeamSearchImpl::ExpandInputs(const OrtValue& input, int num_beams) const { + if (num_beams == 1) + return input; + + // Given input of shape (batch_size, sequence_length), expand the shape to be (batch_size * num_beams, sequence_length) + const TensorShape& input_shape = input.Get().Shape(); + ORT_ENFORCE(input_shape.NumDimensions() == 2 && input_shape[0] == parameters_->batch_size && input_shape[1] == parameters_->sequence_length); + + const int64_t& batch_size = input_shape[0]; + const int64_t& sequence_length = input_shape[1]; + int64_t dims[] = {batch_size * num_beams, sequence_length}; + TensorShape expanded_shape(&dims[0], 2); + + auto element_type = DataTypeImpl::GetType(); + OrtValue expanded; + Tensor::InitOrtValue(element_type, expanded_shape, allocator_, expanded); + + const int64_t* input_data = input.Get().Data(); + int64_t* expanded_data = expanded.GetMutable()->MutableData(); + int64_t* target = expanded_data; + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < num_beams; j++) { + memcpy(target, input_data + i * sequence_length, sizeof(int64_t) * sequence_length); + target += sequence_length; + } + } + + return expanded; +} + +template +void BeamSearchImpl::CreateInitialFeeds(std::vector& feeds) { + // Subgraph inputs: + // input_ids: shape (B, S) wher B is batch size, and S is sequence length + // position_ids: shape (B, S) + // attention_mask: shape (B, P+S), where past_sequence_length (P) is 0 + // After expansion, their shapes will become (B, M*S), where M is num_beams. + + const OrtValue* input_ids = context_.GetInputOrtValue(0); + + const Tensor& input_ids_tensor = input_ids->Get(); + + const TensorShape& input_ids_shape = input_ids_tensor.Shape(); + ORT_ENFORCE(input_ids_shape.NumDimensions() == 2); + const int64_t& batch_size = input_ids_shape[0]; + const int64_t& sequence_length = input_ids_shape[1]; + + // Allocate position_ids and attention_mask based on shape of input_ids + auto element_type = DataTypeImpl::GetType(); + + // input_ids for subgraph is int64, so we need Cast input_ids from int32 to int64. + OrtValue subgraph_input_ids; + // Current shape is (batch_size, sequence_length) + // Note that we will expand it to (batch_size * num_beams, sequence_length) later. + Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, subgraph_input_ids); + + int64_t* subgraph_input_data = subgraph_input_ids.GetMutable()->MutableData(); + const int32_t* source = input_ids_tensor.Data(); + int64_t* target = subgraph_input_data; + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < sequence_length; j++, source++, target++) { + *target = static_cast(*source); + } + } + + OrtValue position_ids; + Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, position_ids); + + OrtValue attention_mask; + Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, attention_mask); + + next_positions_.resize(batch_size * parameters_->num_beams); + // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. + // Set position id to be 0 for pad tokens, and cumulated sum of mask in a batch for other tokens + int64_t* mask_data = attention_mask.GetMutable()->MutableData(); + int64_t* position_data = position_ids.GetMutable()->MutableData(); + source = input_ids_tensor.Data(); + int64_t* mask = mask_data; + int64_t* position = position_data; + for (int i = 0; i < batch_size; i++) { + int64_t abs_position = 0; + for (int j = 0; j < sequence_length; j++, source++, mask++, position++) { + if (*source == parameters_->pad_token_id) { + *mask = 0; + *position = 0; + } else { + *mask = 1; + *position = abs_position; + abs_position++; + } + } + for (int k = 0; k < parameters_->num_beams; k++) { + next_positions_[i * parameters_->num_beams + k] = abs_position; + } + } + + // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) for input_ids, position_ids and attention_mask + // TODO: Try expand inputs/outputs after first subgraph call instead. That may get better peroformance, but more complex to implement. + OrtValue expanded_input_ids = ExpandInputs(subgraph_input_ids, parameters_->num_beams); + OrtValue expanded_position_ids = ExpandInputs(position_ids, parameters_->num_beams); + OrtValue expanded_attention_mask = ExpandInputs(attention_mask, parameters_->num_beams); + + // Initialize empty past state + auto past_type = DataTypeImpl::GetType(); + int64_t past_state_dims[] = {2, batch_size * parameters_->num_beams, parameters_->num_heads, 0, parameters_->head_size}; + TensorShape past_shape(&past_state_dims[0], 5); + OrtValue empty_past; + Tensor::InitOrtValue(past_type, past_shape, allocator_, empty_past); + + // The ordering is the same as used in SetupSubgraphExecutionInfo + feeds.reserve(subgraph_info_.num_subgraph_inputs + subgraph_info_.num_implicit_inputs); + feeds.push_back(expanded_input_ids); + feeds.push_back(expanded_position_ids); + feeds.push_back(expanded_attention_mask); + + // The remaing inputs are past state. + for (int i = 3; i < subgraph_info_.num_subgraph_inputs; ++i) { + feeds.push_back(empty_past); + } + + // pass in implicit inputs + for (const auto* entry : implicit_inputs_) { + feeds.push_back(*entry); + } +} + +template +Status BeamSearchImpl::ProcessLogits( + const OrtValue& logits, // logits output of subgraph + BeamSearchState& beam_state, + int top_k, + AllocatorPtr& allocator) { + const int64_t batch_beam_size = static_cast(parameters_->batch_size * parameters_->num_beams); + const int& vocab_size = parameters_->vocab_size; + +#ifdef DEBUG_BEAM_SEARCH + //DumpOrtValue("input_ids", input_ids); + DumpOrtValue("logits", logits); +#endif + + const float* logits_data = logits.Get().Data(); + + const TensorShape& logits_shape = logits.Get().Shape(); + ORT_ENFORCE(logits_shape.NumDimensions() == 3); + + // The sequence length of input_ids for the logits. + // It equals parameters_->sequence_length for first subgraph call, and 1 for the remaining. + auto input_length = logits_shape[1]; + + // Get logits for the last token, where logits has shape (batch_size * num_beams, input_length, vocab_size) + // next_token_logits = logits[:, -1, :], where its shape is (batch_size * num_beams, vocab_size) + // When input_length == 1, use logits directly to avoid copy logits to next_token_logits. + auto next_token_logits = gsl::make_span(beam_state.next_token_logits); + if (input_length > 1) { + const float* current_logits = logits_data + (input_length - 1) * vocab_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span source(current_logits, vocab_size); + gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); + gsl::copy(source, target); + current_logits += i * (input_length * vocab_size); + } + } + + // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) + auto next_token_scores = gsl::make_span(beam_state.next_token_scores); + Status status = SoftmaxCPU(batch_beam_size, // rows + vocab_size, // elements per row + input_length > 1 ? next_token_logits.data() : logits_data, + next_token_scores.data(), + true, + thread_pool_); + if (!status.IsOK()) { + return status; + } + + // Extra processing: next_token_scores = logits_processor(input_ids, next_token_scores) + // where input_ids is current sequences in beam_state_ + ProcessNextTokenScores(next_token_scores); + + // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) + // TODO: use thread pool to parrellel + int offset = 0; + int batch_beam_index = 0; + for (int i = 0; i < parameters_->batch_size; i++) { + for (int j = 0; j < parameters_->num_beams; j++, batch_beam_index++) { + for (int k = 0; k < parameters_->vocab_size; k++, offset++) { + next_token_scores[offset] += beam_state.beam_scores[batch_beam_index]; + } + } + } + + // TODO: Store scores only when required + // if output_scores: + // scores += (next_token_scores,) + beam_state.scores.insert(beam_state.scores.end(), next_token_scores.begin(), next_token_scores.end()); + + //next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + //next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True) + int64_t next_token_scores_dims[] = {parameters_->batch_size, parameters_->num_beams * vocab_size}; + TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue next_token_scores_value; + Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value); + const Tensor& input = next_token_scores_value.Get(); + +#ifdef DEBUG_BEAM_SEARCH + DumpOrtValue("next_token_scores_value", next_token_scores_value); +#endif + + const int axis = 1; + const unsigned k = static_cast(top_k); + const bool largest = true; + const bool sorted = true; // results returned in sorted order. + + std::unique_ptr topk_scores; + std::unique_ptr topk_indices; + status = GetTopK(&input, axis, top_k, largest, sorted, allocator, thread_pool_, topk_scores, topk_indices); + if (!status.IsOK()) { + return status; + } + +#ifdef DEBUG_BEAM_SEARCH + DumpTensor("topk_scores", *(topk_scores.get())); + DumpTensor("topk_indices", *(topk_indices.get())); +#endif + + //next_indices = (next_tokens / vocab_size).long() + //next_tokens = next_tokens % vocab_size + gsl::span next_token_indices = topk_indices->DataAsSpan(); + beam_state.next_indices.resize(parameters_->batch_size * k); + beam_state.next_tokens.resize(parameters_->batch_size * k); + offset = 0; + for (int i = 0; i < parameters_->batch_size; i++) { + for (unsigned int j = 0; j < k; j++, offset++) { + beam_state.next_indices[offset] = next_token_indices[offset] / vocab_size; + beam_state.next_tokens[offset] = next_token_indices[offset] % vocab_size; + } + } + + gsl::span next_scores = topk_scores->DataAsSpan(); + gsl::span next_tokens(beam_state.next_tokens.data(), beam_state.next_tokens.size()); + gsl::span next_indices(beam_state.next_indices.data(), beam_state.next_indices.size()); + +#ifdef DEBUG_BEAM_SEARCH + DumpTensor("next_scores before scorer", next_scores.data(), parameters_->batch_size, k); + DumpTensor("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, k); + DumpTensor("next_indices before scorer", next_indices.data(), parameters_->batch_size, k); +#endif + + beam_scorer_->Process( + &(beam_state.sequences), + next_scores, //next_token_scores, + next_tokens, + next_indices, + allocator); + + return Status::OK(); +} + +template +Status BeamSearchImpl::GenerateNextToken( + const OrtValue& logits, + gsl::span& beam_next_tokens, + gsl::span& beam_indices) { + // Process logits to get next token scores, and select top_k = 2 * num_beams + // TODO: we might not need 2 * num_beams when logits processors does not update token scores. + const int top_k = 2 * parameters_->num_beams; + ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state_, top_k, allocator_)); + + gsl::span& beam_scores = beam_scorer_->GetNextScores(); + // TODO: may not need clone beam_scores. + beam_state_.beam_scores.assign(beam_scores.begin(), beam_scores.end()); + + beam_next_tokens = beam_scorer_->GetNextTokens(); + beam_indices = beam_scorer_->GetNextIndices(); + +#ifdef DEBUG_BEAM_SEARCH + DumpTensor("beam_scores after scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams); + DumpTensor("beam_next_tokens after scorer", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams); + DumpTensor("beam_indices after scorer", beam_indices.data(), parameters_->batch_size, parameters_->num_beams); +#endif + + beam_state_.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens); + +#ifdef DEBUG_BEAM_SEARCH + beam_state_.sequences.PrintSequences(); +#endif + return Status::OK(); +} + +template +void BeamSearchImpl::ProcessNextTokenScores(gsl::span& /*next_token_scores*/) { + return; +} + +template +void BeamSearchImpl::PickPastState(const std::vector& last_outputs, + std::vector& next_inputs, + gsl::span& beam_indices) { + for (int i = 3; i < subgraph_info_.num_subgraph_inputs; ++i) { + const OrtValue& present = last_outputs[i - 2]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64) + const TensorShape& past_shape = present.Get().Shape(); + + // Create a tensor with same shape. + OrtValue past; + auto past_type = DataTypeImpl::GetType(); // present.Type() + Tensor::InitOrtValue(past_type, past_shape, allocator_, past); + + auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4]; + auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4]; + + gsl::span past_span = past.GetMutable()->MutableDataAsSpan(); + gsl::span present_span = present.Get().DataAsSpan(); + for (gsl::index j = 0; j < beam_indices.length(); j++) { + int64_t beam_index = beam_indices[j]; + gsl::span present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); + gsl::span present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam); + + gsl::span past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam); + gsl::span past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam); + gsl::copy(present_key, past_key); + gsl::copy(present_value, past_value); + +#ifdef DEBUG_BEAM_SEARCH + if (i == 3) // only dump past_0 + { + DumpTensorName("past_key of beam", j, true); + DumpTensor(nullptr, past_key.data(), 1, static_cast(block_size_per_beam)); + + DumpTensorName("past_value of beam", j, true); + DumpTensor(nullptr, past_value.data(), 1, static_cast(block_size_per_beam)); + } +#endif + } + + next_inputs[i] = past; + } +} + +template +Status BeamSearchImpl::ProcessLogitsAndUpdateFeeds( + const std::vector& last_outputs, + std::vector& next_inputs, + int current_length) { + // last_outputs: logits, present_0, present_1, ... + // next_inputs: input_ids, position_id, attention_mask, past_0, past_1 + + // Process logits to get next token scores, and select top_k = 2 * num_beams + // TODO: we might not need 2 * num_beams when logits processors does not update token scores. + const OrtValue& logits = last_outputs[0]; + + gsl::span beam_next_tokens; + gsl::span beam_indices; + ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices)); + + // The following updates inputs for subgraph + // TODO: Reuse buffer for input_ids and position_ids to reduce memory allocation. + + // Update input_ids with next tokens. + int batch_beam_size = parameters_->batch_size * parameters_->num_beams; + int64_t dims[] = {batch_beam_size, 1}; + TensorShape input_ids_shape(&dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue input_ids; + Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, input_ids); + int64_t* input_ids_data = input_ids.GetMutable()->MutableData(); + for (int i = 0; i < batch_beam_size; i++) { + input_ids_data[i] = beam_next_tokens[i]; + } + next_inputs[0] = input_ids; + + // Update position IDs + OrtValue position_ids; + Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, position_ids); + int64_t* position_data = position_ids.GetMutable()->MutableData(); + for (int i = 0; i < batch_beam_size; i++) { + position_data[i] = next_positions_[i]; + next_positions_[i]++; + } + next_inputs[1] = position_ids; + + // Update attention mask + const OrtValue& old_mask = next_inputs[2]; + const int64_t* old_mask_data = old_mask.Get().Data(); + int64_t mask_dims[] = {batch_beam_size, current_length}; + TensorShape mask_shape(&mask_dims[0], 2); + OrtValue attention_mask; + Tensor::InitOrtValue(element_type, mask_shape, allocator_, attention_mask); + int64_t* mask_data = attention_mask.GetMutable()->MutableData(); + for (int i = 0; i < batch_beam_size; i++) { + for (int j = 0; j < current_length - 1; j++) { + mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j]; + } + mask_data[i * current_length + current_length - 1] = 1; + } + next_inputs[2] = attention_mask; + +#ifdef DEBUG_BEAM_SEARCH + DumpOrtValue("input_ids", input_ids); + DumpOrtValue("position_ids", position_ids); + DumpOrtValue("attention_mask", attention_mask); +#endif + + // Update past state + if (parameters_->num_beams == 1) { + // feed present_* output to past_* inputs one by one + for (int i = 3; i < subgraph_info_.num_subgraph_inputs; ++i) { + next_inputs[i] = last_outputs[i - 2]; + } + } else { + PickPastState(last_outputs, next_inputs, beam_indices); + } + + return Status::OK(); +} + +template +Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { + auto status = Status::OK(); + + std::vector feeds; + std::vector fetches; + + CreateInitialFeeds(feeds); + + int current_length = parameters_->sequence_length; + while (current_length < parameters_->max_length) { + if (current_length > parameters_->sequence_length) { + // Initialize resources only when needed + if (beam_scorer_.get() == nullptr) { + beam_scorer_ = std::make_unique(parameters_->batch_size, + parameters_->num_beams, + parameters_->max_length, + parameters_->length_penalty, + parameters_->early_stopping, + parameters_->num_return_sequences, + parameters_->pad_token_id, + parameters_->eos_token_id); + const OrtValue& input_ids = feeds[0]; +#ifdef DEBUG_BEAM_SEARCH + DumpOrtValue("input_ids", input_ids); +#endif + beam_state_.Init(input_ids, + parameters_->batch_size, + parameters_->num_beams, + parameters_->vocab_size, + parameters_->sequence_length, + parameters_->max_length); + } + + ORT_RETURN_IF_ERROR(ProcessLogitsAndUpdateFeeds(fetches, feeds, current_length)); + fetches.clear(); + +#ifdef DEBUG_BEAM_SEARCH + if (current_length - parameters_->sequence_length == 3) { // only dump a few steps. + ConfigureTensorDump(false); + } +#endif + } + + status = utils::ExecuteSubgraph(session_state_, ffm, feeds, fetches, {}, + ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), context_.Logger()); + + ORT_RETURN_IF_ERROR(status); + + ++current_length; + } + + return status; +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h new file mode 100644 index 0000000000000..e14500d1c8417 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "gsl/gsl" +#include "core/common/common.h" +#include "core/framework/feeds_fetches_manager.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/controlflow/utils.h" +#include "beam_search_parameters.h" +#include "beam_search_scorer.h" + +namespace onnxruntime { +namespace contrib { + +struct GptSubgraphInfo { + GptSubgraphInfo(const onnxruntime::Node& node, const GraphViewer& subgraph_in); + + const GraphViewer& subgraph; + + int num_implicit_inputs; + + int num_subgraph_inputs; // same as subgraph_input_names.size(), keep it for convenience. + int num_subgraph_outputs; // same as subgraph_output_names.size() + + std::vector subgraph_input_names; + std::vector subgraph_output_names; +}; + +// This class keeps track of sequences generated. +class Sequences : public ISequences { +public: + Sequences(){} + + // Initialize the sequence with initial input_ids and related parameters. + void Init(const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length); + + // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). + gsl::span GetSequence(int beam_index) override; + + // Returns current sequence length. + int GetSequenceLength() override; + + // Print the sequences to StdOut in debug mode + void PrintSequences(); + + // Select sequences based on beam indices, then append next token to selected sequences. + void AppendNextTokenToSequences( + gsl::span& beam_indices, + gsl::span& beam_next_tokens); + +private: + // Two buffers of shape (batch_size, num_beams, max_seq_length) to store sequences. + // At each time, there is only one buffer is active. The other one will be active in next token. + // Each AppendNextTokenToSequences call will trigger a rotation of active buffer. + std::vector sequences[2]; + + // Index (either 0 or 1) of two buffers that is currently is active. + int current_sequences_buffer; + + int batch_beam_size_; + int max_length_; + int current_length_; +}; + +struct BeamSearchState { + // TODO: use allocater to allocate a buffer, and point each data to a span of the buffer + // so as to reuse related code in CUDA. + std::vector done; // shape (batch_size) + std::vector beam_scores; // shape (batch_size, num_beams) + + std::vector next_token_logits; // shape (batch_size * num_beams, vocab_size) + std::vector next_token_scores; // shape (batch_size, num_beams * vocab_size) + + std::vector next_tokens; // shape (batch_size, num_beams) + std::vector next_indices; // shape (batch_size, num_beams) + + Sequences sequences; + + std::vector scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) + + void Init(const OrtValue& input_ids, int batch_size, int num_beams, int vocab_size, int sequence_length, int max_length) { + int batch_beam_size = batch_size * num_beams; + done.assign(batch_size, 0); + beam_scores.assign(batch_beam_size, 0.0f); + for (int i = 0; i < batch_size; i++) + { + for (int j = 1; j < num_beams; j++) { + beam_scores[i * num_beams + j] = -1e9; + } + } + + next_token_logits.assign(batch_beam_size * vocab_size, 0.0f); + next_token_scores.assign(batch_beam_size * vocab_size, 0.0f); + + next_tokens.assign(batch_beam_size, 0); + next_indices.assign(batch_beam_size, 0); + + sequences.Init(input_ids, batch_beam_size, sequence_length, max_length); + + scores.reserve((max_length - sequence_length + 1) * batch_size * num_beams * vocab_size); + } +}; + +template +class BeamSearch : public controlflow::IControlFlowKernel { + public: + BeamSearch(const OpKernelInfo& info) : IControlFlowKernel(info) { Init(info); } + void Init(const OpKernelInfo& info); + + Status Compute(OpKernelContext* ctx) const override; + + Status SetupSubgraphExecutionInfo(const SessionState& session_state, + const std::string& attribute_name, + const SessionState& subgraph_session_state) override; + + static std::unique_ptr Create(const OpKernelInfo& info, void* stream); + + protected: + void SetComputeStream(void* stream) { stream_ = stream; } + + private: + // Subgraph info and FeedsFetchesManager re-used for each subgraph execution. + std::unique_ptr subgraph_info_; + std::unique_ptr feeds_fetches_manager_; + + void* stream_; + + BeamSearchParameters parameters_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc new file mode 100644 index 0000000000000..e0d117cd58832 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/common/common.h" +#include "core/framework/allocator.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/cpu/rnn/rnn_helpers.h" +#include "beam_search_scorer.h" + +namespace onnxruntime { +namespace contrib { + +using ::onnxruntime::rnn::detail::Allocate; + +BeamHypotheses::BeamHypotheses(int num_beams, float length_penalty, bool early_stopping) + : num_beams_(num_beams), + length_penalty_(length_penalty), + early_stopping_(early_stopping), + worst_score_(1e9) {} + +void BeamHypotheses::Add(gsl::span& hypothesis, float sum_logprobs) { + auto length = hypothesis.size(); + float score = sum_logprobs / pow(static_cast(length), length_penalty_); + + if (this->Size() < num_beams_ || score > worst_score_) { + HypothesisScore item(hypothesis, score); + beams_.push(item); + if (this->Size() > num_beams_) { + beams_.pop(); + } + worst_score_ = beams_.top().score; + } +} + +bool BeamHypotheses::IsDone(float best_sum_logprobs, int current_length) { + // If there are enough hypotheses and that none of the hypotheses being generated can become better + // than the worst one in the heap, then we are done with this sentence. + + if (Size() < num_beams_) + return false; + + if (early_stopping_) + return true; + + float current_score = best_sum_logprobs / pow(static_cast(current_length), length_penalty_); + return worst_score_ >= current_score; +} + +BeamSearchScorer::BeamSearchScorer(int batch_size, + int num_beams, + int max_length, + float length_penalty, + bool early_stopping, + int num_return_sequences, + int pad_token_id, + int eos_token_id) + : batch_size_(batch_size), + num_beams_(num_beams), + max_length_(max_length), + length_penalty_(length_penalty), + early_stopping_(early_stopping), + num_beam_hyps_to_keep_(num_return_sequences), + pad_token_id_(pad_token_id), + eos_token_id_(eos_token_id), + hypothesis_buffer_length_(0), + hypothesis_buffer_offset_(0) { + for (int batch = 0; batch < batch_size; batch++) { + beam_hyps.push_back(BeamHypotheses(num_beams, length_penalty, early_stopping)); + } + + for (int batch = 0; batch < batch_size; batch++) { + done_.push_back(false); + } +} + +bool BeamSearchScorer::IsDone() { + for (int batch = 0; batch < batch_size_; batch++) { + if (!done_[batch]) + return false; + } + return true; +} + +void BeamSearchScorer::Process(ISequences* sequences, + gsl::span& next_scores, + gsl::span& next_tokens, + gsl::span& next_indices, + AllocatorPtr& allocator) { + // sequences shape is (batch_size * num_beams, total_sequence_length) + // It contains word ID of whole sequence generated so far. + // It is different from subgraph input_ids, which only need one word when past state is not empty. + + const int sequence_length = sequences->GetSequenceLength(); + + ORT_ENFORCE(next_scores.size() == next_tokens.size()); + ORT_ENFORCE(next_scores.size() == next_indices.size()); + + // Allocate buffers only once + if (next_beam_scores_.empty()) { + size_t batch_beam_size = static_cast(batch_size_ * num_beams_); + const bool fill_zeros = false; + next_beam_scores_ = Allocate(allocator, batch_beam_size, next_beam_scores_ptr_, fill_zeros); + next_beam_tokens_ = Allocate(allocator, batch_beam_size, next_beam_tokens_ptr_, fill_zeros); + next_beam_indices_ = Allocate(allocator, batch_beam_size, next_beam_indices_ptr_, fill_zeros); + + // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length + int buffer_per_beam = (max_length_ * (max_length_ + 1) - (sequence_length - 1) * sequence_length) / 2; + hypothesis_buffer_length_ = batch_beam_size * static_cast(buffer_per_beam); + hypothesis_buffer_ = Allocate(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, fill_zeros); + } + + for (int batch = 0; batch < batch_size_; batch++) { + BeamHypotheses& beam_hyp = beam_hyps[batch]; + if (done_[batch]) { + ORT_ENFORCE(beam_hyp.Size() >= num_beams_, "Batch can only be done if all beams have been generated"); + + // Pad the batch + for (int j = 0; j < num_beams_; j++) { + next_beam_scores_[batch * num_beams_ + j] = 0.0f; + next_beam_tokens_[batch * num_beams_ + j] = pad_token_id_; + next_beam_indices_[batch * num_beams_ + j] = 0; + } + continue; + } + + // Next tokens for this sentence + int beam_idx = 0; + int top_k = 2 * num_beams_; + for (int j = 0; j < top_k; j++) { + int64_t next_token = next_tokens[batch * top_k + j]; + float next_score = next_scores[batch * top_k + j]; + int64_t next_index = next_indices[batch * top_k + j]; + + int batch_beam_idx = batch * num_beams_ + static_cast(next_index); + // Add to generated hypotheses if end of sentence + if ((eos_token_id_ >= 0) && (next_token == eos_token_id_)) { + bool is_beam_token_worse_than_top_num_beams = (j >= num_beams_); + if (is_beam_token_worse_than_top_num_beams) { + continue; + } + + // Clone the sequence and append to buffer. + gsl::span src = sequences->GetSequence(batch_beam_idx); + auto clone = hypothesis_buffer_.subspan(hypothesis_buffer_offset_, sequence_length); + gsl::copy(src, clone); + hypothesis_buffer_offset_ += sequence_length; + beam_hyp.Add(clone, next_score); + } else { + // Add next predicted token since it is not eos_token + next_beam_scores_[batch * num_beams_ + beam_idx] = next_score; + next_beam_tokens_[batch * num_beams_ + beam_idx] = next_token; + next_beam_indices_[batch * num_beams_ + beam_idx] = batch_beam_idx; + ++beam_idx; + } + + // Once the beam for next step is full, don't add more tokens to it. + if (beam_idx == num_beams_) + break; + } + + ORT_ENFORCE(beam_idx == num_beams_); + ORT_ENFORCE(hypothesis_buffer_offset_ <= batch_size_ * num_beams_ * max_length_); + + // Check if we are done so that we can save a pad step if all(done) + if (!done_[batch]) { + gsl::span topk_scores = next_scores.subspan(batch * num_beams_, top_k); + const float* best_sum_logprobs = std::max_element(topk_scores.begin(), topk_scores.end()); + if (beam_hyp.IsDone(*best_sum_logprobs, sequence_length)) { + done_[batch] = true; + } + } + } +} + +void BeamSearchScorer::Finalize(ISequences* sequences, + gsl::span& final_beam_scores, + gsl::span& final_beam_tokens, + gsl::span& final_beam_indices, + AllocatorPtr& allocator, + Tensor* output_sequences, + Tensor* output_sequence_scores) { + //TODO: implement + ORT_ENFORCE(sequences != nullptr); + ORT_ENFORCE(final_beam_scores.data() != nullptr); + ORT_ENFORCE(final_beam_tokens.data() != nullptr); + ORT_ENFORCE(final_beam_indices.data() != nullptr); + ORT_ENFORCE(allocator.get() != nullptr); + ORT_ENFORCE(output_sequences != nullptr); + ORT_ENFORCE(output_sequence_scores != nullptr); +} + + +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h new file mode 100644 index 0000000000000..a733c9853365e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// The implementation is based on huggingface transformers generation_beam_search.py + +#pragma once +#include +#include +#include "core/common/common.h" +#include "core/framework/allocator.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/tensor/utils.h" +namespace onnxruntime { +namespace contrib { + +class ISequences { +public: + virtual gsl::span GetSequence(int beam_index) = 0; + virtual int GetSequenceLength() = 0; +}; + +// Interface for all scorers for beam search or beam sample. +class IBeamScorer { + public: + virtual void Process(ISequences* sequences, + gsl::span& next_scores, + gsl::span& next_tokens, + gsl::span& next_indices, + AllocatorPtr& allocator) = 0; + + virtual void Finalize(ISequences* sequences, + gsl::span& final_beam_scores, + gsl::span& final_beam_tokens, + gsl::span& final_beam_indices, + AllocatorPtr& allocator, + Tensor* output_sequences, + Tensor* output_sequence_scores) = 0; +}; + +struct HypothesisScore { + HypothesisScore(gsl::span& _hypothesis, float _score) + : hypothesis(_hypothesis), score(_score) {} + + gsl::span hypothesis; + float score; +}; + +class HypothesisScoreCompare { + public: + bool operator()(const HypothesisScore& a, const HypothesisScore& b) { + return a.score > b.score; + } +}; + +class BeamHypotheses { + public: + BeamHypotheses(int num_beams, float length_penalty, bool early_stopping); + + // Number of hypotheses + int Size() { return static_cast(beams_.size()); } + + // Add a new hypothesis + void Add(gsl::span& hypothesis, float sum_logprobs); + + bool IsDone(float best_sum_logprobs, int current_length); + + private: + int num_beams_; + float length_penalty_; + bool early_stopping_; + float worst_score_; + std::priority_queue, HypothesisScoreCompare> beams_; // min-heap for top k +}; + +class BeamSearchScorer : public IBeamScorer { + public: + BeamSearchScorer(int batch_size, + int num_beams, + int max_length, + float length_penalty, + bool early_stopping, + int num_return_sequences, + int pad_token_id, + int eos_token_id); + + bool IsDone(); + + void Process(ISequences* sequences, + gsl::span& next_scores, + gsl::span& next_tokens, + gsl::span& next_indices, + AllocatorPtr& allocator) override; + + void Finalize(ISequences* sequences, + gsl::span& final_beam_scores, + gsl::span& final_beam_tokens, + gsl::span& final_beam_indices, + AllocatorPtr& allocator, + Tensor* output_sequences, + Tensor* output_sequence_scores) override; + + gsl::span& GetNextScores() { return next_beam_scores_; } + gsl::span& GetNextTokens() { return next_beam_tokens_; } + gsl::span& GetNextIndices() { return next_beam_indices_; } + + private: + int batch_size_; + int num_beams_; + int max_length_; + float length_penalty_; + bool early_stopping_; + int num_beam_hyps_to_keep_; + int pad_token_id_; + int eos_token_id_; + + std::vector beam_hyps; // List of batch result of beam search. Its shape is (batch_size) + std::vector done_; // List of flags indicates whether each batch is finished or not. Its shape is (batch_size). + + IAllocatorUniquePtr next_beam_scores_ptr_; + gsl::span next_beam_scores_; + + IAllocatorUniquePtr next_beam_tokens_ptr_; + gsl::span next_beam_tokens_; + + IAllocatorUniquePtr next_beam_indices_ptr_; + gsl::span next_beam_indices_; + + IAllocatorUniquePtr hypothesis_buffer_ptr_; // Allocated buffer to hold all hypotheses + gsl::span hypothesis_buffer_; // Span of the allocated buffer + size_t hypothesis_buffer_length_; // Total number of elements + size_t hypothesis_buffer_offset_; // Offset of avaiable buffer, or length of used buffer. +}; + +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc new file mode 100644 index 0000000000000..00daeaedadd4c --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "dump_tensor.h" + +namespace onnxruntime { +#ifdef DEBUG_BEAM_SEARCH + +#ifdef NDEBUG +bool g_enable_tensor_dump = false; +#else +bool g_enable_tensor_dump = true; +#endif + +void DumpOrtValue(const char* name, const OrtValue& value) { + if (!g_enable_tensor_dump) + return; + std::cout << std::string(name) << "\n"; + const Tensor& tensor = value.Get(); + MLDataType dataType = tensor.DataType(); + if (dataType == DataTypeImpl::GetType()) { + DumpTensor(nullptr, tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpTensor(nullptr, tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpTensor(nullptr, tensor); + } else { + std::cout << "not float/int32/int64"; + } +} + +void ConfigureTensorDump(bool enable) { + g_enable_tensor_dump = enable; +} +#endif + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h new file mode 100644 index 0000000000000..0e9b3a3280847 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include "core/framework/tensorprotoutils.h" + +namespace onnxruntime { + +#ifdef NDEBUG +#define DEBUG_BEAM_SEARCH 1 +#endif + +#ifdef DEBUG_BEAM_SEARCH + +#define MAX_ROW_OR_COLUMN 8 + +#define SKIP_IF_MORE_THAN(row_or_column_size, i, max_n, new_line) \ + if (row_or_column_size > max_n \ + && i >= max_n / 2 \ + && i + max_n / 2 < row_or_column_size){ \ + if (i == max_n / 2) { \ + std::cout << ", ..."; \ + if (new_line) \ + std::cout << std::endl; \ + } \ + continue; \ + } + +#define SKIP_IF_TOO_MANY(row_or_column_size, i, new_line) SKIP_IF_MORE_THAN(row_or_column_size, i, MAX_ROW_OR_COLUMN, new_line) + +extern bool g_enable_tensor_dump; // global variance to turn on/off dump + +template +void PrintValue(const T& value){ + if (std::is_floating_point::value) + std::cout << std::setprecision(8) << value; + else + std::cout << value; +} + +template +void DumpTensor(const char* name, const Tensor& tensor) { + if (!g_enable_tensor_dump) + return; + + if (nullptr != name) { + std::cout << std::string(name) << std::endl; + } + + const auto& shape = tensor.Shape(); + auto num_items = shape.Size(); + + if (num_items == 0) { + std::cout << "no data"; + return; + } + + size_t num_dims = shape.NumDimensions(); + size_t num_rows = 1; + if (num_dims > 1) { + num_rows = static_cast(shape[0]); + } + + size_t row_size = num_items / num_rows; + + auto data = tensor.DataAsSpan(); + + for (size_t row = 0; row < num_rows; ++row) { + SKIP_IF_TOO_MANY(num_rows, row, true); + std::cout << "[" << row << "]:"; + for (size_t i = 0; i < row_size; ++i) { + SKIP_IF_TOO_MANY(row_size, i, false); + + if (i > 0) + std::cout << ", "; + + PrintValue(data[row * row_size + i]); + } + std::cout << "\n"; + } + + std::cout << std::endl; +} + +void DumpOrtValue(const char* name, const OrtValue& value); + +template +void DumpTensor(const char* name, const T* tensor, int dim0, int dim1) { + if (!g_enable_tensor_dump) + return; + + if (nullptr != name) { + std::cout << std::string(name) << std::endl; + } + + for (int i = 0; i < dim0; i++) { + SKIP_IF_TOO_MANY(dim0, i, true); + std::cout << "[" << i << "]:"; + for (int j = 0; j < dim1; j++) { + SKIP_IF_TOO_MANY(dim1, j, false); + if (j > 0) + std::cout << ", "; + T value = tensor[i * dim1 + j]; + PrintValue(value); + } + std::cout << std::endl; + } +} + +template +void DumpTensorName(const char* name, T index, bool end_line) { + std::cout << std::string(name) << "[" << index << "]"; + if(end_line){ + std::cout << std::endl; + } +} + +template +void DumpTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) { + if (!g_enable_tensor_dump) + return; + + if (nullptr != name) { + std::cout << std::string(name) << std::endl; + } + + for (int i = 0; i < dim0; i++) { + SKIP_IF_TOO_MANY(dim0, i, true); + for (int j = 0; j < dim1; j++) { + SKIP_IF_TOO_MANY(dim1, j, true); + std::cout << "[" << i << "][" << j << "]:"; + for (int k = 0; k < dim2; k++) { + SKIP_IF_TOO_MANY(dim2, k, false); + if (k > 0) + std::cout << ", "; + T value = tensor[i * dim1 * dim2 + j * dim2 + k]; + PrintValue(value); + } + std::cout << std::endl; + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +void ConfigureTensorDump(bool enable); +#endif + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 9c98bce6e649d..64a29bc7f6d88 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -325,7 +325,7 @@ def test_model(args): ort_session = create_ort_session(args.output, args.use_gpu) - batch_size = 2 + batch_size = 1 input_ids = input_ids.repeat(batch_size, 1) inputs = { From e1ae8485dd1baaae1ab011a9bcf9ade9fce1fa0a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 18 Nov 2021 01:40:55 -0800 Subject: [PATCH 12/53] output results --- .../cpu/transformers/beam_earch.cc | 147 ++++++++++-------- .../cpu/transformers/beam_search_scorer.cc | 104 +++++++++++-- .../cpu/transformers/beam_search_scorer.h | 22 ++- .../cpu/transformers/dump_tensor.cc | 38 ++++- .../cpu/transformers/dump_tensor.h | 63 ++++---- .../core/graph/contrib_ops/contrib_defs.cc | 23 +-- .../tools/transformers/convert_beam_search.py | 19 ++- 7 files changed, 270 insertions(+), 146 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_earch.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_earch.cc index 804f14b8e0af0..6d8e0595237b4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_earch.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_earch.cc @@ -9,6 +9,10 @@ #pragma warning(disable : 4996) #endif +#ifndef NDEBUG +#define DEBUG_BEAM_SEARCH 1 // TODO: remove this once this operator is ready for production. +#endif + #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/math/top_k.h" #include "core/framework/allocator.h" @@ -106,14 +110,13 @@ int Sequences::GetSequenceLength() { } void Sequences::PrintSequences() { -#ifdef DEBUG_BEAM_SEARCH - std::cout << "sequences:" << std::endl; +#ifdef DEBUG_BEAM_SEARCH for (int i = 0; i < batch_beam_size_; i++) { gsl::span sequence = GetSequence(i); - std::string beam_index = std::to_string(i); - DumpTensor(beam_index.c_str(), sequence.data(), 1, current_length_); + DumpString("sequences", i, false); + DumpTensor(nullptr, sequence.data(), 1, current_length_); } -#endif +#endif } void Sequences::AppendNextTokenToSequences( @@ -167,11 +170,13 @@ class BeamSearchImpl { // Prepare the inputs for first inference of subgraph void CreateInitialFeeds(std::vector& feeds); - // Process Logits and Update the input for next iteration. - Status ProcessLogitsAndUpdateFeeds( + // Update the input for next iteration. + Status UpdateFeeds( const std::vector& last_outputs, std::vector& next_inputs, - int current_length); + int current_length, + gsl::span beam_next_tokens, + gsl::span beam_indices); // Process logits and append next tokens to sequences Status GenerateNextToken(const OrtValue& logits, @@ -188,7 +193,7 @@ class BeamSearchImpl { // Reorder cache by picking the past state based on beam indices void PickPastState(const std::vector& last_outputs, std::vector& next_inputs, - gsl::span& beam_indices); + gsl::span& beam_indices); OpKernelContextInternal& context_; const SessionState& session_state_; @@ -590,7 +595,7 @@ void BeamSearchImpl::CreateInitialFeeds(std::vector& feeds) { template Status BeamSearchImpl::ProcessLogits( - const OrtValue& logits, // logits output of subgraph + const OrtValue& logits, // logits output of subgraph BeamSearchState& beam_state, int top_k, AllocatorPtr& allocator) { @@ -723,9 +728,9 @@ Status BeamSearchImpl::ProcessLogits( template Status BeamSearchImpl::GenerateNextToken( - const OrtValue& logits, - gsl::span& beam_next_tokens, - gsl::span& beam_indices) { + const OrtValue& logits, + gsl::span& beam_next_tokens, + gsl::span& beam_indices) { // Process logits to get next token scores, and select top_k = 2 * num_beams // TODO: we might not need 2 * num_beams when logits processors does not update token scores. const int top_k = 2 * parameters_->num_beams; @@ -746,7 +751,7 @@ Status BeamSearchImpl::GenerateNextToken( beam_state_.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens); -#ifdef DEBUG_BEAM_SEARCH +#ifdef DEBUG_BEAM_SEARCH beam_state_.sequences.PrintSequences(); #endif return Status::OK(); @@ -760,14 +765,14 @@ void BeamSearchImpl::ProcessNextTokenScores(gsl::span& /*next_token_sc template void BeamSearchImpl::PickPastState(const std::vector& last_outputs, std::vector& next_inputs, - gsl::span& beam_indices) { + gsl::span& beam_indices) { for (int i = 3; i < subgraph_info_.num_subgraph_inputs; ++i) { const OrtValue& present = last_outputs[i - 2]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64) const TensorShape& past_shape = present.Get().Shape(); // Create a tensor with same shape. OrtValue past; - auto past_type = DataTypeImpl::GetType(); // present.Type() + auto past_type = DataTypeImpl::GetType(); // present.Type() Tensor::InitOrtValue(past_type, past_shape, allocator_, past); auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4]; @@ -786,12 +791,12 @@ void BeamSearchImpl::PickPastState(const std::vector& last_outputs, gsl::copy(present_value, past_value); #ifdef DEBUG_BEAM_SEARCH - if (i == 3) // only dump past_0 + if (i == 3) // only dump past_0 { - DumpTensorName("past_key of beam", j, true); + DumpString("past_key of beam", static_cast(j), true); DumpTensor(nullptr, past_key.data(), 1, static_cast(block_size_per_beam)); - DumpTensorName("past_value of beam", j, true); + DumpString("past_value of beam", static_cast(j), true); DumpTensor(nullptr, past_value.data(), 1, static_cast(block_size_per_beam)); } #endif @@ -802,21 +807,15 @@ void BeamSearchImpl::PickPastState(const std::vector& last_outputs, } template -Status BeamSearchImpl::ProcessLogitsAndUpdateFeeds( +Status BeamSearchImpl::UpdateFeeds( const std::vector& last_outputs, std::vector& next_inputs, - int current_length) { + int current_length, + gsl::span beam_next_tokens, + gsl::span beam_indices) { // last_outputs: logits, present_0, present_1, ... // next_inputs: input_ids, position_id, attention_mask, past_0, past_1 - // Process logits to get next token scores, and select top_k = 2 * num_beams - // TODO: we might not need 2 * num_beams when logits processors does not update token scores. - const OrtValue& logits = last_outputs[0]; - - gsl::span beam_next_tokens; - gsl::span beam_indices; - ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices)); - // The following updates inputs for subgraph // TODO: Reuse buffer for input_ids and position_ids to reduce memory allocation. @@ -838,8 +837,8 @@ Status BeamSearchImpl::ProcessLogitsAndUpdateFeeds( Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, position_ids); int64_t* position_data = position_ids.GetMutable()->MutableData(); for (int i = 0; i < batch_beam_size; i++) { - position_data[i] = next_positions_[i]; - next_positions_[i]++; + position_data[i] = next_positions_[i]; + next_positions_[i]++; } next_inputs[1] = position_ids; @@ -882,54 +881,80 @@ template Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { auto status = Status::OK(); + std::vector sequences_dims{parameters_->batch_size, parameters_->num_return_sequences, parameters_->max_length}; + TensorShape sequences_shape(sequences_dims); + Tensor* output_sequences = context_.Output(0, sequences_shape); + + std::vector sequences_scores_dims{parameters_->batch_size, parameters_->num_return_sequences}; + TensorShape sequences_scores_shape(sequences_scores_dims); + Tensor* output_sequences_scores = context_.Output(1, sequences_scores_shape); + std::vector feeds; std::vector fetches; CreateInitialFeeds(feeds); - int current_length = parameters_->sequence_length; - while (current_length < parameters_->max_length) { - if (current_length > parameters_->sequence_length) { - // Initialize resources only when needed - if (beam_scorer_.get() == nullptr) { - beam_scorer_ = std::make_unique(parameters_->batch_size, - parameters_->num_beams, - parameters_->max_length, - parameters_->length_penalty, - parameters_->early_stopping, - parameters_->num_return_sequences, - parameters_->pad_token_id, - parameters_->eos_token_id); - const OrtValue& input_ids = feeds[0]; + // Initialize resources + beam_scorer_ = std::make_unique(parameters_->batch_size, + parameters_->num_beams, + parameters_->max_length, + parameters_->length_penalty, + parameters_->early_stopping, + parameters_->num_return_sequences, + parameters_->pad_token_id, + parameters_->eos_token_id); + const OrtValue& input_ids = feeds[0]; #ifdef DEBUG_BEAM_SEARCH - DumpOrtValue("input_ids", input_ids); + DumpOrtValue("input_ids", input_ids); + DumpOrtValue("position_ids", feeds[1]); + DumpOrtValue("attention_mask", feeds[2]); #endif - beam_state_.Init(input_ids, - parameters_->batch_size, - parameters_->num_beams, - parameters_->vocab_size, - parameters_->sequence_length, - parameters_->max_length); - } - ORT_RETURN_IF_ERROR(ProcessLogitsAndUpdateFeeds(fetches, feeds, current_length)); - fetches.clear(); + beam_state_.Init(input_ids, + parameters_->batch_size, + parameters_->num_beams, + parameters_->vocab_size, + parameters_->sequence_length, + parameters_->max_length); -#ifdef DEBUG_BEAM_SEARCH - if (current_length - parameters_->sequence_length == 3) { // only dump a few steps. - ConfigureTensorDump(false); - } -#endif - } + int current_length = parameters_->sequence_length; + while (current_length < parameters_->max_length) { +#ifdef DEBUG_BEAM_SEARCH + DumpString("***CurrentLength", std::to_string(current_length), true); +#endif status = utils::ExecuteSubgraph(session_state_, ffm, feeds, fetches, {}, ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), context_.Logger()); ORT_RETURN_IF_ERROR(status); + const OrtValue& logits = fetches[0]; + gsl::span beam_next_tokens; + gsl::span beam_indices; + ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices)); + + // Increase sequence length after a new token is generated. ++current_length; + + // Prepare inputs for next round of subgraph call. + if (current_length < parameters_->max_length) { + ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length, beam_next_tokens.as_span(), beam_indices.as_span())); + } + fetches.clear(); + +#ifdef DEBUG_BEAM_SEARCH + if (current_length - parameters_->sequence_length == 3) { // only dump a few steps. + DisableTensorDump(); + } +#endif } + gsl::span beam_scores(beam_state_.beam_scores.data(), beam_state_.beam_scores.size()); + beam_scorer_->Finalize(&(beam_state_.sequences), + beam_scores, + output_sequences, + output_sequences_scores); + return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index e0d117cd58832..f2d9aceaf1336 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -22,7 +22,7 @@ BeamHypotheses::BeamHypotheses(int num_beams, float length_penalty, bool early_s early_stopping_(early_stopping), worst_score_(1e9) {} -void BeamHypotheses::Add(gsl::span& hypothesis, float sum_logprobs) { +void BeamHypotheses::Add(gsl::span& hypothesis, float sum_logprobs) { auto length = hypothesis.size(); float score = sum_logprobs / pow(static_cast(length), length_penalty_); @@ -50,6 +50,41 @@ bool BeamHypotheses::IsDone(float best_sum_logprobs, int current_length) { return worst_score_ >= current_score; } +void BeamHypotheses::Output( + int top_k, + int max_length, + gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) + gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty +{ + ORT_ENFORCE(top_k <= Size()); + int remove_count = Size() - top_k; + for (int i = 0; i < remove_count; i++) { + beams_.pop(); + } + + // Since pop get the worst sequence, so output it in the reverse order. + // The frist (worst) beam shall be put at the last position among top_k sequences. + int index = top_k - 1; + while (!beams_.empty()) { + auto item = beams_.top(); + gsl::span& source = item.hypothesis; + gsl::span target = sequences.subspan(index * max_length, max_length); + + // Note that word_ids might be less than max_length. + // Since the sequences has been filled with pad token ID, so padding is not needed here. + // Since data type need cast from int64_t to int32_t, we cannot use gsl::copy(word_ids, sequence) here. + for (int i = 0; i < source.length(); i++){ + target[i] = static_cast(source[i]); + } + + if (!sequences_scores.empty()) + sequences_scores[index] = item.score; + + beams_.pop(); + index--; + } +} + BeamSearchScorer::BeamSearchScorer(int batch_size, int num_beams, int max_length, @@ -61,8 +96,6 @@ BeamSearchScorer::BeamSearchScorer(int batch_size, : batch_size_(batch_size), num_beams_(num_beams), max_length_(max_length), - length_penalty_(length_penalty), - early_stopping_(early_stopping), num_beam_hyps_to_keep_(num_return_sequences), pad_token_id_(pad_token_id), eos_token_id_(eos_token_id), @@ -148,7 +181,8 @@ void BeamSearchScorer::Process(ISequences* sequences, auto clone = hypothesis_buffer_.subspan(hypothesis_buffer_offset_, sequence_length); gsl::copy(src, clone); hypothesis_buffer_offset_ += sequence_length; - beam_hyp.Add(clone, next_score); + auto sequence = clone.as_span(); + beam_hyp.Add(sequence, next_score); } else { // Add next predicted token since it is not eos_token next_beam_scores_[batch * num_beams_ + beam_idx] = next_score; @@ -178,21 +212,59 @@ void BeamSearchScorer::Process(ISequences* sequences, void BeamSearchScorer::Finalize(ISequences* sequences, gsl::span& final_beam_scores, - gsl::span& final_beam_tokens, - gsl::span& final_beam_indices, - AllocatorPtr& allocator, Tensor* output_sequences, Tensor* output_sequence_scores) { - //TODO: implement - ORT_ENFORCE(sequences != nullptr); - ORT_ENFORCE(final_beam_scores.data() != nullptr); - ORT_ENFORCE(final_beam_tokens.data() != nullptr); - ORT_ENFORCE(final_beam_indices.data() != nullptr); - ORT_ENFORCE(allocator.get() != nullptr); - ORT_ENFORCE(output_sequences != nullptr); - ORT_ENFORCE(output_sequence_scores != nullptr); -} + ORT_ENFORCE(sequences != nullptr); + ORT_ENFORCE(output_sequences != nullptr); + + // finalize all open beam hypotheses and add to generated hypotheses + for (int batch_index = 0; batch_index < batch_size_; batch_index++) { + BeamHypotheses& beam_hyp = beam_hyps[batch_index]; + if (done_[batch_index]) { + continue; + } + for (int beam_index = 0; beam_index < num_beams_; beam_index++) { + int batch_beam_index = batch_index * num_beams_ + beam_index; + float final_score = final_beam_scores[batch_beam_index]; + auto final_tokens = sequences->GetSequence(batch_beam_index); + beam_hyp.Add(final_tokens, final_score); + } + } + + // word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length) + gsl::span output = output_sequences->MutableDataAsSpan(); + + // Fill output sequences with pad token ID so that we do not need append it later. + std::fill_n(output.data(), output.size(), pad_token_id_); + + // score of each sequence, with shape (batch_size * num_return_sequences) + gsl::span sequence_scores; + if (output_sequence_scores != nullptr){ + sequence_scores = output_sequence_scores->MutableDataAsSpan(); + } + + // span is empty when output_sequence_scores is NULL. + gsl::span batch_sequence_score; + + // Select the best hypotheses according to number of sequences to return. + for (int batch_index = 0; batch_index < batch_size_; batch_index++) { + BeamHypotheses& beam_hyp = beam_hyps[batch_index]; + + const int num_return_sequences = num_beam_hyps_to_keep_; + auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_, num_return_sequences * max_length_); + + if (output_sequence_scores != nullptr){ + batch_sequence_score = sequence_scores.subspan(batch_index * num_return_sequences, num_return_sequences); + } + + beam_hyp.Output( + num_return_sequences, + max_length_, + batch_output, + batch_sequence_score); + } +} } // namespace contrib } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index a733c9853365e..662b3218d3dc9 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -31,18 +31,15 @@ class IBeamScorer { virtual void Finalize(ISequences* sequences, gsl::span& final_beam_scores, - gsl::span& final_beam_tokens, - gsl::span& final_beam_indices, - AllocatorPtr& allocator, Tensor* output_sequences, Tensor* output_sequence_scores) = 0; }; struct HypothesisScore { - HypothesisScore(gsl::span& _hypothesis, float _score) + HypothesisScore(gsl::span& _hypothesis, float _score) : hypothesis(_hypothesis), score(_score) {} - gsl::span hypothesis; + gsl::span hypothesis; float score; }; @@ -61,10 +58,16 @@ class BeamHypotheses { int Size() { return static_cast(beams_.size()); } // Add a new hypothesis - void Add(gsl::span& hypothesis, float sum_logprobs); + void Add(gsl::span& hypothesis, float sum_logprobs); bool IsDone(float best_sum_logprobs, int current_length); + // Output results. Note that it will clear all beams. + void Output(int top_k, // number of sequences to return + int max_length, // max sequence length + gsl::span& sequences, // buffer filled with pad token ID, with shape (num_return_sequences, max_length) + gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + private: int num_beams_; float length_penalty_; @@ -94,9 +97,6 @@ class BeamSearchScorer : public IBeamScorer { void Finalize(ISequences* sequences, gsl::span& final_beam_scores, - gsl::span& final_beam_tokens, - gsl::span& final_beam_indices, - AllocatorPtr& allocator, Tensor* output_sequences, Tensor* output_sequence_scores) override; @@ -108,8 +108,6 @@ class BeamSearchScorer : public IBeamScorer { int batch_size_; int num_beams_; int max_length_; - float length_penalty_; - bool early_stopping_; int num_beam_hyps_to_keep_; int pad_token_id_; int eos_token_id_; @@ -129,7 +127,7 @@ class BeamSearchScorer : public IBeamScorer { IAllocatorUniquePtr hypothesis_buffer_ptr_; // Allocated buffer to hold all hypotheses gsl::span hypothesis_buffer_; // Span of the allocated buffer size_t hypothesis_buffer_length_; // Total number of elements - size_t hypothesis_buffer_offset_; // Offset of avaiable buffer, or length of used buffer. + int hypothesis_buffer_offset_; // Offset of avaiable buffer, or length of used buffer. }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc index 00daeaedadd4c..e44a521c6e9c2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc @@ -2,9 +2,14 @@ // Licensed under the MIT License. #include "dump_tensor.h" +#include "core/platform/env.h" +#include "core/platform/env_var_utils.h" namespace onnxruntime { -#ifdef DEBUG_BEAM_SEARCH + +namespace dump_tensor_env_vars { +constexpr const char* kDumpBeamSearch = "ORT_DUMP_BEAM_SEARCH"; +} #ifdef NDEBUG bool g_enable_tensor_dump = false; @@ -29,9 +34,34 @@ void DumpOrtValue(const char* name, const OrtValue& value) { } } -void ConfigureTensorDump(bool enable) { - g_enable_tensor_dump = enable; +void ConfigureTensorDump() { + if (ParseEnvironmentVariableWithDefault(dump_tensor_env_vars::kDumpBeamSearch, false)) { + g_enable_tensor_dump = true; + } } -#endif +void DisableTensorDump() { + g_enable_tensor_dump = false; +} + +void DumpString(const char* name, int index, bool end_line) { + if (!g_enable_tensor_dump) + return; + std::cout << std::string(name) << "[" << index << "]"; + + if (end_line) { + std::cout << std::endl; + } +} + +void DumpString(const char* name, std::string value, bool end_line) { + if (!g_enable_tensor_dump) + return; + + std::cout << std::string(name) << "=" << value; + + if (end_line) { + std::cout << std::endl; + } +} } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h index 0e9b3a3280847..a4bebdf8a0f67 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h @@ -8,36 +8,28 @@ namespace onnxruntime { -#ifdef NDEBUG -#define DEBUG_BEAM_SEARCH 1 -#endif - -#ifdef DEBUG_BEAM_SEARCH - #define MAX_ROW_OR_COLUMN 8 -#define SKIP_IF_MORE_THAN(row_or_column_size, i, max_n, new_line) \ - if (row_or_column_size > max_n \ - && i >= max_n / 2 \ - && i + max_n / 2 < row_or_column_size){ \ - if (i == max_n / 2) { \ - std::cout << ", ..."; \ - if (new_line) \ - std::cout << std::endl; \ - } \ - continue; \ - } +#define SKIP_IF_MORE_THAN(row_or_column_size, i, max_n, new_line) \ + if (row_or_column_size > max_n && i >= max_n / 2 && i + max_n / 2 < row_or_column_size) { \ + if (i == max_n / 2) { \ + std::cout << ", ..."; \ + if (new_line) \ + std::cout << std::endl; \ + } \ + continue; \ + } #define SKIP_IF_TOO_MANY(row_or_column_size, i, new_line) SKIP_IF_MORE_THAN(row_or_column_size, i, MAX_ROW_OR_COLUMN, new_line) -extern bool g_enable_tensor_dump; // global variance to turn on/off dump +extern bool g_enable_tensor_dump; // global variance to turn on/off dump template -void PrintValue(const T& value){ +void PrintValue(const T& value) { if (std::is_floating_point::value) - std::cout << std::setprecision(8) << value; - else - std::cout << value; + std::cout << std::setprecision(8) << value; + else + std::cout << value; } template @@ -90,7 +82,7 @@ template void DumpTensor(const char* name, const T* tensor, int dim0, int dim1) { if (!g_enable_tensor_dump) return; - + if (nullptr != name) { std::cout << std::string(name) << std::endl; } @@ -109,13 +101,9 @@ void DumpTensor(const char* name, const T* tensor, int dim0, int dim1) { } } -template -void DumpTensorName(const char* name, T index, bool end_line) { - std::cout << std::string(name) << "[" << index << "]"; - if(end_line){ - std::cout << std::endl; - } -} +void DumpString(const char* name, int index, bool end_line); + +void DumpString(const char* name, std::string value, bool end_line); template void DumpTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) { @@ -125,15 +113,15 @@ void DumpTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) if (nullptr != name) { std::cout << std::string(name) << std::endl; } - + for (int i = 0; i < dim0; i++) { SKIP_IF_TOO_MANY(dim0, i, true); for (int j = 0; j < dim1; j++) { SKIP_IF_TOO_MANY(dim1, j, true); - std::cout << "[" << i << "][" << j << "]:"; - for (int k = 0; k < dim2; k++) { - SKIP_IF_TOO_MANY(dim2, k, false); - if (k > 0) + std::cout << "[" << i << "][" << j << "]:"; + for (int k = 0; k < dim2; k++) { + SKIP_IF_TOO_MANY(dim2, k, false); + if (k > 0) std::cout << ", "; T value = tensor[i * dim1 * dim2 + j * dim2 + k]; PrintValue(value); @@ -145,7 +133,8 @@ void DumpTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) std::cout << std::endl; } -void ConfigureTensorDump(bool enable); -#endif +void ConfigureTensorDump(); + +void DisableTensorDump(); } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 42200f34edb92..7850dc32b0154 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -560,9 +560,9 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // Shape inference // input 0 (input_ids) shape: (batch_size, sequence_length) - // output 0 (sequences) shape: (batch_size * num_return_sequences, max_length) - // output 1 (sequences_scores) shape: (batch_size * num_return_sequences) - // output 2 (scores) shape: (max_length-sequence_length, batch_size*num_beams*num_return_sequences, vocab_size) + // output 0 (sequences) shape: (batch_size, num_return_sequences, max_length) + // output 1 (sequences_scores) shape: (batch_size, num_return_sequences) + // output 2 (scores) shape: (max_length - sequence_length, batch_size, num_beams, vocab_size) if (!hasInputShape(ctx, 0)) { return; } @@ -601,19 +601,22 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { } ONNX_NAMESPACE::TensorShapeProto sequences_shape; - sequences_shape.add_dim()->set_dim_value(batch_size * num_beams_value); - sequences_shape.add_dim()->set_dim_value(batch_size * sequence_length); + sequences_shape.add_dim()->set_dim_value(batch_size); + sequences_shape.add_dim()->set_dim_value(num_return_sequences_value); + sequences_shape.add_dim()->set_dim_value(max_length_value); updateOutputShape(ctx, 0, sequences_shape); if (ctx.getNumOutputs() > 1) { ONNX_NAMESPACE::TensorShapeProto sequences_scores_shape; - sequences_scores_shape.add_dim()->set_dim_value(batch_size * num_beams_value); + sequences_shape.add_dim()->set_dim_value(batch_size); + sequences_shape.add_dim()->set_dim_value(num_return_sequences_value); updateOutputShape(ctx, 1, sequences_scores_shape); if (ctx.getNumOutputs() > 2) { ONNX_NAMESPACE::TensorShapeProto scores_shape; scores_shape.add_dim()->set_dim_value(max_length_value - sequence_length); - scores_shape.add_dim()->set_dim_value(batch_size * num_beams_value * num_return_sequences_value); + scores_shape.add_dim()->set_dim_value(batch_size); + scores_shape.add_dim()->set_dim_value(num_beams_value); scores_shape.add_dim(); // vocab_size is unknown updateOutputShape(ctx, 2, scores_shape); } @@ -646,12 +649,12 @@ void RegisterTextGenerationSchemas() { "T", OpSchema::Optional) .Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) .Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) - .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size * num_return_sequences, max_sequence_length)", "I") - .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size*num_return_sequences)", "T", OpSchema::Optional) + .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") + .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", "Processed beam scores for each vocabulary token at each generation step." "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." - "Shape is (max_length - input_ids_sequence_length, batch_size*num_beams*num_return_sequences, vocab_size)", + "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 64a29bc7f6d88..12818431b5775 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -257,18 +257,21 @@ def convert_model(args): # graph outputs sequences = helper.make_tensor_value_info('sequences', TensorProto.INT32, - ['batch_size * num_return_sequences', 'max_length']) + ['batch_size', 'num_return_sequences', 'max_length']) + sequences_scores = helper.make_tensor_value_info('sequences_scores', TensorProto.FLOAT, - ['batch_size * num_return_sequences']) + ['batch_size', 'num_return_sequences']) scores = helper.make_tensor_value_info( 'scores', TensorProto.FLOAT, - ['max_length - sequence_length', 'batch_size * num_beams * num_return_sequences', vocab_size]) + ['max_length - sequence_length', 'batch_size', 'num_beams', vocab_size]) initializers = [] graph_outputs = [sequences] + if args.output_sequences_scores: graph_outputs.append(sequences_scores) + if args.output_token_scores: graph_outputs.append(scores) @@ -350,9 +353,13 @@ def test_model(args): print("inputs", inputs) result = ort_session.run(None, inputs) - print("outputs", result) - #print(tokenizer.decode(result[0][0], skip_special_tokens=True)) - + + sequences = result[0] + print("outputs", sequences) + + #TODO: print all sequences. Below shows only the first one + first_sequence = tokenizer.decode(sequences[0][0], skip_special_tokens=True) + print(first_sequence) def main(): args = parse_arguments() From 1f1ee1a53ae92c1e92a0b0dbad28b26a453fefb0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 18 Nov 2021 08:08:59 -0800 Subject: [PATCH 13/53] fix typo --- .../cpu/transformers/{beam_earch.cc => beam_search.cc} | 0 onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename onnxruntime/contrib_ops/cpu/transformers/{beam_earch.cc => beam_search.cc} (100%) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_earch.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc similarity index 100% rename from onnxruntime/contrib_ops/cpu/transformers/beam_earch.cc rename to onnxruntime/contrib_ops/cpu/transformers/beam_search.cc diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index f2d9aceaf1336..f7c0b0e6d3ff3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -73,7 +73,7 @@ void BeamHypotheses::Output( // Note that word_ids might be less than max_length. // Since the sequences has been filled with pad token ID, so padding is not needed here. // Since data type need cast from int64_t to int32_t, we cannot use gsl::copy(word_ids, sequence) here. - for (int i = 0; i < source.length(); i++){ + for (size_t i = 0; i < source.length(); i++){ target[i] = static_cast(source[i]); } From 7c4fd9a9d63c80ea110274fc7f7e6bacd58a66d5 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 18 Nov 2021 11:32:06 -0800 Subject: [PATCH 14/53] use c++ template and format python --- .../cpu/transformers/beam_search.cc | 72 +++++++------- .../cpu/transformers/beam_search.h | 40 ++++---- .../cpu/transformers/beam_search_scorer.cc | 96 +++++++++++-------- .../cpu/transformers/beam_search_scorer.h | 55 ++++++----- .../tools/transformers/convert_beam_search.py | 58 +++++------ 5 files changed, 173 insertions(+), 148 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 6d8e0595237b4..c94a6f78657ff 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -184,11 +184,11 @@ class BeamSearchImpl { gsl::span& beam_indices); Status ProcessLogits(const OrtValue& logits, - BeamSearchState& beam_state, + BeamSearchState& beam_state, int top_k, AllocatorPtr& allocator); - void ProcessNextTokenScores(gsl::span& next_token_scores); + void ProcessNextTokenScores(gsl::span& next_token_scores); // Reorder cache by picking the past state based on beam indices void PickPastState(const std::vector& last_outputs, @@ -210,9 +210,9 @@ class BeamSearchImpl { BeamSearchParameters* parameters_; - std::unique_ptr beam_scorer_; + std::unique_ptr> beam_scorer_; - BeamSearchState beam_state_; + BeamSearchState beam_state_; AllocatorPtr allocator_; }; @@ -596,7 +596,7 @@ void BeamSearchImpl::CreateInitialFeeds(std::vector& feeds) { template Status BeamSearchImpl::ProcessLogits( const OrtValue& logits, // logits output of subgraph - BeamSearchState& beam_state, + BeamSearchState& beam_state, int top_k, AllocatorPtr& allocator) { const int64_t batch_beam_size = static_cast(parameters_->batch_size * parameters_->num_beams); @@ -607,7 +607,7 @@ Status BeamSearchImpl::ProcessLogits( DumpOrtValue("logits", logits); #endif - const float* logits_data = logits.Get().Data(); + const T* logits_data = logits.Get().Data(); const TensorShape& logits_shape = logits.Get().Shape(); ORT_ENFORCE(logits_shape.NumDimensions() == 3); @@ -621,10 +621,10 @@ Status BeamSearchImpl::ProcessLogits( // When input_length == 1, use logits directly to avoid copy logits to next_token_logits. auto next_token_logits = gsl::make_span(beam_state.next_token_logits); if (input_length > 1) { - const float* current_logits = logits_data + (input_length - 1) * vocab_size; + const T* current_logits = logits_data + (input_length - 1) * vocab_size; for (int i = 0; i < batch_beam_size; i++) { - gsl::span source(current_logits, vocab_size); - gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); + gsl::span source(current_logits, vocab_size); + gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); gsl::copy(source, target); current_logits += i * (input_length * vocab_size); } @@ -632,12 +632,12 @@ Status BeamSearchImpl::ProcessLogits( // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) auto next_token_scores = gsl::make_span(beam_state.next_token_scores); - Status status = SoftmaxCPU(batch_beam_size, // rows - vocab_size, // elements per row - input_length > 1 ? next_token_logits.data() : logits_data, - next_token_scores.data(), - true, - thread_pool_); + Status status = SoftmaxCPU(batch_beam_size, // rows + vocab_size, // elements per row + input_length > 1 ? next_token_logits.data() : logits_data, + next_token_scores.data(), + true, + thread_pool_); if (!status.IsOK()) { return status; } @@ -667,7 +667,7 @@ Status BeamSearchImpl::ProcessLogits( //next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True) int64_t next_token_scores_dims[] = {parameters_->batch_size, parameters_->num_beams * vocab_size}; TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); - auto element_type = DataTypeImpl::GetType(); + auto element_type = DataTypeImpl::GetType(); OrtValue next_token_scores_value; Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value); const Tensor& input = next_token_scores_value.Get(); @@ -683,13 +683,13 @@ Status BeamSearchImpl::ProcessLogits( std::unique_ptr topk_scores; std::unique_ptr topk_indices; - status = GetTopK(&input, axis, top_k, largest, sorted, allocator, thread_pool_, topk_scores, topk_indices); + status = GetTopK(&input, axis, top_k, largest, sorted, allocator, thread_pool_, topk_scores, topk_indices); if (!status.IsOK()) { return status; } #ifdef DEBUG_BEAM_SEARCH - DumpTensor("topk_scores", *(topk_scores.get())); + DumpTensor("topk_scores", *(topk_scores.get())); DumpTensor("topk_indices", *(topk_indices.get())); #endif @@ -706,12 +706,12 @@ Status BeamSearchImpl::ProcessLogits( } } - gsl::span next_scores = topk_scores->DataAsSpan(); + gsl::span next_scores = topk_scores->DataAsSpan(); gsl::span next_tokens(beam_state.next_tokens.data(), beam_state.next_tokens.size()); gsl::span next_indices(beam_state.next_indices.data(), beam_state.next_indices.size()); #ifdef DEBUG_BEAM_SEARCH - DumpTensor("next_scores before scorer", next_scores.data(), parameters_->batch_size, k); + DumpTensor("next_scores before scorer", next_scores.data(), parameters_->batch_size, k); DumpTensor("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, k); DumpTensor("next_indices before scorer", next_indices.data(), parameters_->batch_size, k); #endif @@ -736,7 +736,7 @@ Status BeamSearchImpl::GenerateNextToken( const int top_k = 2 * parameters_->num_beams; ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state_, top_k, allocator_)); - gsl::span& beam_scores = beam_scorer_->GetNextScores(); + gsl::span& beam_scores = beam_scorer_->GetNextScores(); // TODO: may not need clone beam_scores. beam_state_.beam_scores.assign(beam_scores.begin(), beam_scores.end()); @@ -744,7 +744,7 @@ Status BeamSearchImpl::GenerateNextToken( beam_indices = beam_scorer_->GetNextIndices(); #ifdef DEBUG_BEAM_SEARCH - DumpTensor("beam_scores after scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams); + DumpTensor("beam_scores after scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams); DumpTensor("beam_next_tokens after scorer", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams); DumpTensor("beam_indices after scorer", beam_indices.data(), parameters_->batch_size, parameters_->num_beams); #endif @@ -758,7 +758,7 @@ Status BeamSearchImpl::GenerateNextToken( } template -void BeamSearchImpl::ProcessNextTokenScores(gsl::span& /*next_token_scores*/) { +void BeamSearchImpl::ProcessNextTokenScores(gsl::span& /*next_token_scores*/) { return; } @@ -794,10 +794,10 @@ void BeamSearchImpl::PickPastState(const std::vector& last_outputs, if (i == 3) // only dump past_0 { DumpString("past_key of beam", static_cast(j), true); - DumpTensor(nullptr, past_key.data(), 1, static_cast(block_size_per_beam)); + DumpTensor(nullptr, past_key.data(), 1, static_cast(block_size_per_beam)); DumpString("past_value of beam", static_cast(j), true); - DumpTensor(nullptr, past_value.data(), 1, static_cast(block_size_per_beam)); + DumpTensor(nullptr, past_value.data(), 1, static_cast(block_size_per_beam)); } #endif } @@ -895,14 +895,14 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { CreateInitialFeeds(feeds); // Initialize resources - beam_scorer_ = std::make_unique(parameters_->batch_size, - parameters_->num_beams, - parameters_->max_length, - parameters_->length_penalty, - parameters_->early_stopping, - parameters_->num_return_sequences, - parameters_->pad_token_id, - parameters_->eos_token_id); + beam_scorer_ = std::make_unique>(parameters_->batch_size, + parameters_->num_beams, + parameters_->max_length, + parameters_->length_penalty, + parameters_->early_stopping, + parameters_->num_return_sequences, + parameters_->pad_token_id, + parameters_->eos_token_id); const OrtValue& input_ids = feeds[0]; #ifdef DEBUG_BEAM_SEARCH DumpOrtValue("input_ids", input_ids); @@ -949,7 +949,7 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { #endif } - gsl::span beam_scores(beam_state_.beam_scores.data(), beam_state_.beam_scores.size()); + gsl::span beam_scores(beam_state_.beam_scores.data(), beam_state_.beam_scores.size()); beam_scorer_->Finalize(&(beam_state_.sequences), beam_scores, output_sequences, @@ -958,5 +958,9 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { return status; } +// Instantiation +template class BeamSearchImpl; +template class BeamSearch; + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index e14500d1c8417..2c0d4e8139356 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -21,8 +21,8 @@ struct GptSubgraphInfo { int num_implicit_inputs; - int num_subgraph_inputs; // same as subgraph_input_names.size(), keep it for convenience. - int num_subgraph_outputs; // same as subgraph_output_names.size() + int num_subgraph_inputs; // same as subgraph_input_names.size(), keep it for convenience. + int num_subgraph_outputs; // same as subgraph_output_names.size() std::vector subgraph_input_names; std::vector subgraph_output_names; @@ -30,8 +30,8 @@ struct GptSubgraphInfo { // This class keeps track of sequences generated. class Sequences : public ISequences { -public: - Sequences(){} + public: + Sequences() {} // Initialize the sequence with initial input_ids and related parameters. void Init(const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length); @@ -47,10 +47,10 @@ class Sequences : public ISequences { // Select sequences based on beam indices, then append next token to selected sequences. void AppendNextTokenToSequences( - gsl::span& beam_indices, - gsl::span& beam_next_tokens); + gsl::span& beam_indices, + gsl::span& beam_next_tokens); -private: + private: // Two buffers of shape (batch_size, num_beams, max_seq_length) to store sequences. // At each time, there is only one buffer is active. The other one will be active in next token. // Each AppendNextTokenToSequences call will trigger a rotation of active buffer. @@ -61,31 +61,31 @@ class Sequences : public ISequences { int batch_beam_size_; int max_length_; - int current_length_; + int current_length_; }; +template struct BeamSearchState { // TODO: use allocater to allocate a buffer, and point each data to a span of the buffer // so as to reuse related code in CUDA. - std::vector done; // shape (batch_size) - std::vector beam_scores; // shape (batch_size, num_beams) - - std::vector next_token_logits; // shape (batch_size * num_beams, vocab_size) - std::vector next_token_scores; // shape (batch_size, num_beams * vocab_size) + std::vector done; // shape (batch_size) + std::vector beam_scores; // shape (batch_size, num_beams) - std::vector next_tokens; // shape (batch_size, num_beams) - std::vector next_indices; // shape (batch_size, num_beams) + std::vector next_token_logits; // shape (batch_size * num_beams, vocab_size) + std::vector next_token_scores; // shape (batch_size, num_beams * vocab_size) + + std::vector next_tokens; // shape (batch_size, num_beams) + std::vector next_indices; // shape (batch_size, num_beams) Sequences sequences; - std::vector scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) + std::vector scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) void Init(const OrtValue& input_ids, int batch_size, int num_beams, int vocab_size, int sequence_length, int max_length) { int batch_beam_size = batch_size * num_beams; done.assign(batch_size, 0); beam_scores.assign(batch_beam_size, 0.0f); - for (int i = 0; i < batch_size; i++) - { + for (int i = 0; i < batch_size; i++) { for (int j = 1; j < num_beams; j++) { beam_scores[i * num_beams + j] = -1e9; } @@ -118,13 +118,13 @@ class BeamSearch : public controlflow::IControlFlowKernel { static std::unique_ptr Create(const OpKernelInfo& info, void* stream); protected: - void SetComputeStream(void* stream) { stream_ = stream; } + void SetComputeStream(void* stream) { stream_ = stream; } private: // Subgraph info and FeedsFetchesManager re-used for each subgraph execution. std::unique_ptr subgraph_info_; std::unique_ptr feeds_fetches_manager_; - + void* stream_; BeamSearchParameters parameters_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index f7c0b0e6d3ff3..0908d60098160 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -16,18 +16,21 @@ namespace contrib { using ::onnxruntime::rnn::detail::Allocate; -BeamHypotheses::BeamHypotheses(int num_beams, float length_penalty, bool early_stopping) +template +BeamHypotheses::BeamHypotheses(int num_beams, T length_penalty, bool early_stopping) : num_beams_(num_beams), length_penalty_(length_penalty), early_stopping_(early_stopping), worst_score_(1e9) {} -void BeamHypotheses::Add(gsl::span& hypothesis, float sum_logprobs) { +template +void BeamHypotheses::Add(gsl::span& hypothesis, T sum_logprobs) { auto length = hypothesis.size(); - float score = sum_logprobs / pow(static_cast(length), length_penalty_); + // TODO: may need compute in FP32 when T is FP16 + T score = sum_logprobs / pow(static_cast(length), length_penalty_); if (this->Size() < num_beams_ || score > worst_score_) { - HypothesisScore item(hypothesis, score); + HypothesisScore item(hypothesis, score); beams_.push(item); if (this->Size() > num_beams_) { beams_.pop(); @@ -36,7 +39,8 @@ void BeamHypotheses::Add(gsl::span& hypothesis, float sum_logprob } } -bool BeamHypotheses::IsDone(float best_sum_logprobs, int current_length) { +template +bool BeamHypotheses::IsDone(T best_sum_logprobs, int current_length) { // If there are enough hypotheses and that none of the hypotheses being generated can become better // than the worst one in the heap, then we are done with this sentence. @@ -46,15 +50,16 @@ bool BeamHypotheses::IsDone(float best_sum_logprobs, int current_length) { if (early_stopping_) return true; - float current_score = best_sum_logprobs / pow(static_cast(current_length), length_penalty_); + T current_score = best_sum_logprobs / pow(static_cast(current_length), length_penalty_); return worst_score_ >= current_score; } -void BeamHypotheses::Output( +template +void BeamHypotheses::Output( int top_k, int max_length, - gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) - gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty + gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) + gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty { ORT_ENFORCE(top_k <= Size()); int remove_count = Size() - top_k; @@ -73,7 +78,7 @@ void BeamHypotheses::Output( // Note that word_ids might be less than max_length. // Since the sequences has been filled with pad token ID, so padding is not needed here. // Since data type need cast from int64_t to int32_t, we cannot use gsl::copy(word_ids, sequence) here. - for (size_t i = 0; i < source.length(); i++){ + for (size_t i = 0; i < source.length(); i++) { target[i] = static_cast(source[i]); } @@ -85,14 +90,15 @@ void BeamHypotheses::Output( } } -BeamSearchScorer::BeamSearchScorer(int batch_size, - int num_beams, - int max_length, - float length_penalty, - bool early_stopping, - int num_return_sequences, - int pad_token_id, - int eos_token_id) +template +BeamSearchScorer::BeamSearchScorer(int batch_size, + int num_beams, + int max_length, + T length_penalty, + bool early_stopping, + int num_return_sequences, + int pad_token_id, + int eos_token_id) : batch_size_(batch_size), num_beams_(num_beams), max_length_(max_length), @@ -110,7 +116,8 @@ BeamSearchScorer::BeamSearchScorer(int batch_size, } } -bool BeamSearchScorer::IsDone() { +template +bool BeamSearchScorer::IsDone() { for (int batch = 0; batch < batch_size_; batch++) { if (!done_[batch]) return false; @@ -118,11 +125,12 @@ bool BeamSearchScorer::IsDone() { return true; } -void BeamSearchScorer::Process(ISequences* sequences, - gsl::span& next_scores, - gsl::span& next_tokens, - gsl::span& next_indices, - AllocatorPtr& allocator) { +template +void BeamSearchScorer::Process(ISequences* sequences, + gsl::span& next_scores, + gsl::span& next_tokens, + gsl::span& next_indices, + AllocatorPtr& allocator) { // sequences shape is (batch_size * num_beams, total_sequence_length) // It contains word ID of whole sequence generated so far. // It is different from subgraph input_ids, which only need one word when past state is not empty. @@ -136,7 +144,7 @@ void BeamSearchScorer::Process(ISequences* sequences, if (next_beam_scores_.empty()) { size_t batch_beam_size = static_cast(batch_size_ * num_beams_); const bool fill_zeros = false; - next_beam_scores_ = Allocate(allocator, batch_beam_size, next_beam_scores_ptr_, fill_zeros); + next_beam_scores_ = Allocate(allocator, batch_beam_size, next_beam_scores_ptr_, fill_zeros); next_beam_tokens_ = Allocate(allocator, batch_beam_size, next_beam_tokens_ptr_, fill_zeros); next_beam_indices_ = Allocate(allocator, batch_beam_size, next_beam_indices_ptr_, fill_zeros); @@ -147,7 +155,7 @@ void BeamSearchScorer::Process(ISequences* sequences, } for (int batch = 0; batch < batch_size_; batch++) { - BeamHypotheses& beam_hyp = beam_hyps[batch]; + BeamHypotheses& beam_hyp = beam_hyps[batch]; if (done_[batch]) { ORT_ENFORCE(beam_hyp.Size() >= num_beams_, "Batch can only be done if all beams have been generated"); @@ -165,7 +173,7 @@ void BeamSearchScorer::Process(ISequences* sequences, int top_k = 2 * num_beams_; for (int j = 0; j < top_k; j++) { int64_t next_token = next_tokens[batch * top_k + j]; - float next_score = next_scores[batch * top_k + j]; + T next_score = next_scores[batch * top_k + j]; int64_t next_index = next_indices[batch * top_k + j]; int batch_beam_idx = batch * num_beams_ + static_cast(next_index); @@ -201,8 +209,8 @@ void BeamSearchScorer::Process(ISequences* sequences, // Check if we are done so that we can save a pad step if all(done) if (!done_[batch]) { - gsl::span topk_scores = next_scores.subspan(batch * num_beams_, top_k); - const float* best_sum_logprobs = std::max_element(topk_scores.begin(), topk_scores.end()); + gsl::span topk_scores = next_scores.subspan(batch * num_beams_, top_k); + const T* best_sum_logprobs = std::max_element(topk_scores.begin(), topk_scores.end()); if (beam_hyp.IsDone(*best_sum_logprobs, sequence_length)) { done_[batch] = true; } @@ -210,23 +218,24 @@ void BeamSearchScorer::Process(ISequences* sequences, } } -void BeamSearchScorer::Finalize(ISequences* sequences, - gsl::span& final_beam_scores, - Tensor* output_sequences, - Tensor* output_sequence_scores) { +template +void BeamSearchScorer::Finalize(ISequences* sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { ORT_ENFORCE(sequences != nullptr); ORT_ENFORCE(output_sequences != nullptr); // finalize all open beam hypotheses and add to generated hypotheses for (int batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps[batch_index]; + BeamHypotheses& beam_hyp = beam_hyps[batch_index]; if (done_[batch_index]) { continue; } for (int beam_index = 0; beam_index < num_beams_; beam_index++) { int batch_beam_index = batch_index * num_beams_ + beam_index; - float final_score = final_beam_scores[batch_beam_index]; + T final_score = final_beam_scores[batch_beam_index]; auto final_tokens = sequences->GetSequence(batch_beam_index); beam_hyp.Add(final_tokens, final_score); } @@ -239,22 +248,22 @@ void BeamSearchScorer::Finalize(ISequences* sequences, std::fill_n(output.data(), output.size(), pad_token_id_); // score of each sequence, with shape (batch_size * num_return_sequences) - gsl::span sequence_scores; - if (output_sequence_scores != nullptr){ - sequence_scores = output_sequence_scores->MutableDataAsSpan(); + gsl::span sequence_scores; + if (output_sequence_scores != nullptr) { + sequence_scores = output_sequence_scores->MutableDataAsSpan(); } // span is empty when output_sequence_scores is NULL. - gsl::span batch_sequence_score; + gsl::span batch_sequence_score; // Select the best hypotheses according to number of sequences to return. for (int batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps[batch_index]; + BeamHypotheses& beam_hyp = beam_hyps[batch_index]; const int num_return_sequences = num_beam_hyps_to_keep_; auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_, num_return_sequences * max_length_); - if (output_sequence_scores != nullptr){ + if (output_sequence_scores != nullptr) { batch_sequence_score = sequence_scores.subspan(batch_index * num_return_sequences, num_return_sequences); } @@ -266,5 +275,10 @@ void BeamSearchScorer::Finalize(ISequences* sequences, } } +// Instantiation +template class HypothesisScoreCompare; +template class BeamHypotheses; +template class BeamSearchScorer; + } // namespace contrib } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index 662b3218d3dc9..51dbbbe66688e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -15,73 +15,78 @@ namespace onnxruntime { namespace contrib { class ISequences { -public: + public: virtual gsl::span GetSequence(int beam_index) = 0; virtual int GetSequenceLength() = 0; }; // Interface for all scorers for beam search or beam sample. +template class IBeamScorer { public: virtual void Process(ISequences* sequences, - gsl::span& next_scores, + gsl::span& next_scores, gsl::span& next_tokens, gsl::span& next_indices, AllocatorPtr& allocator) = 0; virtual void Finalize(ISequences* sequences, - gsl::span& final_beam_scores, + gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) = 0; }; +template struct HypothesisScore { - HypothesisScore(gsl::span& _hypothesis, float _score) + HypothesisScore(gsl::span& _hypothesis, T _score) : hypothesis(_hypothesis), score(_score) {} gsl::span hypothesis; - float score; + T score; }; +template class HypothesisScoreCompare { public: - bool operator()(const HypothesisScore& a, const HypothesisScore& b) { + bool operator()(const HypothesisScore& a, const HypothesisScore& b) { return a.score > b.score; } }; +template class BeamHypotheses { public: - BeamHypotheses(int num_beams, float length_penalty, bool early_stopping); + BeamHypotheses(int num_beams, T length_penalty, bool early_stopping); // Number of hypotheses int Size() { return static_cast(beams_.size()); } // Add a new hypothesis - void Add(gsl::span& hypothesis, float sum_logprobs); + void Add(gsl::span& hypothesis, T sum_logprobs); - bool IsDone(float best_sum_logprobs, int current_length); + bool IsDone(T best_sum_logprobs, int current_length); // Output results. Note that it will clear all beams. - void Output(int top_k, // number of sequences to return - int max_length, // max sequence length - gsl::span& sequences, // buffer filled with pad token ID, with shape (num_return_sequences, max_length) - gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + void Output(int top_k, // number of sequences to return + int max_length, // max sequence length + gsl::span& sequences, // buffer filled with pad token ID, with shape (num_return_sequences, max_length) + gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) private: int num_beams_; - float length_penalty_; + T length_penalty_; bool early_stopping_; - float worst_score_; - std::priority_queue, HypothesisScoreCompare> beams_; // min-heap for top k + T worst_score_; + std::priority_queue, std::vector>, HypothesisScoreCompare> beams_; // min-heap for top k }; -class BeamSearchScorer : public IBeamScorer { +template +class BeamSearchScorer : public IBeamScorer { public: BeamSearchScorer(int batch_size, int num_beams, int max_length, - float length_penalty, + T length_penalty, bool early_stopping, int num_return_sequences, int pad_token_id, @@ -90,17 +95,17 @@ class BeamSearchScorer : public IBeamScorer { bool IsDone(); void Process(ISequences* sequences, - gsl::span& next_scores, + gsl::span& next_scores, gsl::span& next_tokens, gsl::span& next_indices, AllocatorPtr& allocator) override; void Finalize(ISequences* sequences, - gsl::span& final_beam_scores, + gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) override; - gsl::span& GetNextScores() { return next_beam_scores_; } + gsl::span& GetNextScores() { return next_beam_scores_; } gsl::span& GetNextTokens() { return next_beam_tokens_; } gsl::span& GetNextIndices() { return next_beam_indices_; } @@ -112,11 +117,11 @@ class BeamSearchScorer : public IBeamScorer { int pad_token_id_; int eos_token_id_; - std::vector beam_hyps; // List of batch result of beam search. Its shape is (batch_size) - std::vector done_; // List of flags indicates whether each batch is finished or not. Its shape is (batch_size). + std::vector> beam_hyps; // List of batch result of beam search. Its shape is (batch_size) + std::vector done_; // List of flags indicates whether each batch is finished or not. Its shape is (batch_size). - IAllocatorUniquePtr next_beam_scores_ptr_; - gsl::span next_beam_scores_; + IAllocatorUniquePtr next_beam_scores_ptr_; + gsl::span next_beam_scores_; IAllocatorUniquePtr next_beam_tokens_ptr_; gsl::span next_beam_tokens_; diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 12818431b5775..3c0528df3a0af 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -12,7 +12,6 @@ from gpt2_helper import PRETRAINED_GPT2_MODELS from convert_to_onnx import main as convert_gpt2_to_onnx from benchmark_helper import Precision - """ This converts GPT2 model to onnx with beam search operator. @@ -20,10 +19,11 @@ python convert_beam_search.py -m gpt2 --gpt2_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores """ -config:GPT2Config = None +config: GPT2Config = None logger = logging.getLogger('') + def parse_arguments(argv=None): parser = argparse.ArgumentParser() @@ -186,18 +186,19 @@ def gpt2_to_onnx(args): from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference out = SymbolicShapeInference.infer_shapes(onnx.load(args.gpt2_onnx), auto_merge=True, guess_output_rank=False) if out: - onnx.save(out, args.gpt2_onnx) + onnx.save(out, args.gpt2_onnx) + def create_ort_session(model_path, use_gpu): from onnxruntime import SessionOptions, InferenceSession, __version__ as ort_version, GraphOptimizationLevel sess_options = SessionOptions() sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL - execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider' - ] if use_gpu else ['CPUExecutionProvider'] + execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider'] ort_session = InferenceSession(model_path, sess_options, providers=execution_providers) return ort_session + def convert_model(args): if os.path.exists(args.gpt2_onnx): print(f"skip convert_to_onnx since path existed: {args.gpt2_onnx}") @@ -217,8 +218,8 @@ def convert_model(args): model = onnx.load(args.gpt2_onnx) model.graph.name = "gpt2 subgraph" inputs = [ - "input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", - "length_penalty", "repetition_penalty", "vocab_mask" + "input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty", + "repetition_penalty", "vocab_mask" ] outputs = ["sequences"] @@ -257,21 +258,20 @@ def convert_model(args): # graph outputs sequences = helper.make_tensor_value_info('sequences', TensorProto.INT32, - ['batch_size', 'num_return_sequences', 'max_length']) - + ['batch_size', 'num_return_sequences', 'max_length']) + sequences_scores = helper.make_tensor_value_info('sequences_scores', TensorProto.FLOAT, ['batch_size', 'num_return_sequences']) - scores = helper.make_tensor_value_info( - 'scores', TensorProto.FLOAT, - ['max_length - sequence_length', 'batch_size', 'num_beams', vocab_size]) + scores = helper.make_tensor_value_info('scores', TensorProto.FLOAT, + ['max_length - sequence_length', 'batch_size', 'num_beams', vocab_size]) initializers = [] graph_outputs = [sequences] - + if args.output_sequences_scores: graph_outputs.append(sequences_scores) - + if args.output_token_scores: graph_outputs.append(scores) @@ -302,17 +302,17 @@ def test_model(args): print('-' * 50) print("Test PyTorch model and beam search with huggingface transformers...") beam_outputs = model.generate(input_ids, - max_length=args.max_length, - min_length=args.min_length, - num_beams=args.num_beams, - early_stopping=args.early_stopping, - no_repeat_ngram_size=args.no_repeat_ngram_size, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - num_return_sequences=args.num_return_sequences, - temperature=args.temperature, - length_penalty=args.length_penalty, - repetition_penalty=args.repetition_penalty) + max_length=args.max_length, + min_length=args.min_length, + num_beams=args.num_beams, + early_stopping=args.early_stopping, + no_repeat_ngram_size=args.no_repeat_ngram_size, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + num_return_sequences=args.num_return_sequences, + temperature=args.temperature, + length_penalty=args.length_penalty, + repetition_penalty=args.repetition_penalty) print("input_ids", input_ids) print("huggingface transformers output:", beam_outputs) for i, beam_output in enumerate(beam_outputs): @@ -325,7 +325,7 @@ def test_model(args): import time print('You have 15 seconds to attach a debugger.') time.sleep(15) - + ort_session = create_ort_session(args.output, args.use_gpu) batch_size = 1 @@ -353,14 +353,15 @@ def test_model(args): print("inputs", inputs) result = ort_session.run(None, inputs) - + sequences = result[0] print("outputs", sequences) - + #TODO: print all sequences. Below shows only the first one first_sequence = tokenizer.decode(sequences[0][0], skip_special_tokens=True) print(first_sequence) + def main(): args = parse_arguments() @@ -371,5 +372,6 @@ def main(): test_model(args) + if __name__ == '__main__': main() From 09e2458c7f4559915a523cd08252ff7a3b90b4df Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 18 Nov 2021 14:53:18 -0800 Subject: [PATCH 15/53] fix build pipeline errors --- onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc | 2 +- onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index 0908d60098160..9cd52eec37156 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -189,7 +189,7 @@ void BeamSearchScorer::Process(ISequences* sequences, auto clone = hypothesis_buffer_.subspan(hypothesis_buffer_offset_, sequence_length); gsl::copy(src, clone); hypothesis_buffer_offset_ += sequence_length; - auto sequence = clone.as_span(); + auto sequence = clone.template as_span(); beam_hyp.Add(sequence, next_score); } else { // Add next predicted token since it is not eos_token diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index 51dbbbe66688e..bedf212c95f7f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -24,6 +24,8 @@ class ISequences { template class IBeamScorer { public: + virtual ~IBeamScorer() {} + virtual void Process(ISequences* sequences, gsl::span& next_scores, gsl::span& next_tokens, From afe4b129973019ee538513b83952c4372cd68261 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 18 Nov 2021 16:17:44 -0800 Subject: [PATCH 16/53] symbolic shape infer of input onnx --- .../python/tools/transformers/convert_beam_search.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 3c0528df3a0af..7a9bab4bac012 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -182,11 +182,16 @@ def gpt2_to_onnx(args): convert_gpt2_to_onnx(arguments) + +def shape_inference(gpt2_onnx_path): # Run symbolic shape inference to walk around ORT shape inference issue for subgraph. from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - out = SymbolicShapeInference.infer_shapes(onnx.load(args.gpt2_onnx), auto_merge=True, guess_output_rank=False) + out = SymbolicShapeInference.infer_shapes(onnx.load(gpt2_onnx_path), auto_merge=True, guess_output_rank=False) if out: - onnx.save(out, args.gpt2_onnx) + # TODO: Use external format if input has extra data. + onnx.save(out, gpt2_onnx_path) + else: + print("Failed to run symbolic shape inference on the model.") def create_ort_session(model_path, use_gpu): @@ -205,7 +210,8 @@ def convert_model(args): else: gpt2_to_onnx(args) - #create_ort_session(args.gpt2_onnx, args.use_gpu) + print(f"Run symbolic shape inference on {args.gpt2_onnx}. The file will be overwritten.") + shape_inference(args.gpt2_onnx) global config config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) From 2396c3b41da7422c9f2db801de29d516d4be9026 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Nov 2021 13:59:43 -0800 Subject: [PATCH 17/53] output scores --- .../cpu/transformers/beam_search.cc | 35 ++++++++++++++--- .../cpu/transformers/beam_search.h | 6 ++- .../transformers/beam_search_parameters.cc | 26 ++++++++----- .../cpu/transformers/beam_search_parameters.h | 3 ++ .../cpu/transformers/beam_search_scorer.cc | 2 +- .../tools/transformers/convert_beam_search.py | 38 +++++++++++++------ 6 files changed, 81 insertions(+), 29 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index c94a6f78657ff..1f63a981fd248 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -463,6 +463,9 @@ Status BeamSearchImpl::Initialize() { // CheckInputs shall be after CheckSubgraph due to its dependency on vocab_size ORT_RETURN_IF_ERROR(CheckInputs(context_)); + // This flag will be updated later when the scores output exists. + parameters_->output_scores = false; + return status; } @@ -658,10 +661,9 @@ Status BeamSearchImpl::ProcessLogits( } } - // TODO: Store scores only when required - // if output_scores: - // scores += (next_token_scores,) - beam_state.scores.insert(beam_state.scores.end(), next_token_scores.begin(), next_token_scores.end()); + if (parameters_->output_scores) { + beam_state.scores.insert(beam_state.scores.end(), next_token_scores.begin(), next_token_scores.end()); + } //next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) //next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True) @@ -889,6 +891,15 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { TensorShape sequences_scores_shape(sequences_scores_dims); Tensor* output_sequences_scores = context_.Output(1, sequences_scores_shape); + std::vector scores_dims{ + parameters_->max_length - parameters_->sequence_length, + parameters_->batch_size, parameters_->num_beams, parameters_->vocab_size}; + TensorShape scores_shape(scores_dims); + Tensor* output_scores = context_.Output(2, scores_shape); + + // Update the flag to indicate whether scores exists in output + parameters_->output_scores = (output_scores != nullptr); + std::vector feeds; std::vector fetches; @@ -915,7 +926,8 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { parameters_->num_beams, parameters_->vocab_size, parameters_->sequence_length, - parameters_->max_length); + parameters_->max_length, + parameters_->output_scores); int current_length = parameters_->sequence_length; while (current_length < parameters_->max_length) { @@ -955,6 +967,19 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { output_sequences, output_sequences_scores); + // Output per token scores + if (output_scores != nullptr) { + gsl::span target = output_scores->MutableDataAsSpan(); + gsl::span source = gsl::span(beam_state_.scores.data(), beam_state_.scores.size()); + gsl::copy(source, target); + + // Fill zeros for the remaining when beam search stopped early + if (target.length() > source.length()) { + gsl::span remaining = target.subspan(source.length()); + memset(remaining.data(), 0, remaining.size_bytes()); + } + } + return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 2c0d4e8139356..23634b9efb5cb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -81,7 +81,7 @@ struct BeamSearchState { std::vector scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) - void Init(const OrtValue& input_ids, int batch_size, int num_beams, int vocab_size, int sequence_length, int max_length) { + void Init(const OrtValue& input_ids, int batch_size, int num_beams, int vocab_size, int sequence_length, int max_length, bool output_scores) { int batch_beam_size = batch_size * num_beams; done.assign(batch_size, 0); beam_scores.assign(batch_beam_size, 0.0f); @@ -99,7 +99,9 @@ struct BeamSearchState { sequences.Init(input_ids, batch_beam_size, sequence_length, max_length); - scores.reserve((max_length - sequence_length + 1) * batch_size * num_beams * vocab_size); + if (output_scores) { + scores.reserve((max_length - sequence_length) * batch_size * num_beams * vocab_size); + } } }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index a830a1da86830..ec00614db508e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "beam_search_parameters.h" +constexpr int kMaxSequenceLength = 4096; + namespace onnxruntime { namespace contrib { @@ -18,41 +20,45 @@ void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) { no_repeat_ngram_size = static_cast(info.GetAttrOrDefault("no_repeat_ngram_size", 0)); } -void BeamSearchParameters::ParseFromInputs(OpKernelContext* context){ +void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { ORT_ENFORCE(context != nullptr); const Tensor* input_ids = context->Input(0); const auto& dims = input_ids->Shape().GetDims(); - if (dims.size() == 2) { - batch_size = static_cast(dims[0]); - sequence_length = static_cast(dims[1]); - } else { - batch_size = 0; - sequence_length = 0; - } - + ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size()); + batch_size = static_cast(dims[0]); + sequence_length = static_cast(dims[1]); + auto* max_length_tensor = context->Input(1); max_length = max_length_tensor ? static_cast(*max_length_tensor->Data()) : 4096; + ORT_ENFORCE(max_length > sequence_length, "max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")"); + ORT_ENFORCE(max_length <= kMaxSequenceLength, "max_length (", max_length, ") shall be no more than ", kMaxSequenceLength); auto* min_length_tensor = context->Input(2); min_length = min_length_tensor ? static_cast(*min_length_tensor->Data()) : 0; auto* num_beams_tensor = context->Input(3); num_beams = num_beams_tensor ? static_cast(*num_beams_tensor->Data()) : 1; + // TODO: shall we limit num_beams > 1. When num_beams==1, we can have another operator for greedy search. + ORT_ENFORCE(num_beams >= 1, "num_beams shall be a positive integer, got ", num_beams); auto* num_return_sequences_tensor = context->Input(4); num_return_sequences = num_return_sequences_tensor ? static_cast(*num_return_sequences_tensor->Data()) : 1; + ORT_ENFORCE(num_return_sequences >= 1, "num_return_sequences shall be a positive integer, got ", num_return_sequences); + ORT_ENFORCE(num_beams >= num_return_sequences, "num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")"); auto* temperature_tensor = context->Input(5); temperature = temperature_tensor ? static_cast(*temperature_tensor->Data()) : 1; + ORT_ENFORCE(temperature > 0.0f, "temperature shall be greater than 0, got ", temperature); auto* length_penalty_tensor = context->Input(6); length_penalty = length_penalty_tensor ? static_cast(*length_penalty_tensor->Data()) : 1; auto* repetition_penalty_tensor = context->Input(7); repetition_penalty = repetition_penalty_tensor ? static_cast(*repetition_penalty_tensor->Data()) : 1.0f; + ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty); } -void BeamSearchParameters::SetSubgraphParameters(int heads, int hidden_size_per_head, int vocabulary_size, int layers){ +void BeamSearchParameters::SetSubgraphParameters(int heads, int hidden_size_per_head, int vocabulary_size, int layers) { num_heads = heads; head_size = hidden_size_per_head; vocab_size = vocabulary_size; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index 95753e1d797c9..0456335865ce3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -25,6 +25,9 @@ struct BeamSearchParameters { float repetition_penalty; int batch_size; // deduce from first dimension of input_ids int sequence_length; // deduce from second dimension of input_ids + + // from outputs + bool output_scores; // whether scores existed in output // deduce from subgraph int num_heads; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index 9cd52eec37156..9f69f2d8d3dd1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -26,7 +26,7 @@ BeamHypotheses::BeamHypotheses(int num_beams, T length_penalty, bool early_st template void BeamHypotheses::Add(gsl::span& hypothesis, T sum_logprobs) { auto length = hypothesis.size(); - // TODO: may need compute in FP32 when T is FP16 + // TODO: when T is FP16, compute in FP32, then cast result back to FP16. length_penalty_ might also be float. T score = sum_logprobs / pow(static_cast(length), length_penalty_); if (this->Size() < num_beams_ || score > worst_score_) { diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 7a9bab4bac012..327bab017c363 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -231,7 +231,9 @@ def convert_model(args): outputs = ["sequences"] if args.output_sequences_scores: outputs.append("sequences_scores") + if args.output_token_scores: + assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores" outputs.append("scores") node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name='BeamSearch_GPT2') @@ -268,6 +270,7 @@ def convert_model(args): sequences_scores = helper.make_tensor_value_info('sequences_scores', TensorProto.FLOAT, ['batch_size', 'num_return_sequences']) + scores = helper.make_tensor_value_info('scores', TensorProto.FLOAT, ['max_length - sequence_length', 'batch_size', 'num_beams', vocab_size]) @@ -318,11 +321,19 @@ def test_model(args): num_return_sequences=args.num_return_sequences, temperature=args.temperature, length_penalty=args.length_penalty, - repetition_penalty=args.repetition_penalty) + repetition_penalty=args.repetition_penalty, + return_dict_in_generate=True, + output_scores=True + ) print("input_ids", input_ids) - print("huggingface transformers output:", beam_outputs) - for i, beam_output in enumerate(beam_outputs): - print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True))) + print("huggingface transformers outputs:") + print("sequences", beam_outputs.sequences) + if args.output_sequences_scores: + print("sequences_scores", beam_outputs.sequences_scores) + if args.output_token_scores: + print("scores", beam_outputs.scores) + for i, sequence in enumerate(beam_outputs.sequences): + print("{}: {}".format(i, tokenizer.decode(sequence, skip_special_tokens=True))) print('-' * 50) print("Test ONNX model and bream search with onnxruntime...") @@ -359,14 +370,19 @@ def test_model(args): print("inputs", inputs) result = ort_session.run(None, inputs) - + print("ORT outputs:") sequences = result[0] - print("outputs", sequences) - - #TODO: print all sequences. Below shows only the first one - first_sequence = tokenizer.decode(sequences[0][0], skip_special_tokens=True) - print(first_sequence) - + print("sequences", sequences) + if args.output_sequences_scores: + print("sequences_scores", result[1]) + if args.output_token_scores: + print("scores", result[2]) + + (batch_size, num_sequences, max_length) = sequences.shape + for i in range(batch_size): + for j in range(num_sequences): + sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True) + print(f"batch {i} sequence {j}: {sequence}") def main(): args = parse_arguments() From b444e70482ad32773149d470627b72f5becf9aa3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Nov 2021 14:00:14 -0800 Subject: [PATCH 18/53] add kernel def hash --- onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json index bb0f31e904196..e87ce07f0324f 100644 --- a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json +++ b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json @@ -3,6 +3,10 @@ "Affine ai.onnx CPUExecutionProvider", 7811918192248490408 ], + [ + "BeamSearch com.microsoft CPUExecutionProvider", + 6968087233460196528 + ], [ "Crop ai.onnx CPUExecutionProvider", 6914973556202621376 From bb032c244f73393e868bfbda64f3ee01b2e6ecaf Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Nov 2021 23:17:26 -0800 Subject: [PATCH 19/53] Handle vocab_mask; move CheckSubgraph --- .../cpu/transformers/beam_search.cc | 139 ++++++++++-------- .../cpu/transformers/beam_search.h | 3 + .../cpu/transformers/beam_search_parameters.h | 2 + .../tools/transformers/convert_beam_search.py | 24 ++- 4 files changed, 99 insertions(+), 69 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 1f63a981fd248..d8714ea4e9d73 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -162,9 +162,6 @@ class BeamSearchImpl { private: Status CheckInputs(const OpKernelContextInternal& context); - Status CheckSubgraph(const std::vector& subgraph_inputs, - const std::vector& subgraph_outputs) const; - OrtValue ExpandInputs(const OrtValue& input_ids, int num_beams) const; // Prepare the inputs for first inference of subgraph @@ -188,7 +185,8 @@ class BeamSearchImpl { int top_k, AllocatorPtr& allocator); - void ProcessNextTokenScores(gsl::span& next_token_scores); + // Mask tokens accroding to vocab_mask + void ApplyVocabMask(gsl::span& next_token_scores); // Reorder cache by picking the past state based on beam indices void PickPastState(const std::vector& last_outputs, @@ -302,6 +300,61 @@ common::Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& ses feeds_fetches_manager_ = std::move(ffm); + // CheckSubgraph is moved here so that it only need called once instead of every inference run. + auto& inputs = subgraph_info_->subgraph.GetInputs(); + auto& outputs = subgraph_info_->subgraph.GetOutputs(); + ORT_RETURN_IF_ERROR(CheckSubgraph(inputs, outputs)); + + return Status::OK(); +} + +template +Status BeamSearch::CheckSubgraph(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) { + ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ", + subgraph_inputs[0]->Name()); + ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", "subgraph input 1 shall be named as position_ids, got: ", + subgraph_inputs[1]->Name()); + ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", "subgraph input 2 shall be named as attention_mask, got: ", + subgraph_inputs[2]->Name()); + ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", "subgraph input 3 shall be named as past_0, got: ", + subgraph_inputs[3]->Name()); + + // Past state shape is like (2, batch_size, 12, past_seq_len, 64). Here 12 and 64 are constants of num_heads and hidden_size/num_heads. + const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape(); + ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ", + past_shape->dim_size()); + + ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2, + "subgraph past state dimension 0 shall have length of 2"); + + ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for number of heads"); + + ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0, + "subgraph past state dimension 4 shall have a positive value for hidden size per head"); + + // check subgraph outputs + ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", "subgraph output 0 shall be named as logits, got: ", + subgraph_outputs[0]->Name()); + + ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", "subgraph input 1 shall be named as present_0, got: ", + subgraph_outputs[1]->Name()); + + // Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size. + const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); + ORT_RETURN_IF(logits_shape->dim_size() != 3, "subgraph logits output is expected to have 3 dimension, got ", + logits_shape->dim_size()); + + ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for vocabulary size"); + + int num_heads = static_cast(past_shape->dim(2).dim_value()); + int head_size = static_cast(past_shape->dim(4).dim_value()); + int vocab_size = static_cast(logits_shape->dim(2).dim_value()); + int num_layers = static_cast(subgraph_outputs.size()) - 1; + parameters_.SetSubgraphParameters(num_heads, head_size, vocab_size, num_layers); + return Status::OK(); } @@ -372,57 +425,10 @@ Status BeamSearchImpl::CheckInputs(const OpKernelContextInternal& context) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' shape does not match with vocab_size, got ", vocab_mask_dims[0]); } - } - - return Status::OK(); -} - -template -Status BeamSearchImpl::CheckSubgraph(const std::vector& subgraph_inputs, - const std::vector& subgraph_outputs) const { - ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ", - subgraph_inputs[0]->Name()); - ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", "subgraph input 1 shall be named as position_ids, got: ", - subgraph_inputs[1]->Name()); - ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", "subgraph input 2 shall be named as attention_mask, got: ", - subgraph_inputs[2]->Name()); - ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", "subgraph input 3 shall be named as past_0, got: ", - subgraph_inputs[3]->Name()); - - // Past state shape is like (2, batch_size, 12, past_seq_len, 64). Here 12 and 64 are constants of num_heads and hidden_size/num_heads. - const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape(); - ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ", - past_shape->dim_size()); - - ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2, - "subgraph past state dimension 0 shall have length of 2"); - - ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0, - "subgraph past state dimension 2 shall have a positive value for number of heads"); - - ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0, - "subgraph past state dimension 4 shall have a positive value for hidden size per head"); - // check subgraph outputs - ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", "subgraph output 0 shall be named as logits, got: ", - subgraph_outputs[0]->Name()); - - ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", "subgraph input 1 shall be named as present_0, got: ", - subgraph_outputs[1]->Name()); - - // Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size. - const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); - ORT_RETURN_IF(logits_shape->dim_size() != 3, "subgraph logits output is expected to have 3 dimension, got ", - logits_shape->dim_size()); - - ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0, - "subgraph past state dimension 2 shall have a positive value for vocabulary size"); - - int num_heads = static_cast(past_shape->dim(2).dim_value()); - int head_size = static_cast(past_shape->dim(4).dim_value()); - int vocab_size = static_cast(logits_shape->dim(2).dim_value()); - int num_layers = static_cast(subgraph_outputs.size()) - 1; - parameters_->SetSubgraphParameters(num_heads, head_size, vocab_size, num_layers); + // store vocab mask in parameters. + parameters_->vocab_mask = vocab_mask->DataAsSpan(); + } return Status::OK(); } @@ -456,10 +462,6 @@ Status BeamSearchImpl::Initialize() { ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, "'num_return_sequences' has to be smaller or equal to 'num_beams'."); - auto& inputs = subgraph_info_.subgraph.GetInputs(); - auto& outputs = subgraph_info_.subgraph.GetOutputs(); - ORT_RETURN_IF_ERROR(CheckSubgraph(inputs, outputs)); - // CheckInputs shall be after CheckSubgraph due to its dependency on vocab_size ORT_RETURN_IF_ERROR(CheckInputs(context_)); @@ -645,9 +647,9 @@ Status BeamSearchImpl::ProcessLogits( return status; } - // Extra processing: next_token_scores = logits_processor(input_ids, next_token_scores) - // where input_ids is current sequences in beam_state_ - ProcessNextTokenScores(next_token_scores); + // Apply all logits processors that modify scores + ApplyVocabMask(next_token_scores); + // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) // TODO: use thread pool to parrellel @@ -760,7 +762,20 @@ Status BeamSearchImpl::GenerateNextToken( } template -void BeamSearchImpl::ProcessNextTokenScores(gsl::span& /*next_token_scores*/) { +void BeamSearchImpl::ApplyVocabMask(gsl::span& next_token_scores) { + // Process vocabulary mask and set tokens with mask value 0 to -inf. + auto& vocab_mask = parameters_->vocab_mask; + if (!vocab_mask.empty()) { + T* p = next_token_scores.data(); + // next_token_scores shape (batch_size * num_beams, vocab_size), vocab_mask shape (vocab_size) + for (int i = 0; i < parameters_->batch_size * parameters_->num_beams; i++) { + for (int j = 0; j < parameters_->vocab_size; j++, p++) { + if (vocab_mask[j] == 0) { + *p = std::numeric_limits::lowest(); + } + } + } + } return; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 23634b9efb5cb..bc229a6856f24 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -120,6 +120,9 @@ class BeamSearch : public controlflow::IControlFlowKernel { static std::unique_ptr Create(const OpKernelInfo& info, void* stream); protected: + Status CheckSubgraph(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs); + void SetComputeStream(void* stream) { stream_ = stream; } private: diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index 0456335865ce3..b29c49696d1d0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -26,6 +26,8 @@ struct BeamSearchParameters { int batch_size; // deduce from first dimension of input_ids int sequence_length; // deduce from second dimension of input_ids + gsl::span vocab_mask; + // from outputs bool output_scores; // whether scores existed in output diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 327bab017c363..abeabc65ca82f 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -231,7 +231,7 @@ def convert_model(args): outputs = ["sequences"] if args.output_sequences_scores: outputs.append("sequences_scores") - + if args.output_token_scores: assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores" outputs.append("scores") @@ -270,7 +270,7 @@ def convert_model(args): sequences_scores = helper.make_tensor_value_info('sequences_scores', TensorProto.FLOAT, ['batch_size', 'num_return_sequences']) - + scores = helper.make_tensor_value_info('scores', TensorProto.FLOAT, ['max_length - sequence_length', 'batch_size', 'num_beams', vocab_size]) @@ -299,6 +299,11 @@ def test_model(args): pad_token_id=tokenizer.eos_token_id) input_ids = tokenizer.encode('I enjoy walking in the park', return_tensors='pt') + bad_words = "walk in park" + bad_words_ids = tokenizer.encode(bad_words, add_prefix_space=True) + bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list + print("bad_words_ids", bad_words_ids) + global config if config is None: config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) @@ -322,9 +327,9 @@ def test_model(args): temperature=args.temperature, length_penalty=args.length_penalty, repetition_penalty=args.repetition_penalty, + bad_words_ids=bad_words_ids, return_dict_in_generate=True, - output_scores=True - ) + output_scores=True) print("input_ids", input_ids) print("huggingface transformers outputs:") print("sequences", beam_outputs.sequences) @@ -348,6 +353,10 @@ def test_model(args): batch_size = 1 input_ids = input_ids.repeat(batch_size, 1) + vocab_mask = np.ones((vocab_size), dtype=np.int32) + for bad_word_id in bad_words_ids: + vocab_mask[bad_word_id] = 0 + inputs = { "input_ids": input_ids.cpu().numpy().astype(np.int32), "max_length": np.array([args.max_length], dtype=np.int32), @@ -357,7 +366,7 @@ def test_model(args): "temperature": np.array([args.temperature], dtype=np.float32), "length_penalty": np.array([args.length_penalty], dtype=np.float32), "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), - "vocab_mask": np.ones((vocab_size), dtype=np.int32) + "vocab_mask": vocab_mask } test_data_dir = Path(args.output).parent.as_posix() @@ -377,13 +386,14 @@ def test_model(args): print("sequences_scores", result[1]) if args.output_token_scores: print("scores", result[2]) - - (batch_size, num_sequences, max_length) = sequences.shape + + (batch_size, num_sequences, max_length) = sequences.shape for i in range(batch_size): for j in range(num_sequences): sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True) print(f"batch {i} sequence {j}: {sequence}") + def main(): args = parse_arguments() From fb275641e5e66b53ddcc38b2b7c9bca1d08ff58e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 20 Nov 2021 00:05:24 -0800 Subject: [PATCH 20/53] undo insert_cast_transformer.cc and fusion_utils.py --- .../core/optimizer/insert_cast_transformer.cc | 20 +++++++++---------- .../python/tools/transformers/fusion_utils.py | 3 +-- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 6741cc04ea6f1..82d8a501753aa 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -176,19 +176,17 @@ enum TypeGroup { }; TypeGroup GetTypeGroup(DataType type) { - if (type != nullptr) { - if (*type == "tensor(bool)") { - return Bool; - } + if (*type == "tensor(bool)") { + return Bool; + } - if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)" || - *type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { - return Integer; - } + if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)" || + *type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + return Integer; + } - if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { - return Float; - } + if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { + return Float; } return Unknown; diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index 760d26032a3ea..ae0587bd24933 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -36,8 +36,7 @@ def cast_input_to_int32(self, input_name: str): if parent_node and parent_node.op_type == 'Cast': inputs = [parent_node.input[0]] - node_name = self.model.create_node_name('Cast') - cast_node = helper.make_node('Cast', inputs=inputs, outputs=[cast_output], name=node_name) + cast_node = helper.make_node('Cast', inputs=inputs, outputs=[cast_output]) cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.INT32))]) self.model.add_node(cast_node) From cdb62bb3e7d1d447ea8ee69a3c5fb140ffdc9d35 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 20 Nov 2021 00:18:26 -0800 Subject: [PATCH 21/53] fix typo --- .../tools/transformers/convert_beam_search.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index abeabc65ca82f..9abde89d78d25 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -99,7 +99,7 @@ def parse_arguments(argv=None): type=int, required=False, default=1, - help='Number of return sequence') + help='Number of return sequence <= num_beams') beam_search_group.add_argument('--temperature', type=float, @@ -119,33 +119,33 @@ def parse_arguments(argv=None): default=1, help='Positive. >1 to penalize and <1 to encorage.') - mixed_precision_option_grapu = parser.add_argument_group( + mixed_precision_option_group = parser.add_argument_group( "mixed precision conversion parameters that works when \"--precision fp16\" is specified") - mixed_precision_option_grapu.add_argument('--io_block_list', + mixed_precision_option_group.add_argument('--io_block_list', nargs='+', required=False, default=[], help='List of inputs or outputs in float32') - mixed_precision_option_grapu.add_argument( + mixed_precision_option_group.add_argument( '--op_block_list', nargs='+', required=False, default=[], help='List of operators (like Add LayerNormalization FastGelu) to compute in float32.') - mixed_precision_option_grapu.add_argument('--node_block_list', + mixed_precision_option_group.add_argument('--node_block_list', nargs='+', required=False, default=[], help='List of node names to compute in float32.') - mixed_precision_option_grapu.add_argument('--force_fp16_initializers', + mixed_precision_option_group.add_argument('--force_fp16_initializers', required=False, action='store_true', help='Convert all float initializers to float16.') - mixed_precision_option_grapu.set_defaults(force_fp16_initializers=False) + mixed_precision_option_group.set_defaults(force_fp16_initializers=False) args = parser.parse_args(argv) From bca63fdf04767ac805fca7a24b5de7b085fd02c4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 20 Nov 2021 01:38:56 -0800 Subject: [PATCH 22/53] fix merge --- .../core/graph/contrib_ops/contrib_defs.cc | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 186a6a24e25a4..ba597be498219 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -659,6 +659,48 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { } } } + +void RegisterTextGenerationSchemas() { + ONNX_CONTRIB_OPERATOR_SCHEMA(BeamSearch) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc("Beam Search for text generation. Supports GPT-2 decoder.") + .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) + .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) + .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) + .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) + .Attr( + "body", + "The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output", + AttributeProto::GRAPH) + .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I") + .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") + .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) + .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") + .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") + .Input(5, "temperature", "The value used to module the next token probabilities. Accepts value != 0.0. Shape is (1)", "T") + .Input(6, "length_penalty", + "Exponential penalty to the length. Default value 1.0 means no penalty." + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Shape is (1,)", + "T", OpSchema::Optional) + .Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) + .Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") + .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) + .Output(2, "scores", + "Processed beam scores for each vocabulary token at each generation step." + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", + "T", OpSchema::Optional) + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") + .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + BeamSearchShapeInference(ctx); + }); +} + void RegisterBertSchemas() { static const char* Attention_ver1_doc = R"DOC( Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). @@ -812,6 +854,37 @@ Global attention flags have value 1 for the tokens attend globally and 0 otherwi .TypeConstraint("G", {"tensor(int32)"}, "Constrain to integer types") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + static const char* Decoder_Attention_doc = R"DOC( +This DecoderAttention supports self attention and cross attention, key and value cache, and key_padding_mask. The attention mask is not support at the moment. +Some boolean parameters are passed by runtime input for generic purpose +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(DecoderAttention) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(Decoder_Attention_doc) + .Attr("num_heads", "Number of attention heads", AttributeProto::INT) + .Input(0, "query", "3D input tensor with shape (sequence_length, batch_size, hidden_size), hidden_size = num_heads * head_size", "T") + .Input(1, "key", "3D input tensor with shape (total_sequence_length, batch_size, hidden_size)", "T") + .Input(2, "q_weight", "2D input tensor with shape (hidden_size, hidden_size)", "T") + .Input(3, "kv_weight", "2D input tensor with shape (hidden_size, 2 * hidden_size)", "T") + .Input(4, "bias", "1D input tensor with shape (3 * hidden_size)", "T") + .Input(5, "key_padding_mask", "2D input tensor with shape (batch_size, total_sequence_length)", "B", OpSchema::Optional) + .Input(6, "key_cache", "input tensor with shape (batch_size, num_heads, sequence_length or total_sequence_length, head_size)", "T", OpSchema::Optional) // self & cross + .Input(7, "value_cache", "input tensor with shape (batch_size, num_heads, sequence_length or total_sequence_length, head_size)", "T", OpSchema::Optional) // self & cross + .Input(8, "static_kv", "If static_kv = true, cross-attention; else self-attention", "B") + .Input(9, "use_past", "If use_past = true, use cache; else no cache", "B") + .Input(10, "has_layer_state", "If has_layer_state = true, layer_state = {} or [a,b]; else layer_state = None", "B") + .Input(11, "has_key_padding_mask", "has_key_padding_mask or not", "B") + .Output(0, "output", "3D output tensor with shape (sequence_length, batch_size, hidden_size)", "T") + .Output(1, "new_key_cache", "output tensor with shape (batch_size, num_heads, new sequence_length, head_size)", "T", OpSchema::Optional) // self & cross + .Output(2, "new_value_cache", "output tensor with shape (batch_size, num_heads, new sequence_length, head_size)", "T", OpSchema::Optional) // self & cross + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float and float16 tensors.") + .TypeConstraint("B", {"tensor(bool)"}, "Constrain key_padding_mask to bool tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + DecoderAttentionTypeAndShapeInference(ctx); + }); + static const char* EmbedLayerNormalization_ver1_doc = R"DOC( EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, @@ -3205,6 +3278,7 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i RegisterNhwcSchemas(); RegisterBertSchemas(); + RegisterTextGenerationSchemas(); #ifdef BUILD_MS_EXPERIMENTAL_OPS onnxruntime::signal::RegisterSignalSchemas(); From 3e0cb7f7321b56db60868b7f1c04bdf6738e298d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 20 Nov 2021 11:51:03 -0800 Subject: [PATCH 23/53] update doc --- docs/ContribOperators.md | 70 ++++++++++++++++++++++++++++++++++++++++ docs/OperatorKernels.md | 1 + 2 files changed, 71 insertions(+) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 3a0171a064589..fc1411a764ff4 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5,6 +5,7 @@ Do not modify directly.* * com.microsoft * com.microsoft.Attention * com.microsoft.AttnLSTM + * com.microsoft.BeamSearch * com.microsoft.BiasDropout * com.microsoft.BiasGelu * com.microsoft.BiasSoftmax @@ -337,6 +338,75 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.BeamSearch** + + Beam Search for text generation. Supports GPT-2 decoder. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
body : graph (required)
+
The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output
+
early_stopping : int
+
early stop or not
+
eos_token_id : int (required)
+
The id of the end-of-sequence token
+
no_repeat_ngram_size : int
+
no repeat ngrams size
+
pad_token_id : int (required)
+
The id of the padding token
+
+ +#### Inputs (6 - 9) + +
+
input_ids : I
+
The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)
+
max_length : I
+
The maximum length of the sequence to be generated. Shape is (1)
+
min_length (optional) : I
+
The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)
+
num_beams : I
+
Number of beams for beam search. 1 means no beam search. Shape is (1)
+
num_return_sequences : I
+
The number of returned sequences in the batch. Shape is (1)
+
temperature : T
+
The value used to module the next token probabilities. Accepts value != 0.0. Shape is (1)
+
length_penalty (optional) : T
+
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
repetition_penalty (optional) : T
+
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
+
vocab_mask (optional) : M
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
+ +#### Outputs (1 - 3) + +
+
sequences : I
+
Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)
+
sequences_scores (optional) : T
+
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
+
scores (optional) : T
+
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
I : tensor(int32)
+
Constrain to integer types
+
M : tensor(int32)
+
Constrain mask to integer types
+
+ + ### **com.microsoft.BiasDropout** output, dropout_mask = Dropout(data + bias, ratio) + residual, Intended to specialize the dropout pattern commonly found in transformer models. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bf637d6405788..01fd44bb5c009 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -377,6 +377,7 @@ Do not modify directly.* |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| +|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* temperature:**T**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| |BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)| |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| From 872408ca013bb96f3fe3ce076c5a135337d7c44f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 21 Nov 2021 00:59:30 -0800 Subject: [PATCH 24/53] add repetition penalty --- docs/ContribOperators.md | 2 +- .../cpu/transformers/beam_search.cc | 36 ++++++++++++++----- .../cpu/transformers/beam_search.h | 2 +- .../cpu/transformers/beam_search_parameters.h | 1 + .../cpu/transformers/beam_search_scorer.h | 3 +- .../core/graph/contrib_ops/contrib_defs.cc | 2 +- 6 files changed, 34 insertions(+), 12 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index fc1411a764ff4..a982999f0a82d 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -375,7 +375,7 @@ This version of the operator has been available since version 1 of the 'com.micr
num_return_sequences : I
The number of returned sequences in the batch. Shape is (1)
temperature : T
-
The value used to module the next token probabilities. Accepts value != 0.0. Shape is (1)
+
The value used to module the next token probabilities. Accepts value > 0.0. Shape is (1)
length_penalty (optional) : T
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
repetition_penalty (optional) : T
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index d8714ea4e9d73..e830bd82a1511 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -10,7 +10,7 @@ #endif #ifndef NDEBUG -#define DEBUG_BEAM_SEARCH 1 // TODO: remove this once this operator is ready for production. +//#define DEBUG_BEAM_SEARCH 1 // uncomment it for debugging #endif #include "core/providers/cpu/controlflow/utils.h" @@ -52,9 +52,6 @@ namespace contrib { REGISTER_KERNEL_TYPED(float) -// CPU does not support float16 -// REGISTER_KERNEL_TYPED(MLFloat16) - GptSubgraphInfo::GptSubgraphInfo(const onnxruntime::Node& node, const GraphViewer& subgraph_in) : subgraph(subgraph_in) { num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); @@ -99,9 +96,9 @@ void Sequences::Init(const OrtValue& input_ids, int batch_beam_size, int sequenc current_length_ = sequence_length; } -gsl::span Sequences::GetSequence(int beam_index) { - gsl::span buffer(sequences[current_sequences_buffer]); - gsl::span sequence = buffer.subspan(beam_index * max_length_, current_length_); +gsl::span Sequences::GetSequence(int beam_index) const { + gsl::span buffer(sequences[current_sequences_buffer]); + gsl::span sequence = buffer.subspan(beam_index * max_length_, current_length_); return sequence; } @@ -188,6 +185,9 @@ class BeamSearchImpl { // Mask tokens accroding to vocab_mask void ApplyVocabMask(gsl::span& next_token_scores); + // Apply repetion penalty + void ApplyRepetitionPenalty(const Sequences& sequences, gsl::span& next_token_scores); + // Reorder cache by picking the past state based on beam indices void PickPastState(const std::vector& last_outputs, std::vector& next_inputs, @@ -647,8 +647,9 @@ Status BeamSearchImpl::ProcessLogits( return status; } - // Apply all logits processors that modify scores + // Apply all score processors that updates scores ApplyVocabMask(next_token_scores); + ApplyRepetitionPenalty(beam_state.sequences, next_token_scores); // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) @@ -779,6 +780,25 @@ void BeamSearchImpl::ApplyVocabMask(gsl::span& next_token_scores) { return; } +template +void BeamSearchImpl::ApplyRepetitionPenalty(const Sequences& sequences, gsl::span& next_token_scores) { + if (parameters_->repetition_penalty == 1.0f) { // no penalty + return; + } + + int batch_beam_size = parameters_->BatchBeamSize(); + for (int i = 0; i < batch_beam_size; i++) { + gsl::span beam_token_scores = next_token_scores.subspan(i * parameters_->vocab_size, parameters_->vocab_size); + gsl::span sequence = sequences.GetSequence(i); + for (const int64_t& word_id : sequence) { + T score = beam_token_scores[word_id]; + // If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability, + // This assumes that scores are either positive (like ctrl) or negative (like GPT-2), but not a mixture. + beam_token_scores[word_id] = (score < 0 ? score * parameters_->repetition_penalty : score / parameters_->repetition_penalty); + } + } +} + template void BeamSearchImpl::PickPastState(const std::vector& last_outputs, std::vector& next_inputs, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index bc229a6856f24..9ff42ba4e88e2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -37,7 +37,7 @@ class Sequences : public ISequences { void Init(const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length); // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). - gsl::span GetSequence(int beam_index) override; + gsl::span GetSequence(int beam_index) const override; // Returns current sequence length. int GetSequenceLength() override; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index b29c49696d1d0..b9477bcbf3d96 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -39,6 +39,7 @@ struct BeamSearchParameters { Status Validate(); + int BatchBeamSize(){ return batch_size * num_beams; } void ParseFromAttributes(const OpKernelInfo& info); void ParseFromInputs(OpKernelContext* context); void SetSubgraphParameters(int num_heads, int head_size, int vocab_size, int num_layers); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index bedf212c95f7f..22172344baecb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -16,7 +16,8 @@ namespace contrib { class ISequences { public: - virtual gsl::span GetSequence(int beam_index) = 0; + virtual ~ISequences() {} + virtual gsl::span GetSequence(int beam_index) const = 0; virtual int GetSequenceLength() = 0; }; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index ba597be498219..6eb90e969c241 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -678,7 +678,7 @@ void RegisterTextGenerationSchemas() { .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") - .Input(5, "temperature", "The value used to module the next token probabilities. Accepts value != 0.0. Shape is (1)", "T") + .Input(5, "temperature", "The value used to module the next token probabilities. Accepts value > 0.0. Shape is (1)", "T") .Input(6, "length_penalty", "Exponential penalty to the length. Default value 1.0 means no penalty." "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." From 06dc72b2ce8f8d67014dc6fb91093eeb7bbc2c5c Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 24 Nov 2021 23:30:57 -0800 Subject: [PATCH 25/53] refactoring: add GptSubgraph class --- .../cpu/transformers/beam_search.cc | 467 +----------------- .../cpu/transformers/beam_search.h | 60 +-- .../transformers/beam_search_parameters.cc | 10 +- .../cpu/transformers/beam_search_parameters.h | 20 +- .../cpu/transformers/beam_search_scorer.cc | 3 +- .../cpu/transformers/beam_search_scorer.h | 11 +- .../cpu/transformers/dump_tensor.cc | 5 + .../cpu/transformers/dump_tensor.h | 8 + .../cpu/transformers/gpt_subgraph.cc | 452 +++++++++++++++++ .../cpu/transformers/gpt_subgraph.h | 83 ++++ .../contrib_ops/cpu/transformers/sequences.cc | 72 +++ .../contrib_ops/cpu/transformers/sequences.h | 55 +++ 12 files changed, 725 insertions(+), 521 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc create mode 100644 onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h create mode 100644 onnxruntime/contrib_ops/cpu/transformers/sequences.cc create mode 100644 onnxruntime/contrib_ops/cpu/transformers/sequences.h diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index e830bd82a1511..3ffa8a5be723b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -9,10 +9,6 @@ #pragma warning(disable : 4996) #endif -#ifndef NDEBUG -//#define DEBUG_BEAM_SEARCH 1 // uncomment it for debugging -#endif - #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/math/top_k.h" #include "core/framework/allocator.h" @@ -48,103 +44,18 @@ namespace contrib { kCpuExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - BeamSearch); + transformers::BeamSearch); REGISTER_KERNEL_TYPED(float) -GptSubgraphInfo::GptSubgraphInfo(const onnxruntime::Node& node, const GraphViewer& subgraph_in) - : subgraph(subgraph_in) { - num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); - - auto& subgraph_inputs = subgraph.GetInputs(); - auto& subgraph_outputs = subgraph.GetOutputs(); - - // inputs: input_ids, position_ids, attention_mask, past_0, past_1, ... - // outputs: logits, present_0, present_1, ... - num_subgraph_inputs = static_cast(subgraph_inputs.size()); - num_subgraph_outputs = static_cast(subgraph_outputs.size()); - - // CheckSubgraph will verify inputs and outputs later. - subgraph_input_names.reserve(num_subgraph_inputs); - for (int i = 0; i < num_subgraph_inputs; ++i) { - subgraph_input_names.push_back(subgraph_inputs[i]->Name()); - } - - subgraph_output_names.reserve(num_subgraph_outputs); - for (int i = 0; i < num_subgraph_outputs; ++i) { - subgraph_output_names.push_back(subgraph_outputs[i]->Name()); - } -} - -void Sequences::Init(const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length) { - // Allocate buffer (shall we use allocator instead?) - sequences[0].assign(batch_beam_size * max_length, 0); - sequences[1].assign(batch_beam_size * max_length, 0); - - // copying input_ids to sequences[0] - gsl::span input = input_ids.Get().DataAsSpan(); - gsl::span output(sequences[0]); - for (int i = 0; i < batch_beam_size; i++) { - gsl::span source = input.subspan(i * sequence_length, sequence_length); - gsl::span target = output.subspan(i * max_length, sequence_length); - gsl::copy(source, target); - } - current_sequences_buffer = 0; - - batch_beam_size_ = batch_beam_size; - max_length_ = max_length; - current_length_ = sequence_length; -} - -gsl::span Sequences::GetSequence(int beam_index) const { - gsl::span buffer(sequences[current_sequences_buffer]); - gsl::span sequence = buffer.subspan(beam_index * max_length_, current_length_); - return sequence; -} - -int Sequences::GetSequenceLength() { - return current_length_; -} - -void Sequences::PrintSequences() { -#ifdef DEBUG_BEAM_SEARCH - for (int i = 0; i < batch_beam_size_; i++) { - gsl::span sequence = GetSequence(i); - DumpString("sequences", i, false); - DumpTensor(nullptr, sequence.data(), 1, current_length_); - } -#endif -} - -void Sequences::AppendNextTokenToSequences( - gsl::span& beam_indices, - gsl::span& beam_next_tokens) { - //sequences = torch.cat([sequences[beam_indices, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - gsl::span input(sequences[current_sequences_buffer]); - gsl::span output(sequences[1 - current_sequences_buffer]); - - for (int i = 0; i < batch_beam_size_; i++) { - int beam_index = static_cast(beam_indices[i]); - gsl::span source = input.subspan(beam_index * max_length_, current_length_); - gsl::span target = output.subspan(i * max_length_, current_length_); - gsl::copy(source, target); - } - - // append next token to each beam - for (int i = 0; i < batch_beam_size_; i++) { - output[i * max_length_ + current_length_] = beam_next_tokens[i]; - } - - ++current_length_; - current_sequences_buffer = 1 - current_sequences_buffer; // rotate buffer for next round -} +namespace transformers { template class BeamSearchImpl { public: BeamSearchImpl(OpKernelContextInternal& context, const SessionState& session_state, - const GptSubgraphInfo& info, + GptSubgraph& gpt_subgraph, concurrency::ThreadPool* thread_pool, void* stream, BeamSearchParameters& params); @@ -159,8 +70,6 @@ class BeamSearchImpl { private: Status CheckInputs(const OpKernelContextInternal& context); - OrtValue ExpandInputs(const OrtValue& input_ids, int num_beams) const; - // Prepare the inputs for first inference of subgraph void CreateInitialFeeds(std::vector& feeds); @@ -188,21 +97,15 @@ class BeamSearchImpl { // Apply repetion penalty void ApplyRepetitionPenalty(const Sequences& sequences, gsl::span& next_token_scores); - // Reorder cache by picking the past state based on beam indices - void PickPastState(const std::vector& last_outputs, - std::vector& next_inputs, - gsl::span& beam_indices); - OpKernelContextInternal& context_; const SessionState& session_state_; - const GptSubgraphInfo& subgraph_info_; + + GptSubgraph& gpt_subgraph_; concurrency::ThreadPool* thread_pool_; const std::vector& implicit_inputs_; - std::vector next_positions_; - // Not used in CPU. Stream is for CUDA only. void* stream_; @@ -243,118 +146,15 @@ template common::Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { - ORT_ENFORCE(subgraph_info_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); - ORT_UNUSED_PARAMETER(attribute_name); - + ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); const auto& node = Node(); - subgraph_info_ = std::make_unique(node, subgraph_session_state.GetGraphViewer()); - - ORT_RETURN_IF(subgraph_info_->num_subgraph_outputs <= 1, - "Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in inputs and outputs)."); - - ORT_RETURN_IF(subgraph_info_->num_subgraph_inputs != subgraph_info_->num_subgraph_outputs + 2, - "Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2"); - - std::vector feed_names; - feed_names.reserve(subgraph_info_->num_subgraph_inputs + subgraph_info_->num_implicit_inputs); - - // First, get the location of input_ids of current operator. - const auto& node_inputs = node.InputDefs(); - const OrtMemoryInfo& input_ids_location = utils::FindMemoryInfoForValue(session_state, node_inputs[0]->Name()); - - // position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. - // as we skip them when we call FindDevicesForValues, and default them to be in the same device as input_ids - feed_names.insert(feed_names.end(), subgraph_info_->subgraph_input_names.begin(), subgraph_info_->subgraph_input_names.end()); - - for (auto& entry : node.ImplicitInputDefs()) { - feed_names.push_back(entry->Name()); - } - - std::vector feed_locations; - feed_locations.resize(feed_names.size()); - - for (size_t i = 0, end = feed_names.size(); i < end; ++i) { - if (i >= subgraph_info_->subgraph_input_names.size()) { // implicit inputs - const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]); - feed_locations[i] = location.device; - } else { - feed_locations[i] = input_ids_location.device; - } - } - - std::unique_ptr ffm; - ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_info_->subgraph_output_names, - subgraph_session_state.GetOrtValueNameIdxMap(), ffm)); - ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm)); - - // setup the locations where we want the subgraph output to end up on - std::vector fetch_locations; - fetch_locations.reserve(subgraph_info_->num_subgraph_outputs); - - // past state need to be where we can feed them in to the next iteration, so set the fetch location to match the feed location. - for (int i = 0; i < subgraph_info_->num_subgraph_outputs; ++i) { - fetch_locations.push_back(&input_ids_location); - } - - utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations); - - feeds_fetches_manager_ = std::move(ffm); - - // CheckSubgraph is moved here so that it only need called once instead of every inference run. - auto& inputs = subgraph_info_->subgraph.GetInputs(); - auto& outputs = subgraph_info_->subgraph.GetOutputs(); - ORT_RETURN_IF_ERROR(CheckSubgraph(inputs, outputs)); - - return Status::OK(); -} - -template -Status BeamSearch::CheckSubgraph(const std::vector& subgraph_inputs, - const std::vector& subgraph_outputs) { - ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ", - subgraph_inputs[0]->Name()); - ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", "subgraph input 1 shall be named as position_ids, got: ", - subgraph_inputs[1]->Name()); - ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", "subgraph input 2 shall be named as attention_mask, got: ", - subgraph_inputs[2]->Name()); - ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", "subgraph input 3 shall be named as past_0, got: ", - subgraph_inputs[3]->Name()); - - // Past state shape is like (2, batch_size, 12, past_seq_len, 64). Here 12 and 64 are constants of num_heads and hidden_size/num_heads. - const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape(); - ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ", - past_shape->dim_size()); - - ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2, - "subgraph past state dimension 0 shall have length of 2"); - - ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0, - "subgraph past state dimension 2 shall have a positive value for number of heads"); - - ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0, - "subgraph past state dimension 4 shall have a positive value for hidden size per head"); - - // check subgraph outputs - ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", "subgraph output 0 shall be named as logits, got: ", - subgraph_outputs[0]->Name()); - - ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", "subgraph input 1 shall be named as present_0, got: ", - subgraph_outputs[1]->Name()); - - // Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size. - const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); - ORT_RETURN_IF(logits_shape->dim_size() != 3, "subgraph logits output is expected to have 3 dimension, got ", - logits_shape->dim_size()); - - ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0, - "subgraph past state dimension 2 shall have a positive value for vocabulary size"); - - int num_heads = static_cast(past_shape->dim(2).dim_value()); - int head_size = static_cast(past_shape->dim(4).dim_value()); - int vocab_size = static_cast(logits_shape->dim(2).dim_value()); - int num_layers = static_cast(subgraph_outputs.size()) - 1; - parameters_.SetSubgraphParameters(num_heads, head_size, vocab_size, num_layers); - + gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer()); + ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state)); + feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager(); + parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size, + gpt_subgraph_->num_heads, + gpt_subgraph_->head_size, + gpt_subgraph_->num_layers); return Status::OK(); } @@ -369,7 +169,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { BeamSearchParameters parameters = parameters_; // make a copy - BeamSearchImpl impl{*ctx_internal, *session_state, *subgraph_info_, thread_pool, stream_, parameters}; + BeamSearchImpl impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, stream_, parameters}; auto status = impl.Initialize(); ORT_RETURN_IF_ERROR(status); @@ -382,13 +182,13 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { template BeamSearchImpl::BeamSearchImpl(OpKernelContextInternal& context, const SessionState& session_state, - const GptSubgraphInfo& subgraph_info, + GptSubgraph& gpt_subgraph, concurrency::ThreadPool* thread_pool, void* stream, BeamSearchParameters& params) : context_(context), session_state_(session_state), - subgraph_info_(subgraph_info), + gpt_subgraph_(gpt_subgraph), thread_pool_(thread_pool), implicit_inputs_(context_.GetImplicitInputs()), stream_(stream), @@ -471,131 +271,11 @@ Status BeamSearchImpl::Initialize() { return status; } -template -OrtValue BeamSearchImpl::ExpandInputs(const OrtValue& input, int num_beams) const { - if (num_beams == 1) - return input; - - // Given input of shape (batch_size, sequence_length), expand the shape to be (batch_size * num_beams, sequence_length) - const TensorShape& input_shape = input.Get().Shape(); - ORT_ENFORCE(input_shape.NumDimensions() == 2 && input_shape[0] == parameters_->batch_size && input_shape[1] == parameters_->sequence_length); - - const int64_t& batch_size = input_shape[0]; - const int64_t& sequence_length = input_shape[1]; - int64_t dims[] = {batch_size * num_beams, sequence_length}; - TensorShape expanded_shape(&dims[0], 2); - - auto element_type = DataTypeImpl::GetType(); - OrtValue expanded; - Tensor::InitOrtValue(element_type, expanded_shape, allocator_, expanded); - - const int64_t* input_data = input.Get().Data(); - int64_t* expanded_data = expanded.GetMutable()->MutableData(); - int64_t* target = expanded_data; - for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < num_beams; j++) { - memcpy(target, input_data + i * sequence_length, sizeof(int64_t) * sequence_length); - target += sequence_length; - } - } - - return expanded; -} - template void BeamSearchImpl::CreateInitialFeeds(std::vector& feeds) { - // Subgraph inputs: - // input_ids: shape (B, S) wher B is batch size, and S is sequence length - // position_ids: shape (B, S) - // attention_mask: shape (B, P+S), where past_sequence_length (P) is 0 - // After expansion, their shapes will become (B, M*S), where M is num_beams. - - const OrtValue* input_ids = context_.GetInputOrtValue(0); - - const Tensor& input_ids_tensor = input_ids->Get(); - - const TensorShape& input_ids_shape = input_ids_tensor.Shape(); - ORT_ENFORCE(input_ids_shape.NumDimensions() == 2); - const int64_t& batch_size = input_ids_shape[0]; - const int64_t& sequence_length = input_ids_shape[1]; - - // Allocate position_ids and attention_mask based on shape of input_ids - auto element_type = DataTypeImpl::GetType(); - - // input_ids for subgraph is int64, so we need Cast input_ids from int32 to int64. - OrtValue subgraph_input_ids; - // Current shape is (batch_size, sequence_length) - // Note that we will expand it to (batch_size * num_beams, sequence_length) later. - Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, subgraph_input_ids); - - int64_t* subgraph_input_data = subgraph_input_ids.GetMutable()->MutableData(); - const int32_t* source = input_ids_tensor.Data(); - int64_t* target = subgraph_input_data; - for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < sequence_length; j++, source++, target++) { - *target = static_cast(*source); - } - } - - OrtValue position_ids; - Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, position_ids); - - OrtValue attention_mask; - Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, attention_mask); - - next_positions_.resize(batch_size * parameters_->num_beams); - // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. - // Set position id to be 0 for pad tokens, and cumulated sum of mask in a batch for other tokens - int64_t* mask_data = attention_mask.GetMutable()->MutableData(); - int64_t* position_data = position_ids.GetMutable()->MutableData(); - source = input_ids_tensor.Data(); - int64_t* mask = mask_data; - int64_t* position = position_data; - for (int i = 0; i < batch_size; i++) { - int64_t abs_position = 0; - for (int j = 0; j < sequence_length; j++, source++, mask++, position++) { - if (*source == parameters_->pad_token_id) { - *mask = 0; - *position = 0; - } else { - *mask = 1; - *position = abs_position; - abs_position++; - } - } - for (int k = 0; k < parameters_->num_beams; k++) { - next_positions_[i * parameters_->num_beams + k] = abs_position; - } - } - - // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) for input_ids, position_ids and attention_mask - // TODO: Try expand inputs/outputs after first subgraph call instead. That may get better peroformance, but more complex to implement. - OrtValue expanded_input_ids = ExpandInputs(subgraph_input_ids, parameters_->num_beams); - OrtValue expanded_position_ids = ExpandInputs(position_ids, parameters_->num_beams); - OrtValue expanded_attention_mask = ExpandInputs(attention_mask, parameters_->num_beams); - - // Initialize empty past state - auto past_type = DataTypeImpl::GetType(); - int64_t past_state_dims[] = {2, batch_size * parameters_->num_beams, parameters_->num_heads, 0, parameters_->head_size}; - TensorShape past_shape(&past_state_dims[0], 5); - OrtValue empty_past; - Tensor::InitOrtValue(past_type, past_shape, allocator_, empty_past); - - // The ordering is the same as used in SetupSubgraphExecutionInfo - feeds.reserve(subgraph_info_.num_subgraph_inputs + subgraph_info_.num_implicit_inputs); - feeds.push_back(expanded_input_ids); - feeds.push_back(expanded_position_ids); - feeds.push_back(expanded_attention_mask); - - // The remaing inputs are past state. - for (int i = 3; i < subgraph_info_.num_subgraph_inputs; ++i) { - feeds.push_back(empty_past); - } - - // pass in implicit inputs - for (const auto* entry : implicit_inputs_) { - feeds.push_back(*entry); - } + const OrtValue* input_ids_value = context_.GetInputOrtValue(0); + const Tensor& input_ids = input_ids_value->Get(); + gpt_subgraph_.CreateInitialFeeds(input_ids, implicit_inputs_, parameters_->num_beams, parameters_->pad_token_id, feeds); } template @@ -651,7 +331,6 @@ Status BeamSearchImpl::ProcessLogits( ApplyVocabMask(next_token_scores); ApplyRepetitionPenalty(beam_state.sequences, next_token_scores); - // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) // TODO: use thread pool to parrellel int offset = 0; @@ -799,50 +478,6 @@ void BeamSearchImpl::ApplyRepetitionPenalty(const Sequences& sequences, gsl:: } } -template -void BeamSearchImpl::PickPastState(const std::vector& last_outputs, - std::vector& next_inputs, - gsl::span& beam_indices) { - for (int i = 3; i < subgraph_info_.num_subgraph_inputs; ++i) { - const OrtValue& present = last_outputs[i - 2]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64) - const TensorShape& past_shape = present.Get().Shape(); - - // Create a tensor with same shape. - OrtValue past; - auto past_type = DataTypeImpl::GetType(); // present.Type() - Tensor::InitOrtValue(past_type, past_shape, allocator_, past); - - auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4]; - auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4]; - - gsl::span past_span = past.GetMutable()->MutableDataAsSpan(); - gsl::span present_span = present.Get().DataAsSpan(); - for (gsl::index j = 0; j < beam_indices.length(); j++) { - int64_t beam_index = beam_indices[j]; - gsl::span present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); - gsl::span present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam); - - gsl::span past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam); - gsl::span past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam); - gsl::copy(present_key, past_key); - gsl::copy(present_value, past_value); - -#ifdef DEBUG_BEAM_SEARCH - if (i == 3) // only dump past_0 - { - DumpString("past_key of beam", static_cast(j), true); - DumpTensor(nullptr, past_key.data(), 1, static_cast(block_size_per_beam)); - - DumpString("past_value of beam", static_cast(j), true); - DumpTensor(nullptr, past_value.data(), 1, static_cast(block_size_per_beam)); - } -#endif - } - - next_inputs[i] = past; - } -} - template Status BeamSearchImpl::UpdateFeeds( const std::vector& last_outputs, @@ -850,68 +485,7 @@ Status BeamSearchImpl::UpdateFeeds( int current_length, gsl::span beam_next_tokens, gsl::span beam_indices) { - // last_outputs: logits, present_0, present_1, ... - // next_inputs: input_ids, position_id, attention_mask, past_0, past_1 - - // The following updates inputs for subgraph - // TODO: Reuse buffer for input_ids and position_ids to reduce memory allocation. - - // Update input_ids with next tokens. - int batch_beam_size = parameters_->batch_size * parameters_->num_beams; - int64_t dims[] = {batch_beam_size, 1}; - TensorShape input_ids_shape(&dims[0], 2); - auto element_type = DataTypeImpl::GetType(); - OrtValue input_ids; - Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, input_ids); - int64_t* input_ids_data = input_ids.GetMutable()->MutableData(); - for (int i = 0; i < batch_beam_size; i++) { - input_ids_data[i] = beam_next_tokens[i]; - } - next_inputs[0] = input_ids; - - // Update position IDs - OrtValue position_ids; - Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, position_ids); - int64_t* position_data = position_ids.GetMutable()->MutableData(); - for (int i = 0; i < batch_beam_size; i++) { - position_data[i] = next_positions_[i]; - next_positions_[i]++; - } - next_inputs[1] = position_ids; - - // Update attention mask - const OrtValue& old_mask = next_inputs[2]; - const int64_t* old_mask_data = old_mask.Get().Data(); - int64_t mask_dims[] = {batch_beam_size, current_length}; - TensorShape mask_shape(&mask_dims[0], 2); - OrtValue attention_mask; - Tensor::InitOrtValue(element_type, mask_shape, allocator_, attention_mask); - int64_t* mask_data = attention_mask.GetMutable()->MutableData(); - for (int i = 0; i < batch_beam_size; i++) { - for (int j = 0; j < current_length - 1; j++) { - mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j]; - } - mask_data[i * current_length + current_length - 1] = 1; - } - next_inputs[2] = attention_mask; - -#ifdef DEBUG_BEAM_SEARCH - DumpOrtValue("input_ids", input_ids); - DumpOrtValue("position_ids", position_ids); - DumpOrtValue("attention_mask", attention_mask); -#endif - - // Update past state - if (parameters_->num_beams == 1) { - // feed present_* output to past_* inputs one by one - for (int i = 3; i < subgraph_info_.num_subgraph_inputs; ++i) { - next_inputs[i] = last_outputs[i - 2]; - } - } else { - PickPastState(last_outputs, next_inputs, beam_indices); - } - - return Status::OK(); + return gpt_subgraph_.UpdateFeeds(last_outputs, next_inputs, current_length, beam_next_tokens, beam_indices, parameters_->num_beams); } template @@ -1022,5 +596,6 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { template class BeamSearchImpl; template class BeamSearch; +} // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 9ff42ba4e88e2..ae3640fc12a2f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -10,59 +10,12 @@ #include "core/providers/cpu/controlflow/utils.h" #include "beam_search_parameters.h" #include "beam_search_scorer.h" +#include "gpt_subgraph.h" +#include "sequences.h" namespace onnxruntime { namespace contrib { - -struct GptSubgraphInfo { - GptSubgraphInfo(const onnxruntime::Node& node, const GraphViewer& subgraph_in); - - const GraphViewer& subgraph; - - int num_implicit_inputs; - - int num_subgraph_inputs; // same as subgraph_input_names.size(), keep it for convenience. - int num_subgraph_outputs; // same as subgraph_output_names.size() - - std::vector subgraph_input_names; - std::vector subgraph_output_names; -}; - -// This class keeps track of sequences generated. -class Sequences : public ISequences { - public: - Sequences() {} - - // Initialize the sequence with initial input_ids and related parameters. - void Init(const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length); - - // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). - gsl::span GetSequence(int beam_index) const override; - - // Returns current sequence length. - int GetSequenceLength() override; - - // Print the sequences to StdOut in debug mode - void PrintSequences(); - - // Select sequences based on beam indices, then append next token to selected sequences. - void AppendNextTokenToSequences( - gsl::span& beam_indices, - gsl::span& beam_next_tokens); - - private: - // Two buffers of shape (batch_size, num_beams, max_seq_length) to store sequences. - // At each time, there is only one buffer is active. The other one will be active in next token. - // Each AppendNextTokenToSequences call will trigger a rotation of active buffer. - std::vector sequences[2]; - - // Index (either 0 or 1) of two buffers that is currently is active. - int current_sequences_buffer; - - int batch_beam_size_; - int max_length_; - int current_length_; -}; +namespace transformers { template struct BeamSearchState { @@ -120,20 +73,19 @@ class BeamSearch : public controlflow::IControlFlowKernel { static std::unique_ptr Create(const OpKernelInfo& info, void* stream); protected: - Status CheckSubgraph(const std::vector& subgraph_inputs, - const std::vector& subgraph_outputs); void SetComputeStream(void* stream) { stream_ = stream; } private: // Subgraph info and FeedsFetchesManager re-used for each subgraph execution. - std::unique_ptr subgraph_info_; - std::unique_ptr feeds_fetches_manager_; + std::unique_ptr gpt_subgraph_; + FeedsFetchesManager* feeds_fetches_manager_; void* stream_; BeamSearchParameters parameters_; }; +} // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index ec00614db508e..37fdc60b46096 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -2,10 +2,11 @@ // Licensed under the MIT License. #include "beam_search_parameters.h" -constexpr int kMaxSequenceLength = 4096; - namespace onnxruntime { namespace contrib { +namespace transformers { + +constexpr int kMaxSequenceLength = 4096; Status BeamSearchParameters::Validate() { ORT_RETURN_IF(eos_token_id < 0, "eos_token_id is invalid"); @@ -58,12 +59,13 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty); } -void BeamSearchParameters::SetSubgraphParameters(int heads, int hidden_size_per_head, int vocabulary_size, int layers) { +void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { + vocab_size = vocabulary_size; num_heads = heads; head_size = hidden_size_per_head; - vocab_size = vocabulary_size; num_layers = layers; } +} // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index b9477bcbf3d96..3ff1eede0e0ca 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -7,6 +7,7 @@ namespace onnxruntime { namespace contrib { +namespace transformers { struct BeamSearchParameters { // from node attributes @@ -23,27 +24,28 @@ struct BeamSearchParameters { float temperature; float length_penalty; float repetition_penalty; - int batch_size; // deduce from first dimension of input_ids - int sequence_length; // deduce from second dimension of input_ids - + int batch_size; // deduce from first dimension of input_ids + int sequence_length; // deduce from second dimension of input_ids + gsl::span vocab_mask; // from outputs - bool output_scores; // whether scores existed in output + bool output_scores; // whether scores existed in output // deduce from subgraph - int num_heads; - int head_size; int vocab_size; - int num_layers; + int num_heads; // not used + int head_size; // not used + int num_layers; // not used Status Validate(); - int BatchBeamSize(){ return batch_size * num_beams; } + int BatchBeamSize() { return batch_size * num_beams; } void ParseFromAttributes(const OpKernelInfo& info); void ParseFromInputs(OpKernelContext* context); - void SetSubgraphParameters(int num_heads, int head_size, int vocab_size, int num_layers); + void SetSubgraphParameters(int vocab_size, int num_heads, int head_size, int num_layers); }; +} // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index 9f69f2d8d3dd1..802b84afc6e83 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -13,7 +13,7 @@ namespace onnxruntime { namespace contrib { - +namespace transformers { using ::onnxruntime::rnn::detail::Allocate; template @@ -280,5 +280,6 @@ template class HypothesisScoreCompare; template class BeamHypotheses; template class BeamSearchScorer; +} // namespace transformers } // namespace contrib } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index 22172344baecb..c2eb0b13c5c83 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -11,15 +11,11 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" +#include "sequences.h" + namespace onnxruntime { namespace contrib { - -class ISequences { - public: - virtual ~ISequences() {} - virtual gsl::span GetSequence(int beam_index) const = 0; - virtual int GetSequenceLength() = 0; -}; +namespace transformers { // Interface for all scorers for beam search or beam sample. template @@ -138,5 +134,6 @@ class BeamSearchScorer : public IBeamScorer { int hypothesis_buffer_offset_; // Offset of avaiable buffer, or length of used buffer. }; +} // namespace transformers } // namespace contrib } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc index e44a521c6e9c2..1e671c01e9038 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc @@ -6,6 +6,8 @@ #include "core/platform/env_var_utils.h" namespace onnxruntime { +namespace contrib { +namespace transformers { namespace dump_tensor_env_vars { constexpr const char* kDumpBeamSearch = "ORT_DUMP_BEAM_SEARCH"; @@ -64,4 +66,7 @@ void DumpString(const char* name, std::string value, bool end_line) { std::cout << std::endl; } } + +} // namespace transformers +} // namespace contrib } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h index a4bebdf8a0f67..2f80410011dab 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h @@ -6,7 +6,13 @@ #include #include "core/framework/tensorprotoutils.h" +#ifndef NDEBUG +//#define DEBUG_BEAM_SEARCH 1 // uncomment it for debugging beam search +#endif + namespace onnxruntime { +namespace contrib { +namespace transformers { #define MAX_ROW_OR_COLUMN 8 @@ -137,4 +143,6 @@ void ConfigureTensorDump(); void DisableTensorDump(); +} // namespace transformers +} // namespace contrib } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc new file mode 100644 index 0000000000000..a27726178776b --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc @@ -0,0 +1,452 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// there's no way to use a raw pointer as the copy destination with std::copy_n +// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset +// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +#include "core/framework/framework_common.h" +#include "core/framework/session_state.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/tensor/utils.h" +#include "gsl/gsl" +#include "gpt_subgraph.h" +#include "dump_tensor.h" + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +GptSubgraph::GptSubgraph( + const onnxruntime::Node& node_in, + const std::string& attribute_name, + const GraphViewer& subgraph_in) + : node(node_in), attribute(attribute_name), subgraph(subgraph_in), allocator_(nullptr) { + num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); + + auto& subgraph_inputs = subgraph.GetInputs(); + auto& subgraph_outputs = subgraph.GetOutputs(); + + // inputs: input_ids, position_ids, attention_mask, past_0, past_1, ... + // outputs: logits, present_0, present_1, ... + num_subgraph_inputs = static_cast(subgraph_inputs.size()); + num_subgraph_outputs = static_cast(subgraph_outputs.size()); + + // CheckSubgraph will verify inputs and outputs later. + subgraph_input_names.reserve(num_subgraph_inputs); + for (int i = 0; i < num_subgraph_inputs; ++i) { + subgraph_input_names.push_back(subgraph_inputs[i]->Name()); + } + + subgraph_output_names.reserve(num_subgraph_outputs); + for (int i = 0; i < num_subgraph_outputs; ++i) { + subgraph_output_names.push_back(subgraph_outputs[i]->Name()); + } +} + +Status GptSubgraph::Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) { + ORT_RETURN_IF(num_subgraph_outputs <= 1, + "Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in inputs and outputs)."); + + ORT_RETURN_IF(num_subgraph_inputs != num_subgraph_outputs + 2, + "Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2"); + + ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ", + subgraph_inputs[0]->Name()); + ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", "subgraph input 1 shall be named as position_ids, got: ", + subgraph_inputs[1]->Name()); + ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", "subgraph input 2 shall be named as attention_mask, got: ", + subgraph_inputs[2]->Name()); + ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", "subgraph input 3 shall be named as past_0, got: ", + subgraph_inputs[3]->Name()); + + // Past state shape is like (2, batch_size, 12, past_seq_len, 64). Here 12 and 64 are constants of num_heads and hidden_size/num_heads. + const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape(); + ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ", + past_shape->dim_size()); + + ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2, + "subgraph past state dimension 0 shall have length of 2"); + + ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for number of heads"); + + ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0, + "subgraph past state dimension 4 shall have a positive value for hidden size per head"); + + // check subgraph outputs + ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", "subgraph output 0 shall be named as logits, got: ", + subgraph_outputs[0]->Name()); + + ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", "subgraph input 1 shall be named as present_0, got: ", + subgraph_outputs[1]->Name()); + + // Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size. + const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); + ORT_RETURN_IF(logits_shape->dim_size() != 3, "subgraph logits output is expected to have 3 dimension, got ", + logits_shape->dim_size()); + + ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for vocabulary size"); + + // Save parameters related to the subgraph. + num_heads = static_cast(past_shape->dim(2).dim_value()); + head_size = static_cast(past_shape->dim(4).dim_value()); + vocab_size = static_cast(logits_shape->dim(2).dim_value()); + num_layers = static_cast(subgraph_outputs.size()) - 1; + + ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64, + "subgraph input 0 (input_ids) shall have int64 type"); + ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64, + "subgraph input 1 (position_ids) shall have int64 type"); + // TODO: support float16 + ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, + "subgraph input 2 (attention_mask) shall have float type"); + ORT_RETURN_IF(subgraph_inputs[3]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, + "subgraph input 3 (past_0) shall have float type"); + ORT_RETURN_IF(subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, + "subgraph output 0 (logits) shall have float type"); + ORT_RETURN_IF(subgraph_outputs[1]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, + "subgraph output 1 (present_0) shall have float type"); + + return Status::OK(); +} + +Status GptSubgraph::Setup(const SessionState& session_state, + const SessionState& subgraph_session_state) { + session_state_ = &session_state; + subgraph_session_state_ = &subgraph_session_state; + + std::vector feed_names; + feed_names.reserve(num_subgraph_inputs + num_implicit_inputs); + + // First, get the location of input_ids of current operator. + const auto& node_inputs = node.InputDefs(); + const OrtMemoryInfo& input_ids_location = utils::FindMemoryInfoForValue(session_state, node_inputs[0]->Name()); + + // position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. + // as we skip them when we call FindDevicesForValues, and default them to be in the same device as input_ids + feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end()); + + for (auto& entry : node.ImplicitInputDefs()) { + feed_names.push_back(entry->Name()); + } + + std::vector feed_locations; + feed_locations.resize(feed_names.size()); + + for (size_t i = 0, end = feed_names.size(); i < end; ++i) { + if (i >= subgraph_input_names.size()) { // implicit inputs + const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]); + feed_locations[i] = location.device; + } else { + feed_locations[i] = input_ids_location.device; + } + } + + std::unique_ptr ffm; + ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names, + subgraph_session_state.GetOrtValueNameIdxMap(), ffm)); + ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm)); + + // setup the locations where we want the subgraph output to end up on + std::vector fetch_locations; + fetch_locations.reserve(num_subgraph_outputs); + + // past state need to be where we can feed them in to the next iteration, so set the fetch location to match the feed location. + for (int i = 0; i < num_subgraph_outputs; ++i) { + fetch_locations.push_back(&input_ids_location); + } + + utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations); + + feeds_fetches_manager_ = std::move(ffm); + + // Check subgraph only need once so put in Setup function. + auto& inputs = subgraph.GetInputs(); + auto& outputs = subgraph.GetOutputs(); + ORT_RETURN_IF_ERROR(Validate(inputs, outputs)); + + return Status::OK(); +} + +void GptSubgraph::CreateInitialFeeds( + const Tensor& input_ids, + const std::vector& implicit_inputs, + int num_beams, + int pad_token_id, + std::vector& feeds) { + ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); + + // Subgraph inputs: + // input_ids: shape (B, S) wher B is batch size, and S is sequence length + // position_ids: shape (B, S) + // attention_mask: shape (B, P+S), where past_sequence_length (P) is 0 + // After expansion, their shapes will become (B, M*S), where M is num_beams. + + // Allocate subgraph inputs to be same device as input_ids + AllocatorPtr alloactor = session_state_->GetAllocator(input_ids.Location()); + + // Store allocator, which is needed in ExpandInputs. + allocator_ = alloactor; + + const TensorShape& input_ids_shape = input_ids.Shape(); + ORT_ENFORCE(input_ids_shape.NumDimensions() == 2); + const int64_t& batch_size = input_ids_shape[0]; + const int64_t& sequence_length = input_ids_shape[1]; + + // Allocate position_ids and attention_mask based on shape of input_ids + auto element_type = DataTypeImpl::GetType(); + + // input_ids for subgraph is int64, so we need Cast input_ids from int32 to int64. + OrtValue subgraph_input_ids; + // Current shape is (batch_size, sequence_length) + // Note that we will expand it to (batch_size * num_beams, sequence_length) later. + Tensor::InitOrtValue(element_type, input_ids_shape, alloactor, subgraph_input_ids); + + int64_t* subgraph_input_data = subgraph_input_ids.GetMutable()->MutableData(); + const int32_t* source = input_ids.Data(); + int64_t* target = subgraph_input_data; + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < sequence_length; j++, source++, target++) { + *target = static_cast(*source); + } + } + + OrtValue position_ids; + Tensor::InitOrtValue(element_type, input_ids_shape, alloactor, position_ids); + + OrtValue attention_mask; + auto mask_type = DataTypeImpl::GetType(); + Tensor::InitOrtValue(mask_type, input_ids_shape, alloactor, attention_mask); + + next_positions_.resize(batch_size * num_beams); + // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. + // Set position id to be 0 for pad tokens, and cumulated sum of mask in a batch for other tokens + float* mask_data = attention_mask.GetMutable()->MutableData(); + int64_t* position_data = position_ids.GetMutable()->MutableData(); + source = input_ids.Data(); + float* mask = mask_data; + int64_t* position = position_data; + for (int i = 0; i < batch_size; i++) { + int64_t abs_position = 0; + for (int j = 0; j < sequence_length; j++, source++, mask++, position++) { + if (*source == pad_token_id) { + *mask = 0.0f; + *position = 0; + } else { + *mask = 1.0f; + *position = abs_position; + abs_position++; + } + } + for (int k = 0; k < num_beams; k++) { + next_positions_[i * num_beams + k] = abs_position; + } + } + + // Initialize empty past state + auto past_type = DataTypeImpl::GetType(); + int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, 0, head_size}; + TensorShape past_shape(&past_state_dims[0], 5); + OrtValue empty_past; + Tensor::InitOrtValue(past_type, past_shape, allocator_, empty_past); + + // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) for input_ids, position_ids and attention_mask + // TODO: Try expand inputs/outputs after first subgraph call instead. That may get better peroformance, but more complex to implement. + OrtValue expanded_input_ids = ExpandInputs(subgraph_input_ids, num_beams); + OrtValue expanded_position_ids = ExpandInputs(position_ids, num_beams); + OrtValue expanded_attention_mask = ExpandInputs(attention_mask, num_beams); + + // The ordering is the same as used in Setup + feeds.reserve(num_subgraph_inputs + num_implicit_inputs); + feeds.push_back(expanded_input_ids); + feeds.push_back(expanded_position_ids); + feeds.push_back(expanded_attention_mask); + + // The remaing inputs are past state. + for (int i = 3; i < num_subgraph_inputs; ++i) { + feeds.push_back(empty_past); + } + + // pass in implicit inputs + for (const auto* entry : implicit_inputs) { + feeds.push_back(*entry); + } +} + +OrtValue GptSubgraph::ExpandInputs(const OrtValue& input, int num_beams) const { + if (num_beams == 1) + return input; + + // Given input of shape (batch_size, sequence_length), expand the shape to be (batch_size * num_beams, sequence_length) + const TensorShape& input_shape = input.Get().Shape(); + //ORT_ENFORCE(input_shape.NumDimensions() == 2 && input_shape[0] == parameters_->batch_size && input_shape[1] == parameters_->sequence_length); + + const int64_t& batch_size = input_shape[0]; + const int64_t& sequence_length = input_shape[1]; + int64_t dims[] = {batch_size * num_beams, sequence_length}; + TensorShape expanded_shape(&dims[0], 2); + + MLDataType element_type = input.Get().DataType(); + + OrtValue expanded; + Tensor::InitOrtValue(element_type, expanded_shape, allocator_, expanded); + + if (element_type == DataTypeImpl::GetType()) { + const int64_t* input_data = input.Get().Data(); + int64_t* expanded_data = expanded.GetMutable()->MutableData(); + int64_t* target = expanded_data; + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < num_beams; j++) { + memcpy(target, input_data + i * sequence_length, sizeof(int64_t) * sequence_length); + target += sequence_length; + } + } + } + else if (element_type == DataTypeImpl::GetType()) { + const float* input_data = input.Get().Data(); + float* expanded_data = expanded.GetMutable()->MutableData(); + float* target = expanded_data; + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < num_beams; j++) { + memcpy(target, input_data + i * sequence_length, sizeof(float) * sequence_length); + target += sequence_length; + } + } + } + + return expanded; +} + +void GptSubgraph::PickPastState(const std::vector& last_outputs, + std::vector& next_inputs, + gsl::span& beam_indices) { + for (int i = 3; i < num_subgraph_inputs; ++i) { + const OrtValue& present = last_outputs[i - 2]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64) + const TensorShape& past_shape = present.Get().Shape(); + + // Create a tensor with same shape. + OrtValue past; + auto past_type = DataTypeImpl::GetType(); //TODO: present.Type() + Tensor::InitOrtValue(past_type, past_shape, allocator_, past); + + auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4]; + auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4]; + + // TODO: support float16 + gsl::span past_span = past.GetMutable()->MutableDataAsSpan(); + gsl::span present_span = present.Get().DataAsSpan(); + for (gsl::index j = 0; j < beam_indices.length(); j++) { + int64_t beam_index = beam_indices[j]; + gsl::span present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); + gsl::span present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam); + + gsl::span past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam); + gsl::span past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam); + gsl::copy(present_key, past_key); + gsl::copy(present_value, past_value); +#ifdef DEBUG_BEAM_SEARCH + if (i == 3) // only dump past_0 + { + DumpString("past_key of beam", static_cast(j), true); + DumpTensor(nullptr, past_key.data(), 1, static_cast(block_size_per_beam)); + + DumpString("past_value of beam", static_cast(j), true); + DumpTensor(nullptr, past_value.data(), 1, static_cast(block_size_per_beam)); + } +#endif + } + + next_inputs[i] = past; + } +} + +Status GptSubgraph::UpdateFeeds( + const std::vector& last_outputs, + std::vector& next_inputs, + int current_length, + gsl::span beam_next_tokens, + gsl::span beam_indices, + int num_beams) { + // last_outputs: logits, present_0, present_1, ... + // next_inputs: input_ids, position_id, attention_mask, past_0, past_1 + + // The following updates inputs for subgraph + // TODO: Reuse buffer for input_ids and position_ids to reduce memory allocation. + + // Update input_ids with next tokens. + int batch_beam_size = static_cast(beam_next_tokens.length()); + int64_t dims[] = {batch_beam_size, 1}; + TensorShape input_ids_shape(&dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue input_ids; + Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, input_ids); + int64_t* input_ids_data = input_ids.GetMutable()->MutableData(); + for (int i = 0; i < batch_beam_size; i++) { + input_ids_data[i] = beam_next_tokens[i]; + } + next_inputs[0] = input_ids; + + // Update position IDs + OrtValue position_ids; + Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, position_ids); + int64_t* position_data = position_ids.GetMutable()->MutableData(); + for (int i = 0; i < batch_beam_size; i++) { + position_data[i] = next_positions_[i]; + next_positions_[i]++; + } + next_inputs[1] = position_ids; + + // Update attention mask + const OrtValue& old_mask = next_inputs[2]; + const float* old_mask_data = old_mask.Get().Data(); + int64_t mask_dims[] = {batch_beam_size, current_length}; + TensorShape mask_shape(&mask_dims[0], 2); + OrtValue attention_mask; + auto mask_type = DataTypeImpl::GetType(); + Tensor::InitOrtValue(mask_type, mask_shape, allocator_, attention_mask); + float* mask_data = attention_mask.GetMutable()->MutableData(); + for (int i = 0; i < batch_beam_size; i++) { + for (int j = 0; j < current_length - 1; j++) { + mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j]; + } + mask_data[i * current_length + current_length - 1] = 1.0f; + } + next_inputs[2] = attention_mask; + +#ifdef DEBUG_BEAM_SEARCH + DumpOrtValue("input_ids", input_ids); + DumpOrtValue("position_ids", position_ids); + DumpOrtValue("attention_mask", attention_mask); +#endif + + // Update past state + if (num_beams == 1) { + // feed present_* output to past_* inputs one by one + for (int i = 3; i < num_subgraph_inputs; ++i) { + next_inputs[i] = last_outputs[i - 2]; + } + } else { + PickPastState(last_outputs, next_inputs, beam_indices); + } + + return Status::OK(); +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h new file mode 100644 index 0000000000000..154dbf4b7390d --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +//#include +#include "gsl/gsl" +#include "core/framework/allocator.h" +#include "core/framework/session_state.h" +#include "core/framework/feeds_fetches_manager.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +// A class for GPT-2 subgraph inputs and outputs preparation. +struct GptSubgraph { + GptSubgraph( + const onnxruntime::Node& node_in, + const std::string& attribute_name, + const GraphViewer& subgraph_in); + + const onnxruntime::Node& node; // node that contains the subgraph + const std::string& attribute; // attribute of th node that contains the subgraph. Not used yet. + const GraphViewer& subgraph; // the subgraph + + int num_implicit_inputs; + + int num_subgraph_inputs; // same as subgraph_input_names.size(), keep it for convenience. + int num_subgraph_outputs; // same as subgraph_output_names.size() + + std::vector subgraph_input_names; + std::vector subgraph_output_names; + + // Parameters deduced from the subgraph + int num_heads; + int head_size; + int vocab_size; + int num_layers; + + // Setup exectuion + Status Setup(const SessionState& session_state, + const SessionState& subgraph_session_state); + + // Create inputs for first inference of subgraph. + void CreateInitialFeeds( + const Tensor& input_ids, + const std::vector& implicit_inputs, + int num_beams, + int pad_token_id, + std::vector& feeds); + + Status UpdateFeeds( + const std::vector& last_outputs, + std::vector& next_inputs, + int current_length, + gsl::span beam_next_tokens, + gsl::span beam_indices, + int num_beams); + + FeedsFetchesManager* GetFeedsFetchesManager() const { return feeds_fetches_manager_.get(); } + + protected: + Status Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs); + + OrtValue ExpandInputs(const OrtValue& input, int num_beams) const; + + void PickPastState(const std::vector& last_outputs, + std::vector& next_inputs, + gsl::span& beam_indices); + + // TODO: move it to make this class state less. + std::vector next_positions_; + + AllocatorPtr allocator_; + const SessionState* session_state_; + const SessionState* subgraph_session_state_; + std::unique_ptr feeds_fetches_manager_; +}; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc new file mode 100644 index 0000000000000..572be533970e6 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -0,0 +1,72 @@ +#include "sequences.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +void Sequences::Init(const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length) { + // Allocate buffer (shall we use allocator instead?) + sequences[0].assign(batch_beam_size * max_length, 0); + sequences[1].assign(batch_beam_size * max_length, 0); + + // copying input_ids to sequences[0] + gsl::span input = input_ids.Get().DataAsSpan(); + gsl::span output(sequences[0]); + for (int i = 0; i < batch_beam_size; i++) { + gsl::span source = input.subspan(i * sequence_length, sequence_length); + gsl::span target = output.subspan(i * max_length, sequence_length); + gsl::copy(source, target); + } + current_sequences_buffer = 0; + + batch_beam_size_ = batch_beam_size; + max_length_ = max_length; + current_length_ = sequence_length; +} + +gsl::span Sequences::GetSequence(int beam_index) const { + gsl::span buffer(sequences[current_sequences_buffer]); + gsl::span sequence = buffer.subspan(beam_index * max_length_, current_length_); + return sequence; +} + +int Sequences::GetSequenceLength() { + return current_length_; +} + +void Sequences::PrintSequences() { +#ifdef DEBUG_BEAM_SEARCH + for (int i = 0; i < batch_beam_size_; i++) { + gsl::span sequence = GetSequence(i); + DumpString("sequences", i, false); + DumpTensor(nullptr, sequence.data(), 1, current_length_); + } +#endif +} + +void Sequences::AppendNextTokenToSequences( + gsl::span& beam_indices, + gsl::span& beam_next_tokens) { + //sequences = torch.cat([sequences[beam_indices, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + gsl::span input(sequences[current_sequences_buffer]); + gsl::span output(sequences[1 - current_sequences_buffer]); + + for (int i = 0; i < batch_beam_size_; i++) { + int beam_index = static_cast(beam_indices[i]); + gsl::span source = input.subspan(beam_index * max_length_, current_length_); + gsl::span target = output.subspan(i * max_length_, current_length_); + gsl::copy(source, target); + } + + // append next token to each beam + for (int i = 0; i < batch_beam_size_; i++) { + output[i * max_length_ + current_length_] = beam_next_tokens[i]; + } + + ++current_length_; + current_sequences_buffer = 1 - current_sequences_buffer; // rotate buffer for next round +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h new file mode 100644 index 0000000000000..cab33c63b2c5c --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -0,0 +1,55 @@ +#pragma once + +#include "gsl/gsl" +#include "core/framework/ort_value.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +class ISequences { + public: + virtual ~ISequences() {} + virtual gsl::span GetSequence(int beam_index) const = 0; + virtual int GetSequenceLength() = 0; +}; + +// This class keeps track of sequences generated. +class Sequences : public ISequences { + public: + Sequences() {} + + // Initialize the sequence with initial input_ids and related parameters. + void Init(const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length); + + // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). + gsl::span GetSequence(int beam_index) const override; + + // Returns current sequence length. + int GetSequenceLength() override; + + // Print the sequences to StdOut in debug mode + void PrintSequences(); + + // Select sequences based on beam indices, then append next token to selected sequences. + void AppendNextTokenToSequences( + gsl::span& beam_indices, + gsl::span& beam_next_tokens); + + private: + // Two buffers of shape (batch_size, num_beams, max_seq_length) to store sequences. + // At each time, there is only one buffer is active. The other one will be active in next token. + // Each AppendNextTokenToSequences call will trigger a rotation of active buffer. + std::vector sequences[2]; + + // Index (either 0 or 1) of two buffers that is currently is active. + int current_sequences_buffer; + + int batch_beam_size_; + int max_length_; + int current_length_; +}; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file From 8bab52a563e1b3ce8aadf6ad60a285f9a2c31a18 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Nov 2021 09:45:20 -0800 Subject: [PATCH 26/53] move BeamSearchState from .h to .cc file --- .../cpu/transformers/beam_search.cc | 77 +++++++++++++++---- .../cpu/transformers/beam_search.h | 43 +---------- .../cpu/transformers/beam_search_parameters.h | 15 ++-- .../cpu/transformers/beam_search_scorer.cc | 22 +++--- .../cpu/transformers/gpt_subgraph.cc | 21 +++-- 5 files changed, 91 insertions(+), 87 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 3ffa8a5be723b..6e8ed4e5e361a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -50,6 +50,47 @@ REGISTER_KERNEL_TYPED(float) namespace transformers { +template +struct BeamSearchState { + // TODO: use allocater to allocate a buffer, and point each data to a span of the buffer + // so as to reuse related code in CUDA. + std::vector done; // shape (batch_size) + std::vector beam_scores; // shape (batch_size, num_beams) + + std::vector next_token_logits; // shape (batch_size * num_beams, vocab_size) + std::vector next_token_scores; // shape (batch_size, num_beams * vocab_size) + + std::vector next_tokens; // shape (batch_size, num_beams) + std::vector next_indices; // shape (batch_size, num_beams) + + Sequences sequences; + + std::vector scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) + + void Init(const OrtValue& input_ids, int batch_size, int num_beams, int vocab_size, int sequence_length, int max_length, bool output_scores) { + int batch_beam_size = batch_size * num_beams; + done.assign(batch_size, 0); + beam_scores.assign(batch_beam_size, 0.0f); + for (int i = 0; i < batch_size; i++) { + for (int j = 1; j < num_beams; j++) { + beam_scores[i * num_beams + j] = -1e9; + } + } + + next_token_logits.assign(batch_beam_size * vocab_size, 0.0f); + next_token_scores.assign(batch_beam_size * vocab_size, 0.0f); + + next_tokens.assign(batch_beam_size, 0); + next_indices.assign(batch_beam_size, 0); + + sequences.Init(input_ids, batch_beam_size, sequence_length, max_length); + + if (output_scores) { + scores.reserve((max_length - sequence_length) * batch_size * num_beams * vocab_size); + } + } +}; + template class BeamSearchImpl { public: @@ -60,14 +101,15 @@ class BeamSearchImpl { void* stream, BeamSearchParameters& params); - // Initialize by validating all the inputs, and allocating the output tensors + // Initialize by validating all the inputs, and allocating the output tensors. Status Initialize(); - // Execute the batch, by iterating the sequence in each batch entry - // and calling the subgraph with each item in the sequence. + // Execute beam search in iterations util stopping criteria is reached. + // In each iteration, GPT subgraph is called, and next token for each sequence is generated. Status Execute(const FeedsFetchesManager& cached_ffm); private: + // Validate inputs. Status CheckInputs(const OpKernelContextInternal& context); // Prepare the inputs for first inference of subgraph @@ -81,23 +123,25 @@ class BeamSearchImpl { gsl::span beam_next_tokens, gsl::span beam_indices); - // Process logits and append next tokens to sequences + // Process logits and append next tokens to sequences. Status GenerateNextToken(const OrtValue& logits, gsl::span& beam_next_tokens, gsl::span& beam_indices); + // Calculate scores from logits, then apply filtering and select next token for each beam. Status ProcessLogits(const OrtValue& logits, BeamSearchState& beam_state, int top_k, AllocatorPtr& allocator); - // Mask tokens accroding to vocab_mask + // Mask tokens according to vocab_mask. void ApplyVocabMask(gsl::span& next_token_scores); - // Apply repetion penalty + // Apply repetition penalty. void ApplyRepetitionPenalty(const Sequences& sequences, gsl::span& next_token_scores); OpKernelContextInternal& context_; + const SessionState& session_state_; GptSubgraph& gpt_subgraph_; @@ -120,12 +164,11 @@ class BeamSearchImpl { template void BeamSearch::Init(const OpKernelInfo& info) { - // make sure the attribute was present even though we don't need it here. + // Make sure the attribute was present even though we don't need it here. // The GraphProto is loaded as a Graph instance by main Graph::Resolve, // and a SessionState instance for executing the subgraph is created by InferenceSession. // This is available via Info().GetSubgraphSessionState("attribute_name") when Compute is called. ONNX_NAMESPACE::GraphProto proto; - ORT_ENFORCE(info.GetAttr("body", &proto).IsOK()); ORT_IGNORE_RETURN_VALUE(proto); @@ -221,6 +264,8 @@ Status BeamSearchImpl::CheckInputs(const OpKernelContextInternal& context) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' is expected to have 1 dimension, got ", vocab_mask_dims.size()); } + + // There is dependency on vocab_size parameter, which shall be set before calling this function. if (static_cast(vocab_mask_dims[0]) != parameters_->vocab_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' shape does not match with vocab_size, got ", vocab_mask_dims[0]); @@ -262,7 +307,6 @@ Status BeamSearchImpl::Initialize() { ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, "'num_return_sequences' has to be smaller or equal to 'num_beams'."); - // CheckInputs shall be after CheckSubgraph due to its dependency on vocab_size ORT_RETURN_IF_ERROR(CheckInputs(context_)); // This flag will be updated later when the scores output exists. @@ -288,7 +332,6 @@ Status BeamSearchImpl::ProcessLogits( const int& vocab_size = parameters_->vocab_size; #ifdef DEBUG_BEAM_SEARCH - //DumpOrtValue("input_ids", input_ids); DumpOrtValue("logits", logits); #endif @@ -298,7 +341,7 @@ Status BeamSearchImpl::ProcessLogits( ORT_ENFORCE(logits_shape.NumDimensions() == 3); // The sequence length of input_ids for the logits. - // It equals parameters_->sequence_length for first subgraph call, and 1 for the remaining. + // It equals to parameters_->sequence_length for first subgraph call, and 1 for the remaining calls. auto input_length = logits_shape[1]; // Get logits for the last token, where logits has shape (batch_size * num_beams, input_length, vocab_size) @@ -347,8 +390,9 @@ Status BeamSearchImpl::ProcessLogits( beam_state.scores.insert(beam_state.scores.end(), next_token_scores.begin(), next_token_scores.end()); } - //next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - //next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True) + // Apply top-k selection like the following: + // next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + // next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True) int64_t next_token_scores_dims[] = {parameters_->batch_size, parameters_->num_beams * vocab_size}; TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); auto element_type = DataTypeImpl::GetType(); @@ -377,8 +421,9 @@ Status BeamSearchImpl::ProcessLogits( DumpTensor("topk_indices", *(topk_indices.get())); #endif - //next_indices = (next_tokens / vocab_size).long() - //next_tokens = next_tokens % vocab_size + // Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following: + // next_indices = (next_tokens / vocab_size).long() + // next_tokens = next_tokens % vocab_size gsl::span next_token_indices = topk_indices->DataAsSpan(); beam_state.next_indices.resize(parameters_->batch_size * k); beam_state.next_tokens.resize(parameters_->batch_size * k); @@ -402,7 +447,7 @@ Status BeamSearchImpl::ProcessLogits( beam_scorer_->Process( &(beam_state.sequences), - next_scores, //next_token_scores, + next_scores, next_tokens, next_indices, allocator); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index ae3640fc12a2f..497780c1f11d4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -17,47 +17,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { -template -struct BeamSearchState { - // TODO: use allocater to allocate a buffer, and point each data to a span of the buffer - // so as to reuse related code in CUDA. - std::vector done; // shape (batch_size) - std::vector beam_scores; // shape (batch_size, num_beams) - - std::vector next_token_logits; // shape (batch_size * num_beams, vocab_size) - std::vector next_token_scores; // shape (batch_size, num_beams * vocab_size) - - std::vector next_tokens; // shape (batch_size, num_beams) - std::vector next_indices; // shape (batch_size, num_beams) - - Sequences sequences; - - std::vector scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) - - void Init(const OrtValue& input_ids, int batch_size, int num_beams, int vocab_size, int sequence_length, int max_length, bool output_scores) { - int batch_beam_size = batch_size * num_beams; - done.assign(batch_size, 0); - beam_scores.assign(batch_beam_size, 0.0f); - for (int i = 0; i < batch_size; i++) { - for (int j = 1; j < num_beams; j++) { - beam_scores[i * num_beams + j] = -1e9; - } - } - - next_token_logits.assign(batch_beam_size * vocab_size, 0.0f); - next_token_scores.assign(batch_beam_size * vocab_size, 0.0f); - - next_tokens.assign(batch_beam_size, 0); - next_indices.assign(batch_beam_size, 0); - - sequences.Init(input_ids, batch_beam_size, sequence_length, max_length); - - if (output_scores) { - scores.reserve((max_length - sequence_length) * batch_size * num_beams * vocab_size); - } - } -}; - template class BeamSearch : public controlflow::IControlFlowKernel { public: @@ -77,7 +36,7 @@ class BeamSearch : public controlflow::IControlFlowKernel { void SetComputeStream(void* stream) { stream_ = stream; } private: - // Subgraph info and FeedsFetchesManager re-used for each subgraph execution. + // Subgraph and FeedsFetchesManager re-used for each subgraph execution. std::unique_ptr gpt_subgraph_; FeedsFetchesManager* feeds_fetches_manager_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index 3ff1eede0e0ca..5553f88308618 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -10,13 +10,13 @@ namespace contrib { namespace transformers { struct BeamSearchParameters { - // from node attributes + // Parameters from node attributes int eos_token_id; int pad_token_id; int no_repeat_ngram_size; bool early_stopping; - // from inputs + // Parameters from inputs int min_length; int max_length; int num_beams; @@ -29,14 +29,15 @@ struct BeamSearchParameters { gsl::span vocab_mask; - // from outputs + // Parameters from outputs. bool output_scores; // whether scores existed in output - // deduce from subgraph + // Parameters from subgraph. int vocab_size; - int num_heads; // not used - int head_size; // not used - int num_layers; // not used + // Below are used in CPU, reserved for CUDA. + int num_heads; + int head_size; + int num_layers; Status Validate(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index 802b84afc6e83..b3d9d4c472ea5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -131,7 +131,7 @@ void BeamSearchScorer::Process(ISequences* sequences, gsl::span& next_tokens, gsl::span& next_indices, AllocatorPtr& allocator) { - // sequences shape is (batch_size * num_beams, total_sequence_length) + // Sequences shape is (batch_size * num_beams, total_sequence_length) // It contains word ID of whole sequence generated so far. // It is different from subgraph input_ids, which only need one word when past state is not empty. @@ -140,7 +140,7 @@ void BeamSearchScorer::Process(ISequences* sequences, ORT_ENFORCE(next_scores.size() == next_tokens.size()); ORT_ENFORCE(next_scores.size() == next_indices.size()); - // Allocate buffers only once + // Allocate buffers only once. if (next_beam_scores_.empty()) { size_t batch_beam_size = static_cast(batch_size_ * num_beams_); const bool fill_zeros = false; @@ -148,7 +148,7 @@ void BeamSearchScorer::Process(ISequences* sequences, next_beam_tokens_ = Allocate(allocator, batch_beam_size, next_beam_tokens_ptr_, fill_zeros); next_beam_indices_ = Allocate(allocator, batch_beam_size, next_beam_indices_ptr_, fill_zeros); - // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length + // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length. int buffer_per_beam = (max_length_ * (max_length_ + 1) - (sequence_length - 1) * sequence_length) / 2; hypothesis_buffer_length_ = batch_beam_size * static_cast(buffer_per_beam); hypothesis_buffer_ = Allocate(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, fill_zeros); @@ -159,7 +159,7 @@ void BeamSearchScorer::Process(ISequences* sequences, if (done_[batch]) { ORT_ENFORCE(beam_hyp.Size() >= num_beams_, "Batch can only be done if all beams have been generated"); - // Pad the batch + // Pad the batch. for (int j = 0; j < num_beams_; j++) { next_beam_scores_[batch * num_beams_ + j] = 0.0f; next_beam_tokens_[batch * num_beams_ + j] = pad_token_id_; @@ -168,7 +168,7 @@ void BeamSearchScorer::Process(ISequences* sequences, continue; } - // Next tokens for this sentence + // Next tokens for this sentence. int beam_idx = 0; int top_k = 2 * num_beams_; for (int j = 0; j < top_k; j++) { @@ -177,7 +177,7 @@ void BeamSearchScorer::Process(ISequences* sequences, int64_t next_index = next_indices[batch * top_k + j]; int batch_beam_idx = batch * num_beams_ + static_cast(next_index); - // Add to generated hypotheses if end of sentence + // Add to generated hypotheses if end of sentence. if ((eos_token_id_ >= 0) && (next_token == eos_token_id_)) { bool is_beam_token_worse_than_top_num_beams = (j >= num_beams_); if (is_beam_token_worse_than_top_num_beams) { @@ -192,7 +192,7 @@ void BeamSearchScorer::Process(ISequences* sequences, auto sequence = clone.template as_span(); beam_hyp.Add(sequence, next_score); } else { - // Add next predicted token since it is not eos_token + // Add next predicted token since it is not eos_token. next_beam_scores_[batch * num_beams_ + beam_idx] = next_score; next_beam_tokens_[batch * num_beams_ + beam_idx] = next_token; next_beam_indices_[batch * num_beams_ + beam_idx] = batch_beam_idx; @@ -226,7 +226,7 @@ void BeamSearchScorer::Finalize(ISequences* sequences, ORT_ENFORCE(sequences != nullptr); ORT_ENFORCE(output_sequences != nullptr); - // finalize all open beam hypotheses and add to generated hypotheses + // Finalize all open beam hypotheses and add to generated hypotheses. for (int batch_index = 0; batch_index < batch_size_; batch_index++) { BeamHypotheses& beam_hyp = beam_hyps[batch_index]; if (done_[batch_index]) { @@ -241,19 +241,19 @@ void BeamSearchScorer::Finalize(ISequences* sequences, } } - // word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length) + // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). gsl::span output = output_sequences->MutableDataAsSpan(); // Fill output sequences with pad token ID so that we do not need append it later. std::fill_n(output.data(), output.size(), pad_token_id_); - // score of each sequence, with shape (batch_size * num_return_sequences) + // Score of each sequence, with shape (batch_size * num_return_sequences). gsl::span sequence_scores; if (output_sequence_scores != nullptr) { sequence_scores = output_sequence_scores->MutableDataAsSpan(); } - // span is empty when output_sequence_scores is NULL. + // Span is empty when output_sequence_scores is NULL. gsl::span batch_sequence_score; // Select the best hypotheses according to number of sequences to return. diff --git a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc index a27726178776b..cca9bb311cb60 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc @@ -302,7 +302,7 @@ OrtValue GptSubgraph::ExpandInputs(const OrtValue& input, int num_beams) const { TensorShape expanded_shape(&dims[0], 2); MLDataType element_type = input.Get().DataType(); - + OrtValue expanded; Tensor::InitOrtValue(element_type, expanded_shape, allocator_, expanded); @@ -316,17 +316,16 @@ OrtValue GptSubgraph::ExpandInputs(const OrtValue& input, int num_beams) const { target += sequence_length; } } - } - else if (element_type == DataTypeImpl::GetType()) { - const float* input_data = input.Get().Data(); - float* expanded_data = expanded.GetMutable()->MutableData(); - float* target = expanded_data; - for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < num_beams; j++) { - memcpy(target, input_data + i * sequence_length, sizeof(float) * sequence_length); - target += sequence_length; - } + } else if (element_type == DataTypeImpl::GetType()) { + const float* input_data = input.Get().Data(); + float* expanded_data = expanded.GetMutable()->MutableData(); + float* target = expanded_data; + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < num_beams; j++) { + memcpy(target, input_data + i * sequence_length, sizeof(float) * sequence_length); + target += sequence_length; } + } } return expanded; From 7db869c9ac99c03b63a963f4428bba10cce7a0d0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Nov 2021 12:11:19 -0800 Subject: [PATCH 27/53] adjust logits processor order --- onnxruntime/contrib_ops/cpu/transformers/beam_search.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 6e8ed4e5e361a..2f97f938bcaea 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -371,9 +371,9 @@ Status BeamSearchImpl::ProcessLogits( } // Apply all score processors that updates scores - ApplyVocabMask(next_token_scores); ApplyRepetitionPenalty(beam_state.sequences, next_token_scores); - + ApplyVocabMask(next_token_scores); + // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) // TODO: use thread pool to parrellel int offset = 0; From b23540e03eabc09521f1fc9451837e3174c1c60f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Nov 2021 12:17:40 -0800 Subject: [PATCH 28/53] add batch generation example --- .../tools/transformers/convert_beam_search.py | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 9abde89d78d25..b21bd8f103ae7 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -8,6 +8,7 @@ from pathlib import Path from onnx import helper import numpy as np +import torch from transformers import GPT2Config from gpt2_helper import PRETRAINED_GPT2_MODELS from convert_to_onnx import main as convert_gpt2_to_onnx @@ -293,11 +294,21 @@ def convert_model(args): def test_model(args): from transformers import GPT2Tokenizer, GPT2LMHeadModel + tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + tokenizer.padding_side = "left" + tokenizer.pad_token = tokenizer.eos_token + model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir, pad_token_id=tokenizer.eos_token_id) - input_ids = tokenizer.encode('I enjoy walking in the park', return_tensors='pt') + + # use different length sentences to test batching + sentences = ["The product is released", "I enjoy walking in the park"] + + inputs = tokenizer(sentences, return_tensors='pt', padding=True) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] bad_words = "walk in park" bad_words_ids = tokenizer.encode(bad_words, add_prefix_space=True) @@ -305,9 +316,7 @@ def test_model(args): print("bad_words_ids", bad_words_ids) global config - if config is None: - config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) - + config = model.config eos_token_id = config.eos_token_id pad_token_id = config.eos_token_id vocab_size = config.vocab_size @@ -315,7 +324,8 @@ def test_model(args): if args.run_baseline: print('-' * 50) print("Test PyTorch model and beam search with huggingface transformers...") - beam_outputs = model.generate(input_ids, + beam_outputs = model.generate(input_ids=input_ids, + attention_mask=attention_mask, max_length=args.max_length, min_length=args.min_length, num_beams=args.num_beams, @@ -350,9 +360,6 @@ def test_model(args): ort_session = create_ort_session(args.output, args.use_gpu) - batch_size = 1 - input_ids = input_ids.repeat(batch_size, 1) - vocab_mask = np.ones((vocab_size), dtype=np.int32) for bad_word_id in bad_words_ids: vocab_mask[bad_word_id] = 0 @@ -393,6 +400,16 @@ def test_model(args): sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True) print(f"batch {i} sequence {j}: {sequence}") + if args.run_baseline: + torch_sequences = beam_outputs.sequences.reshape(sequences.shape) + ort_sequences = torch.LongTensor(sequences) + print(torch_sequences) + print(ort_sequences) + is_same = torch.equal(torch_sequences, ort_sequences) + + print("Torch and ORT result is ", "same" if is_same else "different") + return is_same + def main(): args = parse_arguments() From 25f7605fc3403f19e762acb4d9bb37c4a29ea085 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Nov 2021 22:57:43 -0800 Subject: [PATCH 29/53] fix repetition penalty for dup words in sequence --- .../contrib_ops/cpu/transformers/beam_search.cc | 12 ++++++++++-- .../python/tools/transformers/convert_beam_search.py | 7 +++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 2f97f938bcaea..afda329f8cbc7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -373,7 +373,7 @@ Status BeamSearchImpl::ProcessLogits( // Apply all score processors that updates scores ApplyRepetitionPenalty(beam_state.sequences, next_token_scores); ApplyVocabMask(next_token_scores); - + // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) // TODO: use thread pool to parrellel int offset = 0; @@ -514,8 +514,16 @@ void BeamSearchImpl::ApplyRepetitionPenalty(const Sequences& sequences, gsl:: for (int i = 0; i < batch_beam_size; i++) { gsl::span beam_token_scores = next_token_scores.subspan(i * parameters_->vocab_size, parameters_->vocab_size); gsl::span sequence = sequences.GetSequence(i); - for (const int64_t& word_id : sequence) { + + // Find unique word IDs in sequence. + std::unordered_set unique_word_ids; + for (const auto& word_id : sequence) { + unique_word_ids.insert(word_id); + } + + for (const int64_t word_id : unique_word_ids) { T score = beam_token_scores[word_id]; + // If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability, // This assumes that scores are either positive (like ctrl) or negative (like GPT-2), but not a mixture. beam_token_scores[word_id] = (score < 0 ? score * parameters_->repetition_penalty : score / parameters_->repetition_penalty); diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index b21bd8f103ae7..a842bc697ed27 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -403,14 +403,17 @@ def test_model(args): if args.run_baseline: torch_sequences = beam_outputs.sequences.reshape(sequences.shape) ort_sequences = torch.LongTensor(sequences) + print("-" * 50) + print("Torch Sequences:") print(torch_sequences) + print("-" * 50) + print("ORT Sequences:") print(ort_sequences) + print("-" * 50) is_same = torch.equal(torch_sequences, ort_sequences) - print("Torch and ORT result is ", "same" if is_same else "different") return is_same - def main(): args = parse_arguments() From 25751e7a0f5f6367273052aa4c07f9bff1120860 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 26 Nov 2021 01:51:09 -0800 Subject: [PATCH 30/53] Add test --- .../tools/transformers/convert_beam_search.py | 6 ++-- .../python/transformers/test_beam_search.py | 32 +++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/test/python/transformers/test_beam_search.py diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index a842bc697ed27..b24a58e331ef0 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -414,15 +414,15 @@ def test_model(args): print("Torch and ORT result is ", "same" if is_same else "different") return is_same -def main(): - args = parse_arguments() +def main(argv=None): + args = parse_arguments(argv) if os.path.exists(args.output): print(f"skip conversion since path existed: {args.output}") else: convert_model(args) - test_model(args) + return test_model(args) if __name__ == '__main__': diff --git a/onnxruntime/test/python/transformers/test_beam_search.py b/onnxruntime/test/python/transformers/test_beam_search.py new file mode 100644 index 0000000000000..a629bd36b149a --- /dev/null +++ b/onnxruntime/test/python/transformers/test_beam_search.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +# For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG + +import unittest +import os +import pytest + +class TestBeamSearch(unittest.TestCase): + def setUp(self): + from onnxruntime import get_available_providers + self.test_cuda = 'CUDAExecutionProvider' in get_available_providers() + + def run_beam_search(self, arguments: str): + from onnxruntime.transformers.convert_beam_search import main as run + return run(arguments.split()) + + @pytest.mark.slow + def test_profiler_cpu(self): + gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') + beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search_v1.onnx') + result = self.run_beam_search(f'-m gpt2 --gpt2_onnx {gpt2_onnx_path} --output {beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0 --run_baseline') + self.assertTrue(result, "ORT and PyTorch is expected to have same result, but current result is different") + +if __name__ == '__main__': + unittest.main() From 6a2427f3385efdf02df40e1caa67673ae8ce7f4d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 3 Dec 2021 14:06:44 -0800 Subject: [PATCH 31/53] Add no repeat ngram processor --- .../cpu/transformers/beam_search.cc | 70 +++++++++++++++---- .../transformers/beam_search_parameters.cc | 1 + .../contrib_ops/cpu/transformers/sequences.cc | 2 +- .../contrib_ops/cpu/transformers/sequences.h | 4 +- .../python/transformers/test_beam_search.py | 22 +++++- 5 files changed, 81 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index afda329f8cbc7..ef8d5ac5b8be6 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -131,7 +131,6 @@ class BeamSearchImpl { // Calculate scores from logits, then apply filtering and select next token for each beam. Status ProcessLogits(const OrtValue& logits, BeamSearchState& beam_state, - int top_k, AllocatorPtr& allocator); // Mask tokens according to vocab_mask. @@ -140,6 +139,12 @@ class BeamSearchImpl { // Apply repetition penalty. void ApplyRepetitionPenalty(const Sequences& sequences, gsl::span& next_token_scores); + // Apply constraint of No repeat NGram Size . + void ApplyNoRepeatNGram(const Sequences& sequences, gsl::span& next_token_scores); + + // Apply constraint of mininal sequence length + void ApplyMinLength(const Sequences& sequences, gsl::span& next_token_scores); + OpKernelContextInternal& context_; const SessionState& session_state_; @@ -326,7 +331,6 @@ template Status BeamSearchImpl::ProcessLogits( const OrtValue& logits, // logits output of subgraph BeamSearchState& beam_state, - int top_k, AllocatorPtr& allocator) { const int64_t batch_beam_size = static_cast(parameters_->batch_size * parameters_->num_beams); const int& vocab_size = parameters_->vocab_size; @@ -372,8 +376,10 @@ Status BeamSearchImpl::ProcessLogits( // Apply all score processors that updates scores ApplyRepetitionPenalty(beam_state.sequences, next_token_scores); + ApplyNoRepeatNGram(beam_state.sequences, next_token_scores); ApplyVocabMask(next_token_scores); - + ApplyMinLength(beam_state.sequences, next_token_scores); + // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) // TODO: use thread pool to parrellel int offset = 0; @@ -405,7 +411,7 @@ Status BeamSearchImpl::ProcessLogits( #endif const int axis = 1; - const unsigned k = static_cast(top_k); + const unsigned top_k = static_cast(2 * parameters_->num_beams); const bool largest = true; const bool sorted = true; // results returned in sorted order. @@ -425,11 +431,11 @@ Status BeamSearchImpl::ProcessLogits( // next_indices = (next_tokens / vocab_size).long() // next_tokens = next_tokens % vocab_size gsl::span next_token_indices = topk_indices->DataAsSpan(); - beam_state.next_indices.resize(parameters_->batch_size * k); - beam_state.next_tokens.resize(parameters_->batch_size * k); + beam_state.next_indices.resize(parameters_->batch_size * top_k); + beam_state.next_tokens.resize(parameters_->batch_size * top_k); offset = 0; for (int i = 0; i < parameters_->batch_size; i++) { - for (unsigned int j = 0; j < k; j++, offset++) { + for (unsigned int j = 0; j < top_k; j++, offset++) { beam_state.next_indices[offset] = next_token_indices[offset] / vocab_size; beam_state.next_tokens[offset] = next_token_indices[offset] % vocab_size; } @@ -440,9 +446,9 @@ Status BeamSearchImpl::ProcessLogits( gsl::span next_indices(beam_state.next_indices.data(), beam_state.next_indices.size()); #ifdef DEBUG_BEAM_SEARCH - DumpTensor("next_scores before scorer", next_scores.data(), parameters_->batch_size, k); - DumpTensor("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, k); - DumpTensor("next_indices before scorer", next_indices.data(), parameters_->batch_size, k); + DumpTensor("next_scores before scorer", next_scores.data(), parameters_->batch_size, top_k); + DumpTensor("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, top_k); + DumpTensor("next_indices before scorer", next_indices.data(), parameters_->batch_size, top_k); #endif beam_scorer_->Process( @@ -460,10 +466,8 @@ Status BeamSearchImpl::GenerateNextToken( const OrtValue& logits, gsl::span& beam_next_tokens, gsl::span& beam_indices) { - // Process logits to get next token scores, and select top_k = 2 * num_beams - // TODO: we might not need 2 * num_beams when logits processors does not update token scores. - const int top_k = 2 * parameters_->num_beams; - ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state_, top_k, allocator_)); + // Process logits to get next token scores + ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state_, allocator_)); gsl::span& beam_scores = beam_scorer_->GetNextScores(); // TODO: may not need clone beam_scores. @@ -531,6 +535,44 @@ void BeamSearchImpl::ApplyRepetitionPenalty(const Sequences& sequences, gsl:: } } +template +void BeamSearchImpl::ApplyNoRepeatNGram(const Sequences& sequences, gsl::span& next_token_scores) { + if (parameters_->no_repeat_ngram_size == 0 || parameters_->no_repeat_ngram_size > sequences.GetSequenceLength()) { + return; + } + + const gsl::index prefix_length = static_cast(parameters_->no_repeat_ngram_size - 1); + int batch_beam_size = parameters_->BatchBeamSize(); + + for (int i = 0; i < batch_beam_size; i++) { + gsl::span beam_token_scores = next_token_scores.subspan(i * parameters_->vocab_size, parameters_->vocab_size); + gsl::span sequence = sequences.GetSequence(i); + + gsl::span prefix = sequence.subspan(sequence.length() - prefix_length); + ORT_ENFORCE(prefix.length() == prefix_length); + + std::unordered_set blocked_word_ids; + for (int j = 0; j <= sequence.length() - parameters_->no_repeat_ngram_size; j++) { + // Here we use naive algorithm for matching. The complexity is O(batch_beam_size * ngram_size * sequence_length) + // TODO: build N-Gram index (hash table with prefix of length NGram - 1 as key, and list of last word of NGram as value) for fast matching. + if (parameters_->no_repeat_ngram_size == 1 || prefix == sequence.subspan(j, prefix_length)) { + blocked_word_ids.insert(sequence[j + prefix_length]); + } + } + + for (const int64_t word_id : blocked_word_ids) { + beam_token_scores[word_id] = std::numeric_limits::lowest(); + } + } +} + +template +void BeamSearchImpl::ApplyMinLength(const Sequences& sequences, gsl::span& next_token_scores) { + if (sequences.GetSequenceLength() < parameters_->min_length) { + next_token_scores[parameters_->eos_token_id] = std::numeric_limits::lowest(); + } +} + template Status BeamSearchImpl::UpdateFeeds( const std::vector& last_outputs, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 37fdc60b46096..002abfd6af336 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -11,6 +11,7 @@ constexpr int kMaxSequenceLength = 4096; Status BeamSearchParameters::Validate() { ORT_RETURN_IF(eos_token_id < 0, "eos_token_id is invalid"); ORT_RETURN_IF(pad_token_id < 0, "pad_token_id is invalid"); + ORT_RETURN_IF(min_length >= max_length, "min_length shall be smaller than max_length"); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index 572be533970e6..3d4f6c244462e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -30,7 +30,7 @@ gsl::span Sequences::GetSequence(int beam_index) const { return sequence; } -int Sequences::GetSequenceLength() { +int Sequences::GetSequenceLength() const { return current_length_; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index cab33c63b2c5c..96f231099a5f3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -11,7 +11,7 @@ class ISequences { public: virtual ~ISequences() {} virtual gsl::span GetSequence(int beam_index) const = 0; - virtual int GetSequenceLength() = 0; + virtual int GetSequenceLength() const = 0; }; // This class keeps track of sequences generated. @@ -26,7 +26,7 @@ class Sequences : public ISequences { gsl::span GetSequence(int beam_index) const override; // Returns current sequence length. - int GetSequenceLength() override; + int GetSequenceLength() const override; // Print the sequences to StdOut in debug mode void PrintSequences(); diff --git a/onnxruntime/test/python/transformers/test_beam_search.py b/onnxruntime/test/python/transformers/test_beam_search.py index a629bd36b149a..8461e0f9964f8 100644 --- a/onnxruntime/test/python/transformers/test_beam_search.py +++ b/onnxruntime/test/python/transformers/test_beam_search.py @@ -22,11 +22,31 @@ def run_beam_search(self, arguments: str): return run(arguments.split()) @pytest.mark.slow - def test_profiler_cpu(self): + def test_cpu(self): gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search_v1.onnx') result = self.run_beam_search(f'-m gpt2 --gpt2_onnx {gpt2_onnx_path} --output {beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0 --run_baseline') + os.remove(gpt2_onnx_path) + os.remove(beam_search_onnx_path) self.assertTrue(result, "ORT and PyTorch is expected to have same result, but current result is different") + @pytest.mark.slow + def test_no_repeat_ngram_1(self): + gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') + beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search_v1.onnx') + result = self.run_beam_search(f'-m gpt2 --gpt2_onnx {gpt2_onnx_path} --output {beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0 --no_repeat_ngram_size 1 --run_baseline') + os.remove(gpt2_onnx_path) + os.remove(beam_search_onnx_path) + self.assertTrue(result, "ORT and PyTorch is expected to have same result, but current result is different") + + @pytest.mark.slow + def test_no_repeat_ngram_2(self): + gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') + beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search_v1.onnx') + result = self.run_beam_search(f'-m gpt2 --gpt2_onnx {gpt2_onnx_path} --output {beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0 --no_repeat_ngram_size 2 --run_baseline') + os.remove(gpt2_onnx_path) + os.remove(beam_search_onnx_path) + self.assertTrue(result, "ORT and PyTorch is expected to have same result, but current result is different") + if __name__ == '__main__': unittest.main() From 5e42e69d6e6ff09ffdeab689d24b725707b0ec59 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 6 Dec 2021 11:19:59 -0800 Subject: [PATCH 32/53] refactoring: move logits processor to classes --- .../cpu/transformers/beam_search.cc | 141 ++++---------- .../cpu/transformers/beam_search.h | 1 - .../transformers/beam_search_parameters.cc | 2 +- .../cpu/transformers/beam_search_parameters.h | 7 +- .../cpu/transformers/logits_processor.cc | 174 ++++++++++++++++++ .../cpu/transformers/logits_processor.h | 98 ++++++++++ .../tools/transformers/convert_beam_search.py | 39 ++-- 7 files changed, 336 insertions(+), 126 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc create mode 100644 onnxruntime/contrib_ops/cpu/transformers/logits_processor.h diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index ef8d5ac5b8be6..1facf6a9aefef 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -23,6 +23,7 @@ #include "gsl/gsl" #include "core/providers/cpu/math/softmax_shared.h" #include "beam_search.h" +#include "logits_processor.h" #include "dump_tensor.h" #ifdef _MSC_VER @@ -68,8 +69,12 @@ struct BeamSearchState { std::vector scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) void Init(const OrtValue& input_ids, int batch_size, int num_beams, int vocab_size, int sequence_length, int max_length, bool output_scores) { - int batch_beam_size = batch_size * num_beams; done.assign(batch_size, 0); + + int batch_beam_size = batch_size * num_beams; + + // Initialize score of first beam of each group with 0 and the rest with -1e9. + // This ensures that the beams in the same group don't produce same tokens every time. beam_scores.assign(batch_beam_size, 0.0f); for (int i = 0; i < batch_size; i++) { for (int j = 1; j < num_beams; j++) { @@ -133,18 +138,6 @@ class BeamSearchImpl { BeamSearchState& beam_state, AllocatorPtr& allocator); - // Mask tokens according to vocab_mask. - void ApplyVocabMask(gsl::span& next_token_scores); - - // Apply repetition penalty. - void ApplyRepetitionPenalty(const Sequences& sequences, gsl::span& next_token_scores); - - // Apply constraint of No repeat NGram Size . - void ApplyNoRepeatNGram(const Sequences& sequences, gsl::span& next_token_scores); - - // Apply constraint of mininal sequence length - void ApplyMinLength(const Sequences& sequences, gsl::span& next_token_scores); - OpKernelContextInternal& context_; const SessionState& session_state_; @@ -160,6 +153,8 @@ class BeamSearchImpl { BeamSearchParameters* parameters_; + LogitsProcessorList logits_processors_; + std::unique_ptr> beam_scorer_; BeamSearchState beam_state_; @@ -317,6 +312,9 @@ Status BeamSearchImpl::Initialize() { // This flag will be updated later when the scores output exists. parameters_->output_scores = false; + // Initialize processsors after CheckInputs so that parameters_->vocab_mask is ready. + logits_processors_.Init(*parameters_); + return status; } @@ -335,10 +333,6 @@ Status BeamSearchImpl::ProcessLogits( const int64_t batch_beam_size = static_cast(parameters_->batch_size * parameters_->num_beams); const int& vocab_size = parameters_->vocab_size; -#ifdef DEBUG_BEAM_SEARCH - DumpOrtValue("logits", logits); -#endif - const T* logits_data = logits.Get().Data(); const TensorShape& logits_shape = logits.Get().Shape(); @@ -358,10 +352,15 @@ Status BeamSearchImpl::ProcessLogits( gsl::span source(current_logits, vocab_size); gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); gsl::copy(source, target); - current_logits += i * (input_length * vocab_size); + current_logits += input_length * vocab_size; } } +#ifdef DEBUG_BEAM_SEARCH + //DumpOrtValue("logits", logits); + DumpTensor("next_token_logits", next_token_logits.data(), parameters_->batch_size, parameters_->num_beams, vocab_size); +#endif + // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) auto next_token_scores = gsl::make_span(beam_state.next_token_scores); Status status = SoftmaxCPU(batch_beam_size, // rows @@ -374,12 +373,17 @@ Status BeamSearchImpl::ProcessLogits( return status; } +#ifdef DEBUG_BEAM_SEARCH + DumpTensor("next_token_scores after softmax", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size); +#endif + // Apply all score processors that updates scores - ApplyRepetitionPenalty(beam_state.sequences, next_token_scores); - ApplyNoRepeatNGram(beam_state.sequences, next_token_scores); - ApplyVocabMask(next_token_scores); - ApplyMinLength(beam_state.sequences, next_token_scores); - + logits_processors_.Process(&(beam_state.sequences), next_token_scores); + +#ifdef DEBUG_BEAM_SEARCH + DumpTensor("next_token_scores after logits processor", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size); +#endif + // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) // TODO: use thread pool to parrellel int offset = 0; @@ -392,6 +396,10 @@ Status BeamSearchImpl::ProcessLogits( } } +#ifdef DEBUG_BEAM_SEARCH + DumpTensor("next_token_scores after adding beam_scores", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size); +#endif + if (parameters_->output_scores) { beam_state.scores.insert(beam_state.scores.end(), next_token_scores.begin(), next_token_scores.end()); } @@ -406,10 +414,6 @@ Status BeamSearchImpl::ProcessLogits( Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value); const Tensor& input = next_token_scores_value.Get(); -#ifdef DEBUG_BEAM_SEARCH - DumpOrtValue("next_token_scores_value", next_token_scores_value); -#endif - const int axis = 1; const unsigned top_k = static_cast(2 * parameters_->num_beams); const bool largest = true; @@ -490,89 +494,6 @@ Status BeamSearchImpl::GenerateNextToken( return Status::OK(); } -template -void BeamSearchImpl::ApplyVocabMask(gsl::span& next_token_scores) { - // Process vocabulary mask and set tokens with mask value 0 to -inf. - auto& vocab_mask = parameters_->vocab_mask; - if (!vocab_mask.empty()) { - T* p = next_token_scores.data(); - // next_token_scores shape (batch_size * num_beams, vocab_size), vocab_mask shape (vocab_size) - for (int i = 0; i < parameters_->batch_size * parameters_->num_beams; i++) { - for (int j = 0; j < parameters_->vocab_size; j++, p++) { - if (vocab_mask[j] == 0) { - *p = std::numeric_limits::lowest(); - } - } - } - } - return; -} - -template -void BeamSearchImpl::ApplyRepetitionPenalty(const Sequences& sequences, gsl::span& next_token_scores) { - if (parameters_->repetition_penalty == 1.0f) { // no penalty - return; - } - - int batch_beam_size = parameters_->BatchBeamSize(); - for (int i = 0; i < batch_beam_size; i++) { - gsl::span beam_token_scores = next_token_scores.subspan(i * parameters_->vocab_size, parameters_->vocab_size); - gsl::span sequence = sequences.GetSequence(i); - - // Find unique word IDs in sequence. - std::unordered_set unique_word_ids; - for (const auto& word_id : sequence) { - unique_word_ids.insert(word_id); - } - - for (const int64_t word_id : unique_word_ids) { - T score = beam_token_scores[word_id]; - - // If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability, - // This assumes that scores are either positive (like ctrl) or negative (like GPT-2), but not a mixture. - beam_token_scores[word_id] = (score < 0 ? score * parameters_->repetition_penalty : score / parameters_->repetition_penalty); - } - } -} - -template -void BeamSearchImpl::ApplyNoRepeatNGram(const Sequences& sequences, gsl::span& next_token_scores) { - if (parameters_->no_repeat_ngram_size == 0 || parameters_->no_repeat_ngram_size > sequences.GetSequenceLength()) { - return; - } - - const gsl::index prefix_length = static_cast(parameters_->no_repeat_ngram_size - 1); - int batch_beam_size = parameters_->BatchBeamSize(); - - for (int i = 0; i < batch_beam_size; i++) { - gsl::span beam_token_scores = next_token_scores.subspan(i * parameters_->vocab_size, parameters_->vocab_size); - gsl::span sequence = sequences.GetSequence(i); - - gsl::span prefix = sequence.subspan(sequence.length() - prefix_length); - ORT_ENFORCE(prefix.length() == prefix_length); - - std::unordered_set blocked_word_ids; - for (int j = 0; j <= sequence.length() - parameters_->no_repeat_ngram_size; j++) { - // Here we use naive algorithm for matching. The complexity is O(batch_beam_size * ngram_size * sequence_length) - // TODO: build N-Gram index (hash table with prefix of length NGram - 1 as key, and list of last word of NGram as value) for fast matching. - if (parameters_->no_repeat_ngram_size == 1 || prefix == sequence.subspan(j, prefix_length)) { - blocked_word_ids.insert(sequence[j + prefix_length]); - } - } - - for (const int64_t word_id : blocked_word_ids) { - beam_token_scores[word_id] = std::numeric_limits::lowest(); - } - } -} - -template -void BeamSearchImpl::ApplyMinLength(const Sequences& sequences, gsl::span& next_token_scores) { - if (sequences.GetSequenceLength() < parameters_->min_length) { - next_token_scores[parameters_->eos_token_id] = std::numeric_limits::lowest(); - } -} - template Status BeamSearchImpl::UpdateFeeds( const std::vector& last_outputs, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 497780c1f11d4..cc359f2549dee 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -32,7 +32,6 @@ class BeamSearch : public controlflow::IControlFlowKernel { static std::unique_ptr Create(const OpKernelInfo& info, void* stream); protected: - void SetComputeStream(void* stream) { stream_ = stream; } private: diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 002abfd6af336..fc0f61705b35b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -8,7 +8,7 @@ namespace transformers { constexpr int kMaxSequenceLength = 4096; -Status BeamSearchParameters::Validate() { +Status BeamSearchParameters::Validate() const { ORT_RETURN_IF(eos_token_id < 0, "eos_token_id is invalid"); ORT_RETURN_IF(pad_token_id < 0, "pad_token_id is invalid"); ORT_RETURN_IF(min_length >= max_length, "min_length shall be smaller than max_length"); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index 5553f88308618..26de2a98408eb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -39,11 +39,14 @@ struct BeamSearchParameters { int head_size; int num_layers; - Status Validate(); + Status Validate() const; - int BatchBeamSize() { return batch_size * num_beams; } + int BatchBeamSize() const { return batch_size * num_beams; } + void ParseFromAttributes(const OpKernelInfo& info); + void ParseFromInputs(OpKernelContext* context); + void SetSubgraphParameters(int vocab_size, int num_heads, int head_size, int num_layers); }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc new file mode 100644 index 0000000000000..0c549cc078cff --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -0,0 +1,174 @@ +#include +#include "logits_processor.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +template +gsl::span NextTokenScores::GetScores(int batch_beam_index) { + assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); + return scores.subspan(batch_beam_index * vocab_size, vocab_size); +} + +// template +// void NextTokenScores::SetScore(int batch_beam_index, int token_id, T score) { +// assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); +// assert(token_id >= 0 && token_id < vocab_size); +// scores[batch_beam_index * vocab_size + token_id] = score; +// } + +template +void NextTokenScores::SetScore(int token_id, T score) { + assert(token_id >= 0 && token_id < vocab_size); + for (int i = 0; i < batch_beam_size; i++) { + scores[i * vocab_size + token_id] = score; + } +} + +// Interface for all scorers for beam search or beam sample. +template +MinLengthLogitsProcessor::MinLengthLogitsProcessor(int min_length, int eos_token_id) + : min_length_(min_length), eos_token_id_(eos_token_id) {} + +template +void MinLengthLogitsProcessor::Process(const ISequences* sequences, + NextTokenScores& next_token_scores) { + if (sequences->GetSequenceLength() < min_length_) { + next_token_scores.SetScore(eos_token_id_, std::numeric_limits::lowest()); + } +} + +template +RepetitionPenaltyLogitsProcessor::RepetitionPenaltyLogitsProcessor(float penalty) : penalty_(penalty) { +} + +template +void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, + NextTokenScores& next_token_scores) { + const int batch_beam_size = next_token_scores.batch_beam_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span beam_token_scores = next_token_scores.GetScores(i); + gsl::span sequence = sequences->GetSequence(i); + + // Find unique word IDs in sequence. + std::unordered_set unique_word_ids; + for (const auto& word_id : sequence) { + unique_word_ids.insert(word_id); + } + + for (const int64_t word_id : unique_word_ids) { + T score = beam_token_scores[word_id]; + + // If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability, + // This assumes that scores are either positive (like ctrl) or negative (like GPT-2), but not a mixture. + beam_token_scores[word_id] = (score < 0 ? score * penalty_ : score / penalty_); + } + } +} + +template +NoRepeatNGramLogitsProcessor::NoRepeatNGramLogitsProcessor(int ngram_size) : ngram_size_(ngram_size) { +} + +template +void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, + NextTokenScores& next_token_scores) { + if (ngram_size_ == 0 || ngram_size_ > sequences->GetSequenceLength()) { + return; + } + + const gsl::index prefix_length = static_cast(ngram_size_ - 1); + int batch_beam_size = next_token_scores.batch_beam_size; + + for (int i = 0; i < batch_beam_size; i++) { + gsl::span beam_token_scores = next_token_scores.GetScores(i); + gsl::span sequence = sequences->GetSequence(i); + + gsl::span prefix = sequence.subspan(sequence.length() - prefix_length); + ORT_ENFORCE(prefix.length() == prefix_length); + + std::unordered_set blocked_word_ids; + for (int j = 0; j <= sequence.length() - ngram_size_; j++) { + // Here we use naive algorithm for matching. The complexity is O(batch_beam_size * ngram_size * sequence_length) + // TODO: build N-Gram index (hash table with prefix of length NGram - 1 as key, and list of last word of NGram as value) for fast matching. + if (ngram_size_ == 1 || prefix == sequence.subspan(j, prefix_length)) { + blocked_word_ids.insert(sequence[j + prefix_length]); + } + } + + for (const int64_t word_id : blocked_word_ids) { + beam_token_scores[word_id] = std::numeric_limits::lowest(); + } + } +} + +template +VocabMaskLogitsProcessor::VocabMaskLogitsProcessor(const gsl::span& vocab_mask) : vocab_mask_(vocab_mask) { +} + +template +void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, + NextTokenScores& next_token_scores) { + assert(!vocab_mask_.empty()); + + // Process vocabulary mask and set tokens with mask value 0 to -inf. + T* p = next_token_scores.scores.data(); + // next_token_scores shape (batch_size * num_beams, vocab_size) + // vocab_mask shape (vocab_size). TODO: support shape (batch_size, vocab_size) + for (int i = 0; i < next_token_scores.batch_beam_size; i++) { + for (int j = 0; j < next_token_scores.vocab_size; j++, p++) { + if (vocab_mask_[j] == 0) { + *p = std::numeric_limits::lowest(); + } + } + } +} + +template +void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { + processor_list_.clear(); + + if (parameters.repetition_penalty != 1.0f) { // 1.0 means no penalty + repetition_penalty_processor_ = std::make_unique>(parameters.repetition_penalty); + processor_list_.push_back(repetition_penalty_processor_.get()); + } + + if (parameters.no_repeat_ngram_size > 0) { + no_repeat_ngram_processor_ = std::make_unique>(parameters.no_repeat_ngram_size); + processor_list_.push_back(no_repeat_ngram_processor_.get()); + } + + if (!parameters.vocab_mask.empty()) { + vocab_mask_processor_ = std::make_unique>(parameters.vocab_mask); + processor_list_.push_back(vocab_mask_processor_.get()); + } + + if (parameters.min_length > 0) { + min_length_processor_ = std::make_unique>(parameters.min_length, parameters.eos_token_id); + processor_list_.push_back(min_length_processor_.get()); + } + + batch_beam_size_ = parameters.BatchBeamSize(); + vocab_size_ = parameters.vocab_size; +} + +template +void LogitsProcessorList::Process(const ISequences* sequences, + gsl::span& next_token_scores) { + NextTokenScores input_scores = {next_token_scores, batch_beam_size_, vocab_size_}; + for (size_t i = 0; i < processor_list_.size(); i++) { + processor_list_[i]->Process(sequences, input_scores); + } +} + +// Instantiation +template class MinLengthLogitsProcessor; +template class RepetitionPenaltyLogitsProcessor; +template class NoRepeatNGramLogitsProcessor; +template class VocabMaskLogitsProcessor; +template class LogitsProcessorList; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h new file mode 100644 index 0000000000000..30f3180f1d098 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -0,0 +1,98 @@ +#pragma once +#include "sequences.h" +#include "beam_search_parameters.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +template +struct NextTokenScores { + gsl::span& scores; + int batch_beam_size; + int vocab_size; + gsl::span GetScores(int batch_beam_index); + //void SetScore(int batch_beam_index, int token_id, T score); + void SetScore(int token_id, T score); +}; + +// Interface for all scorers for beam search or beam sample. +template +class ILogitsProcessor { + public: + virtual ~ILogitsProcessor() {} + + virtual void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) = 0; +}; + +template +class MinLengthLogitsProcessor : public ILogitsProcessor { + public: + MinLengthLogitsProcessor(int min_length, int eos_token_id); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + int min_length_; + int eos_token_id_; +}; + +template +class RepetitionPenaltyLogitsProcessor : public ILogitsProcessor { + public: + RepetitionPenaltyLogitsProcessor(float penalty); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + float penalty_; +}; + +template +class NoRepeatNGramLogitsProcessor : public ILogitsProcessor { + public: + NoRepeatNGramLogitsProcessor(int ngram_size); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + int ngram_size_; +}; + +template +class VocabMaskLogitsProcessor : public ILogitsProcessor { + public: + VocabMaskLogitsProcessor(const gsl::span& vocab_mask); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + gsl::span vocab_mask_; +}; + +template +class LogitsProcessorList { +public: + LogitsProcessorList() = default ; + void Init(const BeamSearchParameters& parameters); + void Process(const ISequences* sequences, gsl::span& next_token_scores); + +private: + int batch_beam_size_; + int vocab_size_; + std::vector*> processor_list_; + + std::unique_ptr> repetition_penalty_processor_; + std::unique_ptr> no_repeat_ngram_processor_; + std::unique_ptr> vocab_mask_processor_; + std::unique_ptr> min_length_processor_; +}; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index b24a58e331ef0..8921df3952c23 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -8,6 +8,7 @@ from pathlib import Path from onnx import helper import numpy as np +from typing import List import torch from transformers import GPT2Config from gpt2_helper import PRETRAINED_GPT2_MODELS @@ -292,7 +293,7 @@ def convert_model(args): onnx.save(new_model, args.output) -def test_model(args): +def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): from transformers import GPT2Tokenizer, GPT2LMHeadModel tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) @@ -303,8 +304,9 @@ def test_model(args): cache_dir=args.cache_dir, pad_token_id=tokenizer.eos_token_id) - # use different length sentences to test batching - sentences = ["The product is released", "I enjoy walking in the park"] + # Use different length sentences to test batching + if sentences is None: + sentences = ["The product is released", "I enjoy walking in the park", "Test best way to invest"] inputs = tokenizer(sentences, return_tensors='pt', padding=True) input_ids = inputs["input_ids"] @@ -313,7 +315,10 @@ def test_model(args): bad_words = "walk in park" bad_words_ids = tokenizer.encode(bad_words, add_prefix_space=True) bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list - print("bad_words_ids", bad_words_ids) + if use_vocab_mask: + print("bad_words_ids", bad_words_ids) + else: + bad_words_ids = None global config config = model.config @@ -321,6 +326,7 @@ def test_model(args): pad_token_id = config.eos_token_id vocab_size = config.vocab_size + torch_decoded_sequences = [] if args.run_baseline: print('-' * 50) print("Test PyTorch model and beam search with huggingface transformers...") @@ -348,7 +354,9 @@ def test_model(args): if args.output_token_scores: print("scores", beam_outputs.scores) for i, sequence in enumerate(beam_outputs.sequences): - print("{}: {}".format(i, tokenizer.decode(sequence, skip_special_tokens=True))) + decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True) + torch_decoded_sequences.append(decoded_sequence) + print("{}: {}".format(i, decoded_sequence)) print('-' * 50) print("Test ONNX model and bream search with onnxruntime...") @@ -361,8 +369,9 @@ def test_model(args): ort_session = create_ort_session(args.output, args.use_gpu) vocab_mask = np.ones((vocab_size), dtype=np.int32) - for bad_word_id in bad_words_ids: - vocab_mask[bad_word_id] = 0 + if use_vocab_mask: + for bad_word_id in bad_words_ids: + vocab_mask[bad_word_id] = 0 inputs = { "input_ids": input_ids.cpu().numpy().astype(np.int32), @@ -395,25 +404,31 @@ def test_model(args): print("scores", result[2]) (batch_size, num_sequences, max_length) = sequences.shape + ort_decoded_sequences = [] for i in range(batch_size): for j in range(num_sequences): - sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True) - print(f"batch {i} sequence {j}: {sequence}") + decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True) + ort_decoded_sequences.append(decoded_sequence) + print(f"batch {i} sequence {j}: {decoded_sequence}") if args.run_baseline: - torch_sequences = beam_outputs.sequences.reshape(sequences.shape) + torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1) ort_sequences = torch.LongTensor(sequences) print("-" * 50) print("Torch Sequences:") print(torch_sequences) + print(torch_decoded_sequences) print("-" * 50) print("ORT Sequences:") print(ort_sequences) + print(ort_decoded_sequences) print("-" * 50) - is_same = torch.equal(torch_sequences, ort_sequences) + # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not. + is_same = (torch_decoded_sequences == ort_decoded_sequences) print("Torch and ORT result is ", "same" if is_same else "different") return is_same + def main(argv=None): args = parse_arguments(argv) @@ -422,7 +437,7 @@ def main(argv=None): else: convert_model(args) - return test_model(args) + return test_model(args, use_vocab_mask=True) if __name__ == '__main__': From 312d5486cc665abb1859ca0052250f8a62a443da Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 6 Dec 2021 12:26:06 -0800 Subject: [PATCH 33/53] fix build warning --- .../cpu/transformers/dump_tensor.cc | 5 +-- .../cpu/transformers/logits_processor.cc | 34 ++++++++++++++----- .../cpu/transformers/logits_processor.h | 3 +- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc index 1e671c01e9038..8aa9e8907fa52 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc @@ -37,8 +37,9 @@ void DumpOrtValue(const char* name, const OrtValue& value) { } void ConfigureTensorDump() { - if (ParseEnvironmentVariableWithDefault(dump_tensor_env_vars::kDumpBeamSearch, false)) { - g_enable_tensor_dump = true; + const auto parsed = ParseEnvironmentVariable(dump_tensor_env_vars::kDumpBeamSearch); + if (parsed.has_value()) { + g_enable_tensor_dump = *parsed; } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 0c549cc078cff..10ef8b6f698d5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -1,5 +1,6 @@ #include #include "logits_processor.h" +#include "dump_tensor.h" namespace onnxruntime { namespace contrib { @@ -11,13 +12,6 @@ gsl::span NextTokenScores::GetScores(int batch_beam_index) { return scores.subspan(batch_beam_index * vocab_size, vocab_size); } -// template -// void NextTokenScores::SetScore(int batch_beam_index, int token_id, T score) { -// assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); -// assert(token_id >= 0 && token_id < vocab_size); -// scores[batch_beam_index * vocab_size + token_id] = score; -// } - template void NextTokenScores::SetScore(int token_id, T score) { assert(token_id >= 0 && token_id < vocab_size); @@ -26,6 +20,14 @@ void NextTokenScores::SetScore(int token_id, T score) { } } +#ifdef DEBUG_BEAM_SEARCH +template +void DumpScores(const char* name, gsl::span& scores) { + DumpString(name, 0, true); + ORT_UNUSED_PARAMETER(scores); +} +#endif + // Interface for all scorers for beam search or beam sample. template MinLengthLogitsProcessor::MinLengthLogitsProcessor(int min_length, int eos_token_id) @@ -37,6 +39,10 @@ void MinLengthLogitsProcessor::Process(const ISequences* sequences, if (sequences->GetSequenceLength() < min_length_) { next_token_scores.SetScore(eos_token_id_, std::numeric_limits::lowest()); } + +#ifdef DEBUG_BEAM_SEARCH + DumpScores("MinLengthLogitsProcessor", next_token_scores.scores); +#endif } template @@ -65,6 +71,10 @@ void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, beam_token_scores[word_id] = (score < 0 ? score * penalty_ : score / penalty_); } } + +#ifdef DEBUG_BEAM_SEARCH + DumpScores("RepetitionPenaltyLogitsProcessor", next_token_scores.scores); +#endif } template @@ -89,7 +99,7 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, ORT_ENFORCE(prefix.length() == prefix_length); std::unordered_set blocked_word_ids; - for (int j = 0; j <= sequence.length() - ngram_size_; j++) { + for (int j = 0; j <= static_cast(sequence.length()) - ngram_size_; j++) { // Here we use naive algorithm for matching. The complexity is O(batch_beam_size * ngram_size * sequence_length) // TODO: build N-Gram index (hash table with prefix of length NGram - 1 as key, and list of last word of NGram as value) for fast matching. if (ngram_size_ == 1 || prefix == sequence.subspan(j, prefix_length)) { @@ -101,6 +111,10 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, beam_token_scores[word_id] = std::numeric_limits::lowest(); } } + +#ifdef DEBUG_BEAM_SEARCH + DumpScores("NoRepeatNGramLogitsProcessor", next_token_scores.scores); +#endif } template @@ -123,6 +137,10 @@ void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, } } } + +#ifdef DEBUG_BEAM_SEARCH + DumpScores("VocabMaskLogitsProcessor", next_token_scores.scores); +#endif } template diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 30f3180f1d098..78fe9acf63bcb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -11,8 +11,9 @@ struct NextTokenScores { gsl::span& scores; int batch_beam_size; int vocab_size; + gsl::span GetScores(int batch_beam_index); - //void SetScore(int batch_beam_index, int token_id, T score); + void SetScore(int token_id, T score); }; From 3371cb45ccc6abea4343780119e719d9b26bd20d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 7 Dec 2021 12:14:48 -0800 Subject: [PATCH 34/53] show latency --- .../tools/transformers/convert_beam_search.py | 37 +++++++++----- .../python/transformers/test_beam_search.py | 49 ++++++++----------- 2 files changed, 47 insertions(+), 39 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 8921df3952c23..9f8841dfbb5e1 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import os +import time import onnx import logging import argparse @@ -65,8 +66,14 @@ def parse_arguments(argv=None): parser.add_argument('-e', '--use_external_data_format', required=False, action='store_true') parser.set_defaults(use_external_data_format=False) - parser.add_argument('--run_baseline', required=False, action='store_true', help="run huggingface beam search") - parser.set_defaults(run_baseline=False) + parser.add_argument('--disable_parity', required=False, action='store_true', help="do not run parity test") + parser.set_defaults(disable_parity=False) + + parser.add_argument('--total_runs', + required=False, + type=int, + default=1, + help='Number of times of inference for latency measurement') beam_search_group = parser.add_argument_group("beam search options") @@ -327,7 +334,7 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): vocab_size = config.vocab_size torch_decoded_sequences = [] - if args.run_baseline: + if not args.disable_parity: print('-' * 50) print("Test PyTorch model and beam search with huggingface transformers...") beam_outputs = model.generate(input_ids=input_ids, @@ -361,11 +368,6 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): print('-' * 50) print("Test ONNX model and bream search with onnxruntime...") - # TODO: remove debug code - import time - print('You have 15 seconds to attach a debugger.') - time.sleep(15) - ort_session = create_ort_session(args.output, args.use_gpu) vocab_mask = np.ones((vocab_size), dtype=np.int32) @@ -394,7 +396,17 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): output_test_data(dir, inputs) print("inputs", inputs) - result = ort_session.run(None, inputs) + + # Test performance + latency = [] + for _ in range(args.total_runs): + start = time.time() + result = ort_session.run(None, inputs) + latency.append(time.time() - start) + batch_size = input_ids.shape[0] + from benchmark_helper import get_latency_result + output = get_latency_result(latency, batch_size) + print("ORT outputs:") sequences = result[0] print("sequences", sequences) @@ -411,7 +423,7 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): ort_decoded_sequences.append(decoded_sequence) print(f"batch {i} sequence {j}: {decoded_sequence}") - if args.run_baseline: + if not args.disable_parity: torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1) ort_sequences = torch.LongTensor(sequences) print("-" * 50) @@ -426,7 +438,10 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not. is_same = (torch_decoded_sequences == ort_decoded_sequences) print("Torch and ORT result is ", "same" if is_same else "different") - return is_same + output["parity"] = is_same + + print(output) + return output def main(argv=None): diff --git a/onnxruntime/test/python/transformers/test_beam_search.py b/onnxruntime/test/python/transformers/test_beam_search.py index 8461e0f9964f8..1296191d568dd 100644 --- a/onnxruntime/test/python/transformers/test_beam_search.py +++ b/onnxruntime/test/python/transformers/test_beam_search.py @@ -6,47 +6,40 @@ # license information. # -------------------------------------------------------------------------- -# For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG - import unittest import os import pytest +from parity_utilities import find_transformers_source +if find_transformers_source(): + from convert_beam_search import main as run +else: + from onnxruntime.transformers.convert_beam_search import main as run + class TestBeamSearch(unittest.TestCase): def setUp(self): - from onnxruntime import get_available_providers - self.test_cuda = 'CUDAExecutionProvider' in get_available_providers() - + self.model_name = "gpt2" + self.gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') + self.beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search.onnx') + self.cpu_params = f'-m {self.model_name} --gpt2_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0' + def run_beam_search(self, arguments: str): - from onnxruntime.transformers.convert_beam_search import main as run return run(arguments.split()) @pytest.mark.slow def test_cpu(self): - gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') - beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search_v1.onnx') - result = self.run_beam_search(f'-m gpt2 --gpt2_onnx {gpt2_onnx_path} --output {beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0 --run_baseline') - os.remove(gpt2_onnx_path) - os.remove(beam_search_onnx_path) - self.assertTrue(result, "ORT and PyTorch is expected to have same result, but current result is different") + result = self.run_beam_search(self.cpu_params) + os.remove(self.gpt2_onnx_path) + os.remove(self.beam_search_onnx_path) + self.assertTrue(result["parity"], "ORT and PyTorch result is different") @pytest.mark.slow - def test_no_repeat_ngram_1(self): - gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') - beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search_v1.onnx') - result = self.run_beam_search(f'-m gpt2 --gpt2_onnx {gpt2_onnx_path} --output {beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0 --no_repeat_ngram_size 1 --run_baseline') - os.remove(gpt2_onnx_path) - os.remove(beam_search_onnx_path) - self.assertTrue(result, "ORT and PyTorch is expected to have same result, but current result is different") - - @pytest.mark.slow - def test_no_repeat_ngram_2(self): - gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') - beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search_v1.onnx') - result = self.run_beam_search(f'-m gpt2 --gpt2_onnx {gpt2_onnx_path} --output {beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0 --no_repeat_ngram_size 2 --run_baseline') - os.remove(gpt2_onnx_path) - os.remove(beam_search_onnx_path) - self.assertTrue(result, "ORT and PyTorch is expected to have same result, but current result is different") + def test_cpu_no_repeat_ngram(self): + for ngram_size in [1, 2]: + result = self.run_beam_search(self.cpu_params + f' --no_repeat_ngram_size {ngram_size}') + os.remove(self.gpt2_onnx_path) + os.remove(self.beam_search_onnx_path) + self.assertTrue(result["parity"], "ORT and PyTorch result is different") if __name__ == '__main__': unittest.main() From a63a473abab6570603179e72c6868a806fc4274b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 7 Dec 2021 12:34:19 -0800 Subject: [PATCH 35/53] use allocator in beam state --- .../cpu/transformers/beam_search.cc | 153 ++++++++++++------ .../cpu/transformers/beam_search_scorer.cc | 40 ++--- .../cpu/transformers/beam_search_scorer.h | 17 +- .../python/transformers/test_beam_search.py | 8 +- 4 files changed, 140 insertions(+), 78 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 1facf6a9aefef..d7fbb0e657d42 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -52,48 +52,89 @@ REGISTER_KERNEL_TYPED(float) namespace transformers { template -struct BeamSearchState { - // TODO: use allocater to allocate a buffer, and point each data to a span of the buffer - // so as to reuse related code in CUDA. - std::vector done; // shape (batch_size) - std::vector beam_scores; // shape (batch_size, num_beams) +gsl::span AllocateScratchBuffer(AllocatorPtr allocator, + BufferUniquePtr& buffer, + size_t elements, + bool fill = false, + T fill_value = T{}) { + size_t bytes = SafeInt(sizeof(T)) * elements; + void* data = allocator->Alloc(bytes); + buffer = std::move(BufferUniquePtr(data, BufferDeleter(allocator))); + T* first = reinterpret_cast(buffer.get()); + auto span = gsl::make_span(first, elements); + + if (fill) { + std::fill_n(first, elements, fill_value); + } + + return span; +} - std::vector next_token_logits; // shape (batch_size * num_beams, vocab_size) - std::vector next_token_scores; // shape (batch_size, num_beams * vocab_size) +template +struct BeamSearchState { + gsl::span beam_scores; // shape (batch_size, num_beams) - std::vector next_tokens; // shape (batch_size, num_beams) - std::vector next_indices; // shape (batch_size, num_beams) + gsl::span next_token_logits; // shape (batch_size * num_beams, vocab_size) + gsl::span next_token_scores; // shape (batch_size, num_beams * vocab_size) - Sequences sequences; + gsl::span next_tokens; // shape (batch_size, 2 * num_beams) + gsl::span next_indices; // shape (batch_size, 2 * num_beams) - std::vector scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) + gsl::span scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) + gsl::span remaining_scores; // subspan of scores that is not used - void Init(const OrtValue& input_ids, int batch_size, int num_beams, int vocab_size, int sequence_length, int max_length, bool output_scores) { - done.assign(batch_size, 0); + // gsl::span sequences_space; // shape (2, batch_size, num_beams, max_seq_length) - int batch_beam_size = batch_size * num_beams; + Sequences sequences; + void Init(AllocatorPtr allocator, + const OrtValue& input_ids, + int batch_size, + int num_beams, + int vocab_size, + int sequence_length, + int max_length, + bool output_scores) { + + size_t batch_beam_size = SafeInt(batch_size) * num_beams; + beam_scores = AllocateScratchBuffer(allocator, beam_scores_buffer_, batch_beam_size, true, static_cast(0)); + // Initialize score of first beam of each group with 0 and the rest with -1e9. // This ensures that the beams in the same group don't produce same tokens every time. - beam_scores.assign(batch_beam_size, 0.0f); for (int i = 0; i < batch_size; i++) { for (int j = 1; j < num_beams; j++) { beam_scores[i * num_beams + j] = -1e9; } } - next_token_logits.assign(batch_beam_size * vocab_size, 0.0f); - next_token_scores.assign(batch_beam_size * vocab_size, 0.0f); + size_t next_token_size = SafeInt(batch_beam_size) * vocab_size; + next_token_logits = AllocateScratchBuffer(allocator, next_token_logits_buffer_, next_token_size, true, static_cast(0)); + next_token_scores = AllocateScratchBuffer(allocator, next_token_scores_buffer_, next_token_size, true, static_cast(0)); - next_tokens.assign(batch_beam_size, 0); - next_indices.assign(batch_beam_size, 0); + next_tokens = AllocateScratchBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size, true, static_cast(0)); - sequences.Init(input_ids, batch_beam_size, sequence_length, max_length); + next_indices = AllocateScratchBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size, true, static_cast(0)); if (output_scores) { - scores.reserve((max_length - sequence_length) * batch_size * num_beams * vocab_size); + size_t elements = SafeInt(max_length - sequence_length) * batch_size * num_beams * vocab_size; + scores = AllocateScratchBuffer(allocator, scores_buffer_, elements); + remaining_scores = scores; } + + // size_t sequences_space_size = SafeInt(2) * batch_size * num_beams * max_length; + // sequences_space = AllocateScratchBuffer(allocator, sequences_space_buffer_, sequences_space_size, true, static_cast(0)); + + sequences.Init(input_ids, static_cast(batch_beam_size), sequence_length, max_length); } + +private: + BufferUniquePtr beam_scores_buffer_; + BufferUniquePtr next_token_logits_buffer_; + BufferUniquePtr next_token_scores_buffer_; + BufferUniquePtr next_tokens_buffer_; + BufferUniquePtr next_indices_buffer_; + // BufferUniquePtr sequences_space_buffer_; + BufferUniquePtr scores_buffer_; }; template @@ -131,7 +172,8 @@ class BeamSearchImpl { // Process logits and append next tokens to sequences. Status GenerateNextToken(const OrtValue& logits, gsl::span& beam_next_tokens, - gsl::span& beam_indices); + gsl::span& beam_indices, + BeamSearchState& beam_state); // Calculate scores from logits, then apply filtering and select next token for each beam. Status ProcessLogits(const OrtValue& logits, @@ -157,8 +199,6 @@ class BeamSearchImpl { std::unique_ptr> beam_scorer_; - BeamSearchState beam_state_; - AllocatorPtr allocator_; }; @@ -345,7 +385,7 @@ Status BeamSearchImpl::ProcessLogits( // Get logits for the last token, where logits has shape (batch_size * num_beams, input_length, vocab_size) // next_token_logits = logits[:, -1, :], where its shape is (batch_size * num_beams, vocab_size) // When input_length == 1, use logits directly to avoid copy logits to next_token_logits. - auto next_token_logits = gsl::make_span(beam_state.next_token_logits); + gsl::span& next_token_logits = beam_state.next_token_logits; if (input_length > 1) { const T* current_logits = logits_data + (input_length - 1) * vocab_size; for (int i = 0; i < batch_beam_size; i++) { @@ -362,7 +402,7 @@ Status BeamSearchImpl::ProcessLogits( #endif // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) - auto next_token_scores = gsl::make_span(beam_state.next_token_scores); + gsl::span& next_token_scores = beam_state.next_token_scores; Status status = SoftmaxCPU(batch_beam_size, // rows vocab_size, // elements per row input_length > 1 ? next_token_logits.data() : logits_data, @@ -401,7 +441,9 @@ Status BeamSearchImpl::ProcessLogits( #endif if (parameters_->output_scores) { - beam_state.scores.insert(beam_state.scores.end(), next_token_scores.begin(), next_token_scores.end()); + // Append next token scores to the scores output. + gsl::copy(next_token_scores, beam_state.remaining_scores); + beam_state.remaining_scores = beam_state.remaining_scores.subspan(next_token_scores.size()); } // Apply top-k selection like the following: @@ -435,8 +477,6 @@ Status BeamSearchImpl::ProcessLogits( // next_indices = (next_tokens / vocab_size).long() // next_tokens = next_tokens % vocab_size gsl::span next_token_indices = topk_indices->DataAsSpan(); - beam_state.next_indices.resize(parameters_->batch_size * top_k); - beam_state.next_tokens.resize(parameters_->batch_size * top_k); offset = 0; for (int i = 0; i < parameters_->batch_size; i++) { for (unsigned int j = 0; j < top_k; j++, offset++) { @@ -459,8 +499,7 @@ Status BeamSearchImpl::ProcessLogits( &(beam_state.sequences), next_scores, next_tokens, - next_indices, - allocator); + next_indices); return Status::OK(); } @@ -469,13 +508,14 @@ template Status BeamSearchImpl::GenerateNextToken( const OrtValue& logits, gsl::span& beam_next_tokens, - gsl::span& beam_indices) { + gsl::span& beam_indices, + BeamSearchState& beam_state) { // Process logits to get next token scores - ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state_, allocator_)); + ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_)); gsl::span& beam_scores = beam_scorer_->GetNextScores(); // TODO: may not need clone beam_scores. - beam_state_.beam_scores.assign(beam_scores.begin(), beam_scores.end()); + gsl::copy(beam_scores, beam_state.beam_scores); beam_next_tokens = beam_scorer_->GetNextTokens(); beam_indices = beam_scorer_->GetNextIndices(); @@ -486,10 +526,10 @@ Status BeamSearchImpl::GenerateNextToken( DumpTensor("beam_indices after scorer", beam_indices.data(), parameters_->batch_size, parameters_->num_beams); #endif - beam_state_.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens); + beam_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens); #ifdef DEBUG_BEAM_SEARCH - beam_state_.sequences.PrintSequences(); + beam_state.sequences.PrintSequences(); #endif return Status::OK(); } @@ -501,7 +541,8 @@ Status BeamSearchImpl::UpdateFeeds( int current_length, gsl::span beam_next_tokens, gsl::span beam_indices) { - return gpt_subgraph_.UpdateFeeds(last_outputs, next_inputs, current_length, beam_next_tokens, beam_indices, parameters_->num_beams); + return gpt_subgraph_.UpdateFeeds(last_outputs, next_inputs, current_length, + beam_next_tokens, beam_indices, parameters_->num_beams); } template @@ -530,6 +571,9 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { CreateInitialFeeds(feeds); + AllocatorPtr temp_space_allocator; + ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator)); + // Initialize resources beam_scorer_ = std::make_unique>(parameters_->batch_size, parameters_->num_beams, @@ -539,6 +583,8 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { parameters_->num_return_sequences, parameters_->pad_token_id, parameters_->eos_token_id); + beam_scorer_->Initialize(allocator_, parameters_->sequence_length); // TODO: use temp_space_allocator + const OrtValue& input_ids = feeds[0]; #ifdef DEBUG_BEAM_SEARCH DumpOrtValue("input_ids", input_ids); @@ -546,13 +592,15 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { DumpOrtValue("attention_mask", feeds[2]); #endif - beam_state_.Init(input_ids, - parameters_->batch_size, - parameters_->num_beams, - parameters_->vocab_size, - parameters_->sequence_length, - parameters_->max_length, - parameters_->output_scores); + BeamSearchState beam_state; + beam_state.Init(temp_space_allocator, + input_ids, + parameters_->batch_size, + parameters_->num_beams, + parameters_->vocab_size, + parameters_->sequence_length, + parameters_->max_length, + parameters_->output_scores); int current_length = parameters_->sequence_length; while (current_length < parameters_->max_length) { @@ -568,17 +616,24 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { const OrtValue& logits = fetches[0]; gsl::span beam_next_tokens; gsl::span beam_indices; - ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices)); + ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state)); // Increase sequence length after a new token is generated. ++current_length; // Prepare inputs for next round of subgraph call. if (current_length < parameters_->max_length) { - ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length, beam_next_tokens.as_span(), beam_indices.as_span())); + ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length, + beam_next_tokens.as_span(), + beam_indices.as_span())); } fetches.clear(); + // When all batches are finished, stop earlier (make sure) + if (beam_scorer_->IsDone()) { + break; + } + #ifdef DEBUG_BEAM_SEARCH if (current_length - parameters_->sequence_length == 3) { // only dump a few steps. DisableTensorDump(); @@ -586,8 +641,8 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { #endif } - gsl::span beam_scores(beam_state_.beam_scores.data(), beam_state_.beam_scores.size()); - beam_scorer_->Finalize(&(beam_state_.sequences), + gsl::span beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size()); + beam_scorer_->Finalize(&(beam_state.sequences), beam_scores, output_sequences, output_sequences_scores); @@ -595,7 +650,7 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { // Output per token scores if (output_scores != nullptr) { gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = gsl::span(beam_state_.scores.data(), beam_state_.scores.size()); + gsl::span source = gsl::span(beam_state.scores.data(), beam_state.scores.size()); gsl::copy(source, target); // Fill zeros for the remaining when beam search stopped early diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index b3d9d4c472ea5..bb7aeb989e966 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -110,10 +110,6 @@ BeamSearchScorer::BeamSearchScorer(int batch_size, for (int batch = 0; batch < batch_size; batch++) { beam_hyps.push_back(BeamHypotheses(num_beams, length_penalty, early_stopping)); } - - for (int batch = 0; batch < batch_size; batch++) { - done_.push_back(false); - } } template @@ -125,12 +121,30 @@ bool BeamSearchScorer::IsDone() { return true; } +template +void BeamSearchScorer::Initialize(AllocatorPtr& allocator, int sequence_length){ + ORT_ENFORCE(next_beam_scores_.empty()); // Make sure this is called only once. + + size_t batch_beam_size = static_cast(batch_size_ * num_beams_); + const bool no_fill = false; // do not fill values after allocation + next_beam_scores_ = Allocate(allocator, batch_beam_size, next_beam_scores_ptr_, no_fill); + next_beam_tokens_ = Allocate(allocator, batch_beam_size, next_beam_tokens_ptr_, no_fill); + next_beam_indices_ = Allocate(allocator, batch_beam_size, next_beam_indices_ptr_, no_fill); + + // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length. + int buffer_per_beam = (max_length_ * (max_length_ + 1) - (sequence_length - 1) * sequence_length) / 2; + hypothesis_buffer_length_ = batch_beam_size * static_cast(buffer_per_beam); + hypothesis_buffer_ = Allocate(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, no_fill); + + done_ = Allocate(allocator, static_cast(batch_size_), done_ptr_, no_fill); + std::fill_n(done_.data(), done_.size(), false); +} + template void BeamSearchScorer::Process(ISequences* sequences, gsl::span& next_scores, gsl::span& next_tokens, - gsl::span& next_indices, - AllocatorPtr& allocator) { + gsl::span& next_indices) { // Sequences shape is (batch_size * num_beams, total_sequence_length) // It contains word ID of whole sequence generated so far. // It is different from subgraph input_ids, which only need one word when past state is not empty. @@ -140,20 +154,6 @@ void BeamSearchScorer::Process(ISequences* sequences, ORT_ENFORCE(next_scores.size() == next_tokens.size()); ORT_ENFORCE(next_scores.size() == next_indices.size()); - // Allocate buffers only once. - if (next_beam_scores_.empty()) { - size_t batch_beam_size = static_cast(batch_size_ * num_beams_); - const bool fill_zeros = false; - next_beam_scores_ = Allocate(allocator, batch_beam_size, next_beam_scores_ptr_, fill_zeros); - next_beam_tokens_ = Allocate(allocator, batch_beam_size, next_beam_tokens_ptr_, fill_zeros); - next_beam_indices_ = Allocate(allocator, batch_beam_size, next_beam_indices_ptr_, fill_zeros); - - // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length. - int buffer_per_beam = (max_length_ * (max_length_ + 1) - (sequence_length - 1) * sequence_length) / 2; - hypothesis_buffer_length_ = batch_beam_size * static_cast(buffer_per_beam); - hypothesis_buffer_ = Allocate(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, fill_zeros); - } - for (int batch = 0; batch < batch_size_; batch++) { BeamHypotheses& beam_hyp = beam_hyps[batch]; if (done_[batch]) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index c2eb0b13c5c83..2a150802369a2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -23,11 +23,12 @@ class IBeamScorer { public: virtual ~IBeamScorer() {} + virtual void Initialize(AllocatorPtr& allocator, int sequence_length) = 0; + virtual void Process(ISequences* sequences, gsl::span& next_scores, gsl::span& next_tokens, - gsl::span& next_indices, - AllocatorPtr& allocator) = 0; + gsl::span& next_indices) = 0; virtual void Finalize(ISequences* sequences, gsl::span& final_beam_scores, @@ -91,19 +92,20 @@ class BeamSearchScorer : public IBeamScorer { int pad_token_id, int eos_token_id); - bool IsDone(); + void Initialize(AllocatorPtr& allocator, int sequence_length) override; void Process(ISequences* sequences, gsl::span& next_scores, gsl::span& next_tokens, - gsl::span& next_indices, - AllocatorPtr& allocator) override; + gsl::span& next_indices) override; void Finalize(ISequences* sequences, gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) override; + bool IsDone(); + gsl::span& GetNextScores() { return next_beam_scores_; } gsl::span& GetNextTokens() { return next_beam_tokens_; } gsl::span& GetNextIndices() { return next_beam_indices_; } @@ -116,8 +118,11 @@ class BeamSearchScorer : public IBeamScorer { int pad_token_id_; int eos_token_id_; + // TODO: use ORT allocator to avoid allocating from heap directly std::vector> beam_hyps; // List of batch result of beam search. Its shape is (batch_size) - std::vector done_; // List of flags indicates whether each batch is finished or not. Its shape is (batch_size). + + IAllocatorUniquePtr done_ptr_; // List of flags indicates whether each batch is finished or not. Its shape is (batch_size). + gsl::span done_; IAllocatorUniquePtr next_beam_scores_ptr_; gsl::span next_beam_scores_; diff --git a/onnxruntime/test/python/transformers/test_beam_search.py b/onnxruntime/test/python/transformers/test_beam_search.py index 1296191d568dd..04d3250b3cc3d 100644 --- a/onnxruntime/test/python/transformers/test_beam_search.py +++ b/onnxruntime/test/python/transformers/test_beam_search.py @@ -15,14 +15,15 @@ from convert_beam_search import main as run else: from onnxruntime.transformers.convert_beam_search import main as run - + + class TestBeamSearch(unittest.TestCase): def setUp(self): self.model_name = "gpt2" self.gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') self.beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search.onnx') self.cpu_params = f'-m {self.model_name} --gpt2_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0' - + def run_beam_search(self, arguments: str): return run(arguments.split()) @@ -31,7 +32,7 @@ def test_cpu(self): result = self.run_beam_search(self.cpu_params) os.remove(self.gpt2_onnx_path) os.remove(self.beam_search_onnx_path) - self.assertTrue(result["parity"], "ORT and PyTorch result is different") + self.assertTrue(result["parity"], "ORT and PyTorch result is different") @pytest.mark.slow def test_cpu_no_repeat_ngram(self): @@ -41,5 +42,6 @@ def test_cpu_no_repeat_ngram(self): os.remove(self.beam_search_onnx_path) self.assertTrue(result["parity"], "ORT and PyTorch result is different") + if __name__ == '__main__': unittest.main() From 08277874c8690cc264c9934c295b18aa3af25a09 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 8 Dec 2021 00:45:17 -0800 Subject: [PATCH 36/53] use allocator in sequences --- .../cpu/transformers/beam_search.cc | 40 ++++--------------- .../cpu/transformers/beam_search.h | 1 - .../contrib_ops/cpu/transformers/sequences.cc | 20 +++++----- .../contrib_ops/cpu/transformers/sequences.h | 28 ++++++++++++- 4 files changed, 45 insertions(+), 44 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index d7fbb0e657d42..6965c0f6255c2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -24,6 +24,7 @@ #include "core/providers/cpu/math/softmax_shared.h" #include "beam_search.h" #include "logits_processor.h" +#include "sequences.h" #include "dump_tensor.h" #ifdef _MSC_VER @@ -51,25 +52,6 @@ REGISTER_KERNEL_TYPED(float) namespace transformers { -template -gsl::span AllocateScratchBuffer(AllocatorPtr allocator, - BufferUniquePtr& buffer, - size_t elements, - bool fill = false, - T fill_value = T{}) { - size_t bytes = SafeInt(sizeof(T)) * elements; - void* data = allocator->Alloc(bytes); - buffer = std::move(BufferUniquePtr(data, BufferDeleter(allocator))); - T* first = reinterpret_cast(buffer.get()); - auto span = gsl::make_span(first, elements); - - if (fill) { - std::fill_n(first, elements, fill_value); - } - - return span; -} - template struct BeamSearchState { gsl::span beam_scores; // shape (batch_size, num_beams) @@ -83,8 +65,6 @@ struct BeamSearchState { gsl::span scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) gsl::span remaining_scores; // subspan of scores that is not used - // gsl::span sequences_space; // shape (2, batch_size, num_beams, max_seq_length) - Sequences sequences; void Init(AllocatorPtr allocator, @@ -97,7 +77,7 @@ struct BeamSearchState { bool output_scores) { size_t batch_beam_size = SafeInt(batch_size) * num_beams; - beam_scores = AllocateScratchBuffer(allocator, beam_scores_buffer_, batch_beam_size, true, static_cast(0)); + beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size, true, static_cast(0)); // Initialize score of first beam of each group with 0 and the rest with -1e9. // This ensures that the beams in the same group don't produce same tokens every time. @@ -108,23 +88,20 @@ struct BeamSearchState { } size_t next_token_size = SafeInt(batch_beam_size) * vocab_size; - next_token_logits = AllocateScratchBuffer(allocator, next_token_logits_buffer_, next_token_size, true, static_cast(0)); - next_token_scores = AllocateScratchBuffer(allocator, next_token_scores_buffer_, next_token_size, true, static_cast(0)); + next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size, true, static_cast(0)); + next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, true, static_cast(0)); - next_tokens = AllocateScratchBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size, true, static_cast(0)); + next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size, true, static_cast(0)); - next_indices = AllocateScratchBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size, true, static_cast(0)); + next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size, true, static_cast(0)); if (output_scores) { size_t elements = SafeInt(max_length - sequence_length) * batch_size * num_beams * vocab_size; - scores = AllocateScratchBuffer(allocator, scores_buffer_, elements); + scores = AllocateBuffer(allocator, scores_buffer_, elements); remaining_scores = scores; } - // size_t sequences_space_size = SafeInt(2) * batch_size * num_beams * max_length; - // sequences_space = AllocateScratchBuffer(allocator, sequences_space_buffer_, sequences_space_size, true, static_cast(0)); - - sequences.Init(input_ids, static_cast(batch_beam_size), sequence_length, max_length); + sequences.Init(allocator, input_ids, static_cast(batch_beam_size), sequence_length, max_length); } private: @@ -133,7 +110,6 @@ struct BeamSearchState { BufferUniquePtr next_token_scores_buffer_; BufferUniquePtr next_tokens_buffer_; BufferUniquePtr next_indices_buffer_; - // BufferUniquePtr sequences_space_buffer_; BufferUniquePtr scores_buffer_; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index cc359f2549dee..9dc5cac4087c8 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -11,7 +11,6 @@ #include "beam_search_parameters.h" #include "beam_search_scorer.h" #include "gpt_subgraph.h" -#include "sequences.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index 3d4f6c244462e..d725699664332 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -4,14 +4,17 @@ namespace onnxruntime { namespace contrib { namespace transformers { -void Sequences::Init(const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length) { - // Allocate buffer (shall we use allocator instead?) - sequences[0].assign(batch_beam_size * max_length, 0); - sequences[1].assign(batch_beam_size * max_length, 0); +void Sequences::Init(AllocatorPtr allocator, const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length) { + size_t sequences_size = SafeInt(batch_beam_size) * max_length; + size_t buffer_size = sequences_size + sequences_size; + gsl::span buffer = AllocateBuffer(allocator, sequences_space_buffer_, buffer_size, true, static_cast(0)); + + sequences[0] = buffer.subspan(0, sequences_size); + sequences[1] = buffer.subspan(sequences_size); // copying input_ids to sequences[0] gsl::span input = input_ids.Get().DataAsSpan(); - gsl::span output(sequences[0]); + gsl::span output = sequences[0]; for (int i = 0; i < batch_beam_size; i++) { gsl::span source = input.subspan(i * sequence_length, sequence_length); gsl::span target = output.subspan(i * max_length, sequence_length); @@ -25,7 +28,7 @@ void Sequences::Init(const OrtValue& input_ids, int batch_beam_size, int sequenc } gsl::span Sequences::GetSequence(int beam_index) const { - gsl::span buffer(sequences[current_sequences_buffer]); + gsl::span buffer(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size()); gsl::span sequence = buffer.subspan(beam_index * max_length_, current_length_); return sequence; } @@ -47,9 +50,8 @@ void Sequences::PrintSequences() { void Sequences::AppendNextTokenToSequences( gsl::span& beam_indices, gsl::span& beam_next_tokens) { - //sequences = torch.cat([sequences[beam_indices, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - gsl::span input(sequences[current_sequences_buffer]); - gsl::span output(sequences[1 - current_sequences_buffer]); + gsl::span input(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size()); + gsl::span output = sequences[1 - current_sequences_buffer]; for (int i = 0; i < batch_beam_size_; i++) { int beam_index = static_cast(beam_indices[i]); diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 96f231099a5f3..999b73c03c128 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -1,6 +1,8 @@ #pragma once #include "gsl/gsl" +#include "core/common/safeint.h" +#include "core/framework/allocator.h" #include "core/framework/ort_value.h" namespace onnxruntime { @@ -20,7 +22,7 @@ class Sequences : public ISequences { Sequences() {} // Initialize the sequence with initial input_ids and related parameters. - void Init(const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length); + void Init(AllocatorPtr allocator, const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length); // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). gsl::span GetSequence(int beam_index) const override; @@ -37,10 +39,13 @@ class Sequences : public ISequences { gsl::span& beam_next_tokens); private: + gsl::span sequences_space; // shape (2, batch_size, num_beams, max_seq_length) + BufferUniquePtr sequences_space_buffer_; + // Two buffers of shape (batch_size, num_beams, max_seq_length) to store sequences. // At each time, there is only one buffer is active. The other one will be active in next token. // Each AppendNextTokenToSequences call will trigger a rotation of active buffer. - std::vector sequences[2]; + gsl::span sequences[2]; // Index (either 0 or 1) of two buffers that is currently is active. int current_sequences_buffer; @@ -50,6 +55,25 @@ class Sequences : public ISequences { int current_length_; }; +template +gsl::span AllocateBuffer(AllocatorPtr allocator, + BufferUniquePtr& buffer, + size_t elements, + bool fill = false, + T fill_value = T{}) { + size_t bytes = SafeInt(sizeof(T)) * elements; + void* data = allocator->Alloc(bytes); + buffer = std::move(BufferUniquePtr(data, BufferDeleter(allocator))); + T* first = reinterpret_cast(buffer.get()); + auto span = gsl::make_span(first, elements); + + if (fill) { + std::fill_n(first, elements, fill_value); + } + + return span; +} + } // namespace transformers } // namespace contrib } // namespace onnxruntime \ No newline at end of file From 08bd1d77a39d63c2800397bf17143ffc38bb2fe1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 8 Dec 2021 01:25:12 -0800 Subject: [PATCH 37/53] fix build error --- onnxruntime/contrib_ops/cpu/transformers/sequences.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 999b73c03c128..9950330dae4e5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -63,7 +63,8 @@ gsl::span AllocateBuffer(AllocatorPtr allocator, T fill_value = T{}) { size_t bytes = SafeInt(sizeof(T)) * elements; void* data = allocator->Alloc(bytes); - buffer = std::move(BufferUniquePtr(data, BufferDeleter(allocator))); + BufferUniquePtr temp_buffer(data, BufferDeleter(allocator)); + buffer = std::move(temp_buffer); T* first = reinterpret_cast(buffer.get()); auto span = gsl::make_span(first, elements); From c7d456b5af409862c19b639bac239487ce26049d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 8 Dec 2021 16:52:58 -0800 Subject: [PATCH 38/53] move next_positions to beam state --- .../cpu/transformers/beam_search.cc | 120 +++++++++--------- .../transformers/beam_search_parameters.cc | 4 +- .../cpu/transformers/gpt_subgraph.cc | 22 ++-- .../cpu/transformers/gpt_subgraph.h | 5 +- .../contrib_ops/cpu/transformers/sequences.cc | 8 +- .../tools/transformers/convert_beam_search.py | 4 +- .../python/transformers/test_beam_search.py | 31 ++++- 7 files changed, 110 insertions(+), 84 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 6965c0f6255c2..1225442192232 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -9,6 +9,7 @@ #pragma warning(disable : 4996) #endif +#include #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/math/top_k.h" #include "core/framework/allocator.h" @@ -54,31 +55,28 @@ namespace transformers { template struct BeamSearchState { - gsl::span beam_scores; // shape (batch_size, num_beams) + gsl::span beam_scores; // shape (batch_size, num_beams) + gsl::span next_token_logits; // shape (batch_size * num_beams, vocab_size) + gsl::span next_token_scores; // shape (batch_size, num_beams * vocab_size) + gsl::span next_tokens; // shape (batch_size, 2 * num_beams) + gsl::span next_indices; // shape (batch_size, 2 * num_beams) + gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. - gsl::span next_token_logits; // shape (batch_size * num_beams, vocab_size) - gsl::span next_token_scores; // shape (batch_size, num_beams * vocab_size) - - gsl::span next_tokens; // shape (batch_size, 2 * num_beams) - gsl::span next_indices; // shape (batch_size, 2 * num_beams) - - gsl::span scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) - gsl::span remaining_scores; // subspan of scores that is not used + gsl::span scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) + gsl::span remaining_scores; // subspan that is avaiable for appending next token scores. Sequences sequences; void Init(AllocatorPtr allocator, - const OrtValue& input_ids, int batch_size, int num_beams, int vocab_size, int sequence_length, int max_length, bool output_scores) { - size_t batch_beam_size = SafeInt(batch_size) * num_beams; beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size, true, static_cast(0)); - + // Initialize score of first beam of each group with 0 and the rest with -1e9. // This ensures that the beams in the same group don't produce same tokens every time. for (int i = 0; i < batch_size; i++) { @@ -95,21 +93,24 @@ struct BeamSearchState { next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size, true, static_cast(0)); + next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size, true, static_cast(0)); + if (output_scores) { size_t elements = SafeInt(max_length - sequence_length) * batch_size * num_beams * vocab_size; scores = AllocateBuffer(allocator, scores_buffer_, elements); remaining_scores = scores; } - sequences.Init(allocator, input_ids, static_cast(batch_beam_size), sequence_length, max_length); + // sequences will be initialized later since it has dependency on input_ids } -private: + private: BufferUniquePtr beam_scores_buffer_; BufferUniquePtr next_token_logits_buffer_; BufferUniquePtr next_token_scores_buffer_; BufferUniquePtr next_tokens_buffer_; BufferUniquePtr next_indices_buffer_; + BufferUniquePtr next_positions_buffer_; BufferUniquePtr scores_buffer_; }; @@ -135,13 +136,14 @@ class BeamSearchImpl { Status CheckInputs(const OpKernelContextInternal& context); // Prepare the inputs for first inference of subgraph - void CreateInitialFeeds(std::vector& feeds); + void CreateInitialFeeds(gsl::span& next_positions, std::vector& feeds); // Update the input for next iteration. Status UpdateFeeds( const std::vector& last_outputs, std::vector& next_inputs, int current_length, + gsl::span& next_positions, gsl::span beam_next_tokens, gsl::span beam_indices); @@ -152,7 +154,7 @@ class BeamSearchImpl { BeamSearchState& beam_state); // Calculate scores from logits, then apply filtering and select next token for each beam. - Status ProcessLogits(const OrtValue& logits, + Status ProcessLogits(const OrtValue& logits, // logits output of subgraph BeamSearchState& beam_state, AllocatorPtr& allocator); @@ -180,10 +182,7 @@ class BeamSearchImpl { template void BeamSearch::Init(const OpKernelInfo& info) { - // Make sure the attribute was present even though we don't need it here. - // The GraphProto is loaded as a Graph instance by main Graph::Resolve, - // and a SessionState instance for executing the subgraph is created by InferenceSession. - // This is available via Info().GetSubgraphSessionState("attribute_name") when Compute is called. + // Make sure the body attribute was present even though we don't need it here. ONNX_NAMESPACE::GraphProto proto; ORT_ENFORCE(info.GetAttr("body", &proto).IsOK()); ORT_IGNORE_RETURN_VALUE(proto); @@ -226,7 +225,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - BeamSearchParameters parameters = parameters_; // make a copy + BeamSearchParameters parameters = parameters_; // make a copy since we will update the parameters based on inputs later BeamSearchImpl impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, stream_, parameters}; @@ -335,32 +334,31 @@ Status BeamSearchImpl::Initialize() { } template -void BeamSearchImpl::CreateInitialFeeds(std::vector& feeds) { +void BeamSearchImpl::CreateInitialFeeds(gsl::span& next_positions, std::vector& feeds) { const OrtValue* input_ids_value = context_.GetInputOrtValue(0); const Tensor& input_ids = input_ids_value->Get(); - gpt_subgraph_.CreateInitialFeeds(input_ids, implicit_inputs_, parameters_->num_beams, parameters_->pad_token_id, feeds); + gpt_subgraph_.CreateInitialFeeds(input_ids, implicit_inputs_, parameters_->num_beams, parameters_->pad_token_id, next_positions, feeds); } template Status BeamSearchImpl::ProcessLogits( - const OrtValue& logits, // logits output of subgraph + const OrtValue& logits, BeamSearchState& beam_state, AllocatorPtr& allocator) { - const int64_t batch_beam_size = static_cast(parameters_->batch_size * parameters_->num_beams); + const int64_t batch_beam_size = static_cast(parameters_->BatchBeamSize()); const int& vocab_size = parameters_->vocab_size; const T* logits_data = logits.Get().Data(); + // Logits has shape (batch_size * num_beams, input_length, vocab_size), + // where input_length equals to parameters_->sequence_length for first subgraph call, and 1 for the remaining calls. const TensorShape& logits_shape = logits.Get().Shape(); ORT_ENFORCE(logits_shape.NumDimensions() == 3); - - // The sequence length of input_ids for the logits. - // It equals to parameters_->sequence_length for first subgraph call, and 1 for the remaining calls. auto input_length = logits_shape[1]; - // Get logits for the last token, where logits has shape (batch_size * num_beams, input_length, vocab_size) - // next_token_logits = logits[:, -1, :], where its shape is (batch_size * num_beams, vocab_size) - // When input_length == 1, use logits directly to avoid copy logits to next_token_logits. + // Get logits for the last token: + // next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size) + // When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1. gsl::span& next_token_logits = beam_state.next_token_logits; if (input_length > 1) { const T* current_logits = logits_data + (input_length - 1) * vocab_size; @@ -400,7 +398,8 @@ Status BeamSearchImpl::ProcessLogits( DumpTensor("next_token_scores after logits processor", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size); #endif - // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) + // Add beam score to next token scores. Corresponding python code is like: + // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) // TODO: use thread pool to parrellel int offset = 0; int batch_beam_index = 0; @@ -490,7 +489,9 @@ Status BeamSearchImpl::GenerateNextToken( ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_)); gsl::span& beam_scores = beam_scorer_->GetNextScores(); - // TODO: may not need clone beam_scores. + // It is optional to clone beam_scores. Change it to use same buffer also works: + // beam_state.beam_scores = beam_scores + // Here we make a copy to reduce the coupling with little cost (the buffer size is small). gsl::copy(beam_scores, beam_state.beam_scores); beam_next_tokens = beam_scorer_->GetNextTokens(); @@ -515,9 +516,10 @@ Status BeamSearchImpl::UpdateFeeds( const std::vector& last_outputs, std::vector& next_inputs, int current_length, + gsl::span& next_positions, gsl::span beam_next_tokens, gsl::span beam_indices) { - return gpt_subgraph_.UpdateFeeds(last_outputs, next_inputs, current_length, + return gpt_subgraph_.UpdateFeeds(last_outputs, next_inputs, current_length, next_positions, beam_next_tokens, beam_indices, parameters_->num_beams); } @@ -545,12 +547,19 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { std::vector feeds; std::vector fetches; - CreateInitialFeeds(feeds); - + // Initialize resources AllocatorPtr temp_space_allocator; ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator)); - // Initialize resources + BeamSearchState beam_state; + beam_state.Init(temp_space_allocator, + parameters_->batch_size, + parameters_->num_beams, + parameters_->vocab_size, + parameters_->sequence_length, + parameters_->max_length, + parameters_->output_scores); + beam_scorer_ = std::make_unique>(parameters_->batch_size, parameters_->num_beams, parameters_->max_length, @@ -559,25 +568,22 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { parameters_->num_return_sequences, parameters_->pad_token_id, parameters_->eos_token_id); - beam_scorer_->Initialize(allocator_, parameters_->sequence_length); // TODO: use temp_space_allocator + beam_scorer_->Initialize(allocator_, parameters_->sequence_length); // TODO: use temp_space_allocator + CreateInitialFeeds(beam_state.next_positions, feeds); const OrtValue& input_ids = feeds[0]; + beam_state.sequences.Init(temp_space_allocator, + input_ids, + parameters_->BatchBeamSize(), + parameters_->sequence_length, + parameters_->max_length); + #ifdef DEBUG_BEAM_SEARCH DumpOrtValue("input_ids", input_ids); DumpOrtValue("position_ids", feeds[1]); DumpOrtValue("attention_mask", feeds[2]); #endif - BeamSearchState beam_state; - beam_state.Init(temp_space_allocator, - input_ids, - parameters_->batch_size, - parameters_->num_beams, - parameters_->vocab_size, - parameters_->sequence_length, - parameters_->max_length, - parameters_->output_scores); - int current_length = parameters_->sequence_length; while (current_length < parameters_->max_length) { #ifdef DEBUG_BEAM_SEARCH @@ -594,22 +600,23 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { gsl::span beam_indices; ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state)); + // When all batches are finished, stop earlier to avoid wasting computation. + if (beam_scorer_->IsDone()) { + break; + } + // Increase sequence length after a new token is generated. ++current_length; // Prepare inputs for next round of subgraph call. if (current_length < parameters_->max_length) { ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length, + beam_state.next_positions, beam_next_tokens.as_span(), beam_indices.as_span())); } fetches.clear(); - // When all batches are finished, stop earlier (make sure) - if (beam_scorer_->IsDone()) { - break; - } - #ifdef DEBUG_BEAM_SEARCH if (current_length - parameters_->sequence_length == 3) { // only dump a few steps. DisableTensorDump(); @@ -627,13 +634,8 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { if (output_scores != nullptr) { gsl::span target = output_scores->MutableDataAsSpan(); gsl::span source = gsl::span(beam_state.scores.data(), beam_state.scores.size()); + assert(target.length() == source.length()); gsl::copy(source, target); - - // Fill zeros for the remaining when beam search stopped early - if (target.length() > source.length()) { - gsl::span remaining = target.subspan(source.length()); - memset(remaining.data(), 0, remaining.size_bytes()); - } } return status; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index fc0f61705b35b..fee3ec475338a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -31,7 +31,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { sequence_length = static_cast(dims[1]); auto* max_length_tensor = context->Input(1); - max_length = max_length_tensor ? static_cast(*max_length_tensor->Data()) : 4096; + max_length = max_length_tensor ? static_cast(*max_length_tensor->Data()) : kMaxSequenceLength; ORT_ENFORCE(max_length > sequence_length, "max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")"); ORT_ENFORCE(max_length <= kMaxSequenceLength, "max_length (", max_length, ") shall be no more than ", kMaxSequenceLength); @@ -40,7 +40,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { auto* num_beams_tensor = context->Input(3); num_beams = num_beams_tensor ? static_cast(*num_beams_tensor->Data()) : 1; - // TODO: shall we limit num_beams > 1. When num_beams==1, we can have another operator for greedy search. + // TODO: limit num_beams > 1 when we can have another operator for greedy search. ORT_ENFORCE(num_beams >= 1, "num_beams shall be a positive integer, got ", num_beams); auto* num_return_sequences_tensor = context->Input(4); diff --git a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc index cca9bb311cb60..80825c7e26037 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc @@ -188,6 +188,7 @@ void GptSubgraph::CreateInitialFeeds( const std::vector& implicit_inputs, int num_beams, int pad_token_id, + gsl::span& next_positions, std::vector& feeds) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); @@ -233,7 +234,6 @@ void GptSubgraph::CreateInitialFeeds( auto mask_type = DataTypeImpl::GetType(); Tensor::InitOrtValue(mask_type, input_ids_shape, alloactor, attention_mask); - next_positions_.resize(batch_size * num_beams); // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and cumulated sum of mask in a batch for other tokens float* mask_data = attention_mask.GetMutable()->MutableData(); @@ -254,7 +254,7 @@ void GptSubgraph::CreateInitialFeeds( } } for (int k = 0; k < num_beams; k++) { - next_positions_[i * num_beams + k] = abs_position; + next_positions[i * num_beams + k] = abs_position; } } @@ -289,21 +289,20 @@ void GptSubgraph::CreateInitialFeeds( } OrtValue GptSubgraph::ExpandInputs(const OrtValue& input, int num_beams) const { + // Input shape (batch_size, sequence_length) + // Output shape (batch_size * num_beams, sequence_length) if (num_beams == 1) return input; - // Given input of shape (batch_size, sequence_length), expand the shape to be (batch_size * num_beams, sequence_length) const TensorShape& input_shape = input.Get().Shape(); - //ORT_ENFORCE(input_shape.NumDimensions() == 2 && input_shape[0] == parameters_->batch_size && input_shape[1] == parameters_->sequence_length); - const int64_t& batch_size = input_shape[0]; const int64_t& sequence_length = input_shape[1]; + int64_t dims[] = {batch_size * num_beams, sequence_length}; TensorShape expanded_shape(&dims[0], 2); - MLDataType element_type = input.Get().DataType(); - OrtValue expanded; + MLDataType element_type = input.Get().DataType(); Tensor::InitOrtValue(element_type, expanded_shape, allocator_, expanded); if (element_type == DataTypeImpl::GetType()) { @@ -331,6 +330,7 @@ OrtValue GptSubgraph::ExpandInputs(const OrtValue& input, int num_beams) const { return expanded; } +// TODO: support float16 void GptSubgraph::PickPastState(const std::vector& last_outputs, std::vector& next_inputs, gsl::span& beam_indices) { @@ -340,13 +340,12 @@ void GptSubgraph::PickPastState(const std::vector& last_outputs, // Create a tensor with same shape. OrtValue past; - auto past_type = DataTypeImpl::GetType(); //TODO: present.Type() + auto past_type = DataTypeImpl::GetType(); Tensor::InitOrtValue(past_type, past_shape, allocator_, past); auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4]; auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4]; - // TODO: support float16 gsl::span past_span = past.GetMutable()->MutableDataAsSpan(); gsl::span present_span = present.Get().DataAsSpan(); for (gsl::index j = 0; j < beam_indices.length(); j++) { @@ -378,6 +377,7 @@ Status GptSubgraph::UpdateFeeds( const std::vector& last_outputs, std::vector& next_inputs, int current_length, + gsl::span& next_positions, gsl::span beam_next_tokens, gsl::span beam_indices, int num_beams) { @@ -405,8 +405,8 @@ Status GptSubgraph::UpdateFeeds( Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, position_ids); int64_t* position_data = position_ids.GetMutable()->MutableData(); for (int i = 0; i < batch_beam_size; i++) { - position_data[i] = next_positions_[i]; - next_positions_[i]++; + position_data[i] = next_positions[i]; + next_positions[i]++; } next_inputs[1] = position_ids; diff --git a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h index 154dbf4b7390d..75d879c7fefdb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h +++ b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h @@ -47,12 +47,14 @@ struct GptSubgraph { const std::vector& implicit_inputs, int num_beams, int pad_token_id, + gsl::span& next_positions, std::vector& feeds); Status UpdateFeeds( const std::vector& last_outputs, std::vector& next_inputs, int current_length, + gsl::span& next_positions, gsl::span beam_next_tokens, gsl::span beam_indices, int num_beams); @@ -69,9 +71,6 @@ struct GptSubgraph { std::vector& next_inputs, gsl::span& beam_indices); - // TODO: move it to make this class state less. - std::vector next_positions_; - AllocatorPtr allocator_; const SessionState* session_state_; const SessionState* subgraph_session_state_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index d725699664332..a9c70ef4106dc 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -12,7 +12,7 @@ void Sequences::Init(AllocatorPtr allocator, const OrtValue& input_ids, int batc sequences[0] = buffer.subspan(0, sequences_size); sequences[1] = buffer.subspan(sequences_size); - // copying input_ids to sequences[0] + // Copy input_ids to sequences[0]. gsl::span input = input_ids.Get().DataAsSpan(); gsl::span output = sequences[0]; for (int i = 0; i < batch_beam_size; i++) { @@ -60,13 +60,15 @@ void Sequences::AppendNextTokenToSequences( gsl::copy(source, target); } - // append next token to each beam + // Append next token to each beam. for (int i = 0; i < batch_beam_size_; i++) { output[i * max_length_ + current_length_] = beam_next_tokens[i]; } ++current_length_; - current_sequences_buffer = 1 - current_sequences_buffer; // rotate buffer for next round + + // Rotate buffer for next round. + current_sequences_buffer = 1 - current_sequences_buffer; } } // namespace transformers diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 9f8841dfbb5e1..3f251fb9ad507 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -444,7 +444,7 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): return output -def main(argv=None): +def main(argv=None, sentences=None): args = parse_arguments(argv) if os.path.exists(args.output): @@ -452,7 +452,7 @@ def main(argv=None): else: convert_model(args) - return test_model(args, use_vocab_mask=True) + return test_model(args, use_vocab_mask=True, sentences=sentences) if __name__ == '__main__': diff --git a/onnxruntime/test/python/transformers/test_beam_search.py b/onnxruntime/test/python/transformers/test_beam_search.py index 04d3250b3cc3d..f469cdaaaa936 100644 --- a/onnxruntime/test/python/transformers/test_beam_search.py +++ b/onnxruntime/test/python/transformers/test_beam_search.py @@ -19,23 +19,46 @@ class TestBeamSearch(unittest.TestCase): def setUp(self): + #TODO: use a smaller model and enable tests in CI pipeline self.model_name = "gpt2" self.gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') self.beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search.onnx') self.cpu_params = f'-m {self.model_name} --gpt2_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0' - def run_beam_search(self, arguments: str): - return run(arguments.split()) + def run_beam_search(self, arguments: str, sentences=None): + return run(arguments.split(), sentences=sentences) @pytest.mark.slow def test_cpu(self): - result = self.run_beam_search(self.cpu_params) + result = self.run_beam_search(self.cpu_params + " --num_return_sequences 2", + sentences=["The product is released"]) os.remove(self.gpt2_onnx_path) os.remove(self.beam_search_onnx_path) self.assertTrue(result["parity"], "ORT and PyTorch result is different") @pytest.mark.slow - def test_cpu_no_repeat_ngram(self): + def test_early_stopping(self): + result = self.run_beam_search(self.cpu_params + " --early_stopping") + os.remove(self.gpt2_onnx_path) + os.remove(self.beam_search_onnx_path) + self.assertTrue(result["parity"], "ORT and PyTorch result is different") + + @pytest.mark.slow + def test_temperature(self): + result = self.run_beam_search(self.cpu_params + " --temperature 0.5") + os.remove(self.gpt2_onnx_path) + os.remove(self.beam_search_onnx_path) + self.assertTrue(result["parity"], "ORT and PyTorch result is different") + + @pytest.mark.slow + def test_length_penalty(self): + result = self.run_beam_search(self.cpu_params + " --length_penalty 0.5") + os.remove(self.gpt2_onnx_path) + os.remove(self.beam_search_onnx_path) + self.assertTrue(result["parity"], "ORT and PyTorch result is different") + + @pytest.mark.slow + def test_no_repeat_ngram(self): for ngram_size in [1, 2]: result = self.run_beam_search(self.cpu_params + f' --no_repeat_ngram_size {ngram_size}') os.remove(self.gpt2_onnx_path) From d676e7b9baea9fb2bf2471294cbd70778022fc98 Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Sun, 9 Jan 2022 23:22:42 -0800 Subject: [PATCH 39/53] Changes for prefix matching --- .../cpu/transformers/beam_search.cc | 53 ++++++++++++-- .../cpu/transformers/beam_search_parameters.h | 3 +- .../cpu/transformers/dump_tensor.h | 3 +- .../cpu/transformers/gpt_subgraph.cc | 2 +- .../cpu/transformers/logits_processor.cc | 71 +++++++++++++++++-- .../cpu/transformers/logits_processor.h | 31 ++++++-- .../core/graph/contrib_ops/contrib_defs.cc | 1 + onnxruntime/core/providers/cpu/math/top_k.cc | 2 + .../tools/transformers/convert_beam_search.py | 17 +++++ 9 files changed, 161 insertions(+), 22 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 1225442192232..4c6fc9b15b8ed 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -151,12 +151,14 @@ class BeamSearchImpl { Status GenerateNextToken(const OrtValue& logits, gsl::span& beam_next_tokens, gsl::span& beam_indices, - BeamSearchState& beam_state); + BeamSearchState& beam_state, + int counter); // Calculate scores from logits, then apply filtering and select next token for each beam. Status ProcessLogits(const OrtValue& logits, // logits output of subgraph BeamSearchState& beam_state, - AllocatorPtr& allocator); + AllocatorPtr& allocator, + int counter); OpKernelContextInternal& context_; @@ -189,6 +191,7 @@ void BeamSearch::Init(const OpKernelInfo& info) { parameters_.ParseFromAttributes(info); + ConfigureTensorDump(); stream_ = nullptr; } @@ -290,6 +293,25 @@ Status BeamSearchImpl::CheckInputs(const OpKernelContextInternal& context) { parameters_->vocab_mask = vocab_mask->DataAsSpan(); } + const Tensor* prefix_vocab_mask = context.Input(9); + if (prefix_vocab_mask != nullptr) { + // prefix_vocab_mask is optional + const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims(); + if (vocab_mask_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 1 dimension, got ", + vocab_mask_dims.size()); + } + + // There is dependency on vocab_size parameter, which shall be set before calling this function. + if (static_cast(vocab_mask_dims[0]) != parameters_->vocab_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' shape does not match with vocab_size, got ", + vocab_mask_dims[0]); + } + + // store prefix vocab mask in parameters. + parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan(); + } + return Status::OK(); } @@ -344,7 +366,8 @@ template Status BeamSearchImpl::ProcessLogits( const OrtValue& logits, BeamSearchState& beam_state, - AllocatorPtr& allocator) { + AllocatorPtr& allocator, + int counter) { const int64_t batch_beam_size = static_cast(parameters_->BatchBeamSize()); const int& vocab_size = parameters_->vocab_size; @@ -392,7 +415,7 @@ Status BeamSearchImpl::ProcessLogits( #endif // Apply all score processors that updates scores - logits_processors_.Process(&(beam_state.sequences), next_token_scores); + logits_processors_.Process(&(beam_state.sequences), next_token_scores, counter); #ifdef DEBUG_BEAM_SEARCH DumpTensor("next_token_scores after logits processor", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size); @@ -443,6 +466,9 @@ Status BeamSearchImpl::ProcessLogits( return status; } + DumpTensor("topk_scores", *(topk_scores.get())); + DumpTensor("topk_indices", *(topk_indices.get())); + #ifdef DEBUG_BEAM_SEARCH DumpTensor("topk_scores", *(topk_scores.get())); DumpTensor("topk_indices", *(topk_indices.get())); @@ -464,6 +490,8 @@ Status BeamSearchImpl::ProcessLogits( gsl::span next_tokens(beam_state.next_tokens.data(), beam_state.next_tokens.size()); gsl::span next_indices(beam_state.next_indices.data(), beam_state.next_indices.size()); + DumpTensor("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, top_k); + #ifdef DEBUG_BEAM_SEARCH DumpTensor("next_scores before scorer", next_scores.data(), parameters_->batch_size, top_k); DumpTensor("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, top_k); @@ -484,9 +512,10 @@ Status BeamSearchImpl::GenerateNextToken( const OrtValue& logits, gsl::span& beam_next_tokens, gsl::span& beam_indices, - BeamSearchState& beam_state) { + BeamSearchState& beam_state, + int counter) { // Process logits to get next token scores - ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_)); + ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_, counter)); gsl::span& beam_scores = beam_scorer_->GetNextScores(); // It is optional to clone beam_scores. Change it to use same buffer also works: @@ -572,6 +601,9 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { CreateInitialFeeds(beam_state.next_positions, feeds); const OrtValue& input_ids = feeds[0]; + + DumpOrtValue("Before init, input_ids:", input_ids); + beam_state.sequences.Init(temp_space_allocator, input_ids, parameters_->BatchBeamSize(), @@ -585,7 +617,14 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { #endif int current_length = parameters_->sequence_length; + int iteration_counter = 0; while (current_length < parameters_->max_length) { + + DumpOrtValue("input_ids", input_ids); + DumpOrtValue("position_ids", feeds[1]); + DumpOrtValue("attention_mask", feeds[2]); + + iteration_counter++; #ifdef DEBUG_BEAM_SEARCH DumpString("***CurrentLength", std::to_string(current_length), true); #endif @@ -598,7 +637,7 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { const OrtValue& logits = fetches[0]; gsl::span beam_next_tokens; gsl::span beam_indices; - ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state)); + ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state, iteration_counter)); // When all batches are finished, stop earlier to avoid wasting computation. if (beam_scorer_->IsDone()) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index 26de2a98408eb..389f24ecdeb66 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -28,7 +28,8 @@ struct BeamSearchParameters { int sequence_length; // deduce from second dimension of input_ids gsl::span vocab_mask; - + gsl::span prefix_vocab_mask; + // Parameters from outputs. bool output_scores; // whether scores existed in output diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h index 2f80410011dab..d26c3c3efda3a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h @@ -26,7 +26,8 @@ namespace transformers { continue; \ } -#define SKIP_IF_TOO_MANY(row_or_column_size, i, new_line) SKIP_IF_MORE_THAN(row_or_column_size, i, MAX_ROW_OR_COLUMN, new_line) +//#define SKIP_IF_TOO_MANY(row_or_column_size, i, new_line) SKIP_IF_MORE_THAN(row_or_column_size, i, MAX_ROW_OR_COLUMN, new_line) +#define SKIP_IF_TOO_MANY(row_or_column_size, i, new_line) extern bool g_enable_tensor_dump; // global variance to turn on/off dump diff --git a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc index 80825c7e26037..58d679f4bd563 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc @@ -266,7 +266,7 @@ void GptSubgraph::CreateInitialFeeds( Tensor::InitOrtValue(past_type, past_shape, allocator_, empty_past); // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) for input_ids, position_ids and attention_mask - // TODO: Try expand inputs/outputs after first subgraph call instead. That may get better peroformance, but more complex to implement. + // TODO: Try expand inputs/outputs after first subgraph call instead. That may get better performance, but more complex to implement. OrtValue expanded_input_ids = ExpandInputs(subgraph_input_ids, num_beams); OrtValue expanded_position_ids = ExpandInputs(position_ids, num_beams); OrtValue expanded_attention_mask = ExpandInputs(attention_mask, num_beams); diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 10ef8b6f698d5..e73364b4fc64e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -35,11 +35,15 @@ MinLengthLogitsProcessor::MinLengthLogitsProcessor(int min_length, int eos_to template void MinLengthLogitsProcessor::Process(const ISequences* sequences, - NextTokenScores& next_token_scores) { + NextTokenScores& next_token_scores, + int counter) { if (sequences->GetSequenceLength() < min_length_) { next_token_scores.SetScore(eos_token_id_, std::numeric_limits::lowest()); } + if (counter == -1) { + std::cout<< counter <::RepetitionPenaltyLogitsProcessor(float pena template void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, - NextTokenScores& next_token_scores) { + NextTokenScores& next_token_scores, + int counter) { const int batch_beam_size = next_token_scores.batch_beam_size; + if (counter == -1) { + std::cout<< counter < beam_token_scores = next_token_scores.GetScores(i); gsl::span sequence = sequences->GetSequence(i); @@ -83,11 +92,16 @@ NoRepeatNGramLogitsProcessor::NoRepeatNGramLogitsProcessor(int ngram_size) : template void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, - NextTokenScores& next_token_scores) { + NextTokenScores& next_token_scores, + int counter) { if (ngram_size_ == 0 || ngram_size_ > sequences->GetSequenceLength()) { return; } + if (counter == -1) { + std::cout<< counter <(ngram_size_ - 1); int batch_beam_size = next_token_scores.batch_beam_size; @@ -123,9 +137,14 @@ VocabMaskLogitsProcessor::VocabMaskLogitsProcessor(const gsl::span void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, - NextTokenScores& next_token_scores) { + NextTokenScores& next_token_scores, + int counter) { assert(!vocab_mask_.empty()); + if (counter == -1) { + std::cout<< counter <::Process(const ISequences* /*sequences*/, #endif } +template +PrefixVocabMaskLogitsProcessor::PrefixVocabMaskLogitsProcessor(const gsl::span& prefix_vocab_mask) : prefix_vocab_mask_(prefix_vocab_mask) { +} + +template +void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, + NextTokenScores& next_token_scores, + int counter) { + assert(!prefix_vocab_mask_.empty()); + + if (counter > 1) { + return; + } + + // Process vocabulary mask and set tokens with mask value 0 to -inf. + T* p = next_token_scores.scores.data(); + // next_token_scores shape (batch_size * num_beams, vocab_size) + // vocab_mask shape (vocab_size). TODO: support shape (batch_size, vocab_size) + for (int i = 0; i < next_token_scores.batch_beam_size; i++) { + for (int j = 0; j < next_token_scores.vocab_size; j++, p++) { + if (prefix_vocab_mask_[j] == 0) { + *p = std::numeric_limits::lowest(); + } + } + } + +#ifdef DEBUG_BEAM_SEARCH + DumpScores("PrefixVocabMaskLogitsProcessor", next_token_scores.scores); +#endif +} + template void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { processor_list_.clear(); @@ -162,6 +212,13 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { processor_list_.push_back(vocab_mask_processor_.get()); } + if (!parameters.prefix_vocab_mask.empty()) { + prefix_vocab_mask_processor_ = std::make_unique>(parameters.prefix_vocab_mask); + processor_list_.push_back(prefix_vocab_mask_processor_.get()); + } else { + std::cout<<" Prefix vocab mask is empty"<< std::endl; + } + if (parameters.min_length > 0) { min_length_processor_ = std::make_unique>(parameters.min_length, parameters.eos_token_id); processor_list_.push_back(min_length_processor_.get()); @@ -173,10 +230,11 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { template void LogitsProcessorList::Process(const ISequences* sequences, - gsl::span& next_token_scores) { + gsl::span& next_token_scores, + int counter) { NextTokenScores input_scores = {next_token_scores, batch_beam_size_, vocab_size_}; for (size_t i = 0; i < processor_list_.size(); i++) { - processor_list_[i]->Process(sequences, input_scores); + processor_list_[i]->Process(sequences, input_scores, counter); } } @@ -185,6 +243,7 @@ template class MinLengthLogitsProcessor; template class RepetitionPenaltyLogitsProcessor; template class NoRepeatNGramLogitsProcessor; template class VocabMaskLogitsProcessor; +template class PrefixVocabMaskLogitsProcessor; template class LogitsProcessorList; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 78fe9acf63bcb..9e2900ab9fdd4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -24,7 +24,8 @@ class ILogitsProcessor { virtual ~ILogitsProcessor() {} virtual void Process(const ISequences* sequences, - NextTokenScores& next_token_scores) = 0; + NextTokenScores& next_token_scores, + int counter) = 0; }; template @@ -33,7 +34,8 @@ class MinLengthLogitsProcessor : public ILogitsProcessor { MinLengthLogitsProcessor(int min_length, int eos_token_id); void Process(const ISequences* sequences, - NextTokenScores& next_token_scores) override; + NextTokenScores& next_token_scores, + int counter) override; private: int min_length_; @@ -46,7 +48,8 @@ class RepetitionPenaltyLogitsProcessor : public ILogitsProcessor { RepetitionPenaltyLogitsProcessor(float penalty); void Process(const ISequences* sequences, - NextTokenScores& next_token_scores) override; + NextTokenScores& next_token_scores, + int counter) override; private: float penalty_; @@ -58,7 +61,8 @@ class NoRepeatNGramLogitsProcessor : public ILogitsProcessor { NoRepeatNGramLogitsProcessor(int ngram_size); void Process(const ISequences* sequences, - NextTokenScores& next_token_scores) override; + NextTokenScores& next_token_scores, + int counter) override; private: int ngram_size_; @@ -70,18 +74,32 @@ class VocabMaskLogitsProcessor : public ILogitsProcessor { VocabMaskLogitsProcessor(const gsl::span& vocab_mask); void Process(const ISequences* sequences, - NextTokenScores& next_token_scores) override; + NextTokenScores& next_token_scores, + int counter) override; private: gsl::span vocab_mask_; }; +template +class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor { + public: + PrefixVocabMaskLogitsProcessor(const gsl::span& vocab_mask); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores, + int counter) override; + + private: + gsl::span prefix_vocab_mask_; +}; + template class LogitsProcessorList { public: LogitsProcessorList() = default ; void Init(const BeamSearchParameters& parameters); - void Process(const ISequences* sequences, gsl::span& next_token_scores); + void Process(const ISequences* sequences, gsl::span& next_token_scores, int counter); private: int batch_beam_size_; @@ -91,6 +109,7 @@ class LogitsProcessorList { std::unique_ptr> repetition_penalty_processor_; std::unique_ptr> no_repeat_ngram_processor_; std::unique_ptr> vocab_mask_processor_; + std::unique_ptr> prefix_vocab_mask_processor_; std::unique_ptr> min_length_processor_; }; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 6eb90e969c241..18174be669fab 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -686,6 +686,7 @@ void RegisterTextGenerationSchemas() { "T", OpSchema::Optional) .Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) .Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(9, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", diff --git a/onnxruntime/core/providers/cpu/math/top_k.cc b/onnxruntime/core/providers/cpu/math/top_k.cc index f9b2c7d736ef9..829f5e7ea09b7 100644 --- a/onnxruntime/core/providers/cpu/math/top_k.cc +++ b/onnxruntime/core/providers/cpu/math/top_k.cc @@ -285,6 +285,8 @@ static void FindTopKElements(const Tensor* input, const TensorShape& input_shape } }; } else { + std::cout<<"Sorting top K"< Date: Mon, 10 Jan 2022 09:16:07 -0800 Subject: [PATCH 40/53] removing debugs --- .../cpu/transformers/beam_search.cc | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index b8b4b002705eb..4cff9237ca2ac 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -193,9 +193,6 @@ void BeamSearch::Init(const OpKernelInfo& info) { parameters_.ParseFromAttributes(info); - //TODO remove this before commit - ConfigureTensorDump(); - stream_ = nullptr; } @@ -470,10 +467,6 @@ Status BeamSearchImpl::ProcessLogits( return status; } - // TODO remove this before commit - DumpTensor("topk_scores", *(topk_scores.get())); - DumpTensor("topk_indices", *(topk_indices.get())); - #ifdef DEBUG_BEAM_SEARCH DumpTensor("topk_scores", *(topk_scores.get())); DumpTensor("topk_indices", *(topk_indices.get())); @@ -495,9 +488,6 @@ Status BeamSearchImpl::ProcessLogits( gsl::span next_tokens(beam_state.next_tokens.data(), beam_state.next_tokens.size()); gsl::span next_indices(beam_state.next_indices.data(), beam_state.next_indices.size()); - // TODO remove before commit - DumpTensor("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, top_k); - #ifdef DEBUG_BEAM_SEARCH DumpTensor("next_scores before scorer", next_scores.data(), parameters_->batch_size, top_k); DumpTensor("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, top_k); @@ -622,11 +612,6 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { int current_length = parameters_->sequence_length; int iteration_counter = 0; while (current_length < parameters_->max_length) { - - DumpOrtValue("input_ids", input_ids); - DumpOrtValue("position_ids", feeds[1]); - DumpOrtValue("attention_mask", feeds[2]); - iteration_counter++; #ifdef DEBUG_BEAM_SEARCH DumpString("***CurrentLength", std::to_string(current_length), true); @@ -640,11 +625,7 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { const OrtValue& logits = fetches[0]; gsl::span beam_next_tokens; gsl::span beam_indices; -<<<<<<< HEAD ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state, iteration_counter)); -======= - ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state)); ->>>>>>> 7d93498e0ec4b1b2a6f55161539513318a908e04 // When all batches are finished, stop earlier to avoid wasting computation. if (beam_scorer_->IsDone()) { From 5fcac5686b4badd1cc96e8778620ccc90ea336e8 Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Mon, 10 Jan 2022 10:42:53 -0800 Subject: [PATCH 41/53] removing more debugs --- .../cpu/transformers/logits_processor.cc | 39 +++++-------------- .../cpu/transformers/logits_processor.h | 22 +++++------ onnxruntime/core/providers/cpu/math/top_k.cc | 2 - 3 files changed, 20 insertions(+), 43 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 94d2a7abb3639..9a850dd3a8057 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -9,6 +9,8 @@ namespace onnxruntime { namespace contrib { namespace transformers { +static int beam_search_iteration; + template gsl::span NextTokenScores::GetScores(int batch_beam_index) { assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); @@ -38,15 +40,11 @@ MinLengthLogitsProcessor::MinLengthLogitsProcessor(int min_length, int eos_to template void MinLengthLogitsProcessor::Process(const ISequences* sequences, - NextTokenScores& next_token_scores, - int counter) { + NextTokenScores& next_token_scores) { if (sequences->GetSequenceLength() < min_length_) { next_token_scores.SetScore(eos_token_id_, std::numeric_limits::lowest()); } - if (counter == -1) { - std::cout<< counter <::RepetitionPenaltyLogitsProcessor(float pena template void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, - NextTokenScores& next_token_scores, - int counter) { + NextTokenScores& next_token_scores) { const int batch_beam_size = next_token_scores.batch_beam_size; - if (counter == -1) { - std::cout<< counter < beam_token_scores = next_token_scores.GetScores(i); @@ -95,16 +89,11 @@ NoRepeatNGramLogitsProcessor::NoRepeatNGramLogitsProcessor(int ngram_size) : template void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, - NextTokenScores& next_token_scores, - int counter) { + NextTokenScores& next_token_scores) { if (ngram_size_ == 0 || ngram_size_ > sequences->GetSequenceLength()) { return; } - if (counter == -1) { - std::cout<< counter <(ngram_size_ - 1); int batch_beam_size = next_token_scores.batch_beam_size; @@ -140,14 +129,9 @@ VocabMaskLogitsProcessor::VocabMaskLogitsProcessor(const gsl::span void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, - NextTokenScores& next_token_scores, - int counter) { + NextTokenScores& next_token_scores) { assert(!vocab_mask_.empty()); - if (counter == -1) { - std::cout<< counter <::PrefixVocabMaskLogitsProcessor(const gsl::spa template void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, - NextTokenScores& next_token_scores, - int counter) { + NextTokenScores& next_token_scores) { assert(!prefix_vocab_mask_.empty()); - - if (counter > 1) { + if (beam_search_iteration > 1) { return; } @@ -218,8 +200,6 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { if (!parameters.prefix_vocab_mask.empty()) { prefix_vocab_mask_processor_ = std::make_unique>(parameters.prefix_vocab_mask); processor_list_.push_back(prefix_vocab_mask_processor_.get()); - } else { - std::cout<<" Prefix vocab mask is empty"<< std::endl; } if (parameters.min_length > 0) { @@ -236,8 +216,9 @@ void LogitsProcessorList::Process(const ISequences* sequences, gsl::span& next_token_scores, int counter) { NextTokenScores input_scores = {next_token_scores, batch_beam_size_, vocab_size_}; + beam_search_iteration = counter; for (size_t i = 0; i < processor_list_.size(); i++) { - processor_list_[i]->Process(sequences, input_scores, counter); + processor_list_[i]->Process(sequences, input_scores); } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 9e2900ab9fdd4..5e75679488655 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -24,8 +24,9 @@ class ILogitsProcessor { virtual ~ILogitsProcessor() {} virtual void Process(const ISequences* sequences, - NextTokenScores& next_token_scores, - int counter) = 0; + NextTokenScores& next_token_scores) = 0; + + }; template @@ -34,8 +35,7 @@ class MinLengthLogitsProcessor : public ILogitsProcessor { MinLengthLogitsProcessor(int min_length, int eos_token_id); void Process(const ISequences* sequences, - NextTokenScores& next_token_scores, - int counter) override; + NextTokenScores& next_token_scores) override; private: int min_length_; @@ -48,8 +48,7 @@ class RepetitionPenaltyLogitsProcessor : public ILogitsProcessor { RepetitionPenaltyLogitsProcessor(float penalty); void Process(const ISequences* sequences, - NextTokenScores& next_token_scores, - int counter) override; + NextTokenScores& next_token_scores) override; private: float penalty_; @@ -61,8 +60,7 @@ class NoRepeatNGramLogitsProcessor : public ILogitsProcessor { NoRepeatNGramLogitsProcessor(int ngram_size); void Process(const ISequences* sequences, - NextTokenScores& next_token_scores, - int counter) override; + NextTokenScores& next_token_scores) override; private: int ngram_size_; @@ -74,8 +72,7 @@ class VocabMaskLogitsProcessor : public ILogitsProcessor { VocabMaskLogitsProcessor(const gsl::span& vocab_mask); void Process(const ISequences* sequences, - NextTokenScores& next_token_scores, - int counter) override; + NextTokenScores& next_token_scores) override; private: gsl::span vocab_mask_; @@ -87,8 +84,7 @@ class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor { PrefixVocabMaskLogitsProcessor(const gsl::span& vocab_mask); void Process(const ISequences* sequences, - NextTokenScores& next_token_scores, - int counter) override; + NextTokenScores& next_token_scores) override; private: gsl::span prefix_vocab_mask_; @@ -104,6 +100,8 @@ class LogitsProcessorList { private: int batch_beam_size_; int vocab_size_; + int counter_; + std::vector*> processor_list_; std::unique_ptr> repetition_penalty_processor_; diff --git a/onnxruntime/core/providers/cpu/math/top_k.cc b/onnxruntime/core/providers/cpu/math/top_k.cc index 829f5e7ea09b7..f9b2c7d736ef9 100644 --- a/onnxruntime/core/providers/cpu/math/top_k.cc +++ b/onnxruntime/core/providers/cpu/math/top_k.cc @@ -285,8 +285,6 @@ static void FindTopKElements(const Tensor* input, const TensorShape& input_shape } }; } else { - std::cout<<"Sorting top K"< Date: Tue, 11 Jan 2022 19:32:42 -0800 Subject: [PATCH 42/53] clean up --- onnxruntime/contrib_ops/cpu/transformers/logits_processor.h | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 5e75679488655..fa508f90a03b7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -100,7 +100,6 @@ class LogitsProcessorList { private: int batch_beam_size_; int vocab_size_; - int counter_; std::vector*> processor_list_; From 7111387b6af83804dee594388a2e6f5ebf485d0a Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Tue, 11 Jan 2022 19:47:42 -0800 Subject: [PATCH 43/53] clean up --- onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc | 1 - onnxruntime/contrib_ops/cpu/transformers/logits_processor.h | 3 --- onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 1 + 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 9a850dd3a8057..5909192aa36b2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -58,7 +58,6 @@ template void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, NextTokenScores& next_token_scores) { const int batch_beam_size = next_token_scores.batch_beam_size; - for (int i = 0; i < batch_beam_size; i++) { gsl::span beam_token_scores = next_token_scores.GetScores(i); gsl::span sequence = sequences->GetSequence(i); diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index fa508f90a03b7..4040552ac364e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -25,8 +25,6 @@ class ILogitsProcessor { virtual void Process(const ISequences* sequences, NextTokenScores& next_token_scores) = 0; - - }; template @@ -100,7 +98,6 @@ class LogitsProcessorList { private: int batch_beam_size_; int vocab_size_; - std::vector*> processor_list_; std::unique_ptr> repetition_penalty_processor_; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 136153f133a3d..6f57a214b2b74 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -620,6 +620,7 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value())) { return; } + int64_t batch_size = input_ids_dims[0].dim_value(); int64_t sequence_length = input_ids_dims[1].dim_value(); From aa8d4c2b522a600d6992b5d74ad7209cb11c2856 Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Tue, 11 Jan 2022 23:36:08 -0800 Subject: [PATCH 44/53] cpu doc updated --- docs/ContribOperators.md | 4 +- docs/OperatorKernels.md | 323 +-------------------------------------- 2 files changed, 4 insertions(+), 323 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index d8630076c1a12..88a43606f31a4 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -361,7 +361,7 @@ This version of the operator has been available since version 1 of the 'com.micr
The id of the padding token
-#### Inputs (6 - 9) +#### Inputs (6 - 10)
input_ids : I
@@ -382,6 +382,8 @@ This version of the operator has been available since version 1 of the 'com.micr
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
prefix_vocab_mask (optional) : M
+
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
#### Outputs (1 - 3) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ee92269f1bfef..b348b3a4eed5d 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -5,7 +5,6 @@ Do not modify directly.* ## Execution Providers - [CPUExecutionProvider](#cpuexecutionprovider) -- [CUDAExecutionProvider](#cudaexecutionprovider) --------------- @@ -377,7 +376,7 @@ Do not modify directly.* |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| -|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* temperature:**T**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| +|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* temperature:**T**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| |BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)| |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| @@ -434,323 +433,3 @@ Do not modify directly.* |Upsample|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| | | | | - - - - -## Operators implemented by CUDAExecutionProvider - -| Op Name | Parameters | OpSet Version | Types Supported | -|---------|------------|---------------|-----------------| -|**Operator Domain:** *ai.onnx*|||| -|Abs|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Add|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| -|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| -|||[7, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| -|BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|15+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(double), tensor(float), tensor(float16)| -|||14|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)| -|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|Clip|*in* input:**T**
*in* min:**T**
*in* max:**T**
*out* output:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)| -|||11|**T** = tensor(float)| -|||[6, 10]|**T** = tensor(float)| -|Compress|*in* input:**T**
*in* condition:**T1**
*out* output:**T**|11+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| -|||[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| -|Concat|*in* inputs:**T**
*out* concat_result:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ConvTranspose|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|Cos|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| -|Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| -|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| -|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**|10+|**T** = tensor(int8), tensor(uint8)| -|Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|Dropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T1**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| -|||12|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| -|||[10, 11]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bool)| -|||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)| -|DynamicSlice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(double), tensor(float), tensor(float16)| -|Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| -|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[7, 10]|**T** = tensor(bool), tensor(int32), tensor(int64)| -|Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|Exp|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|Expand|*in* input:**T**
*in* shape:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[8, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|EyeLike|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)
**T2** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)| -|Flatten|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[1, 8]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Floor|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|GRU|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| -|||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| -|Gather|*in* data:**T**
*in* indices:**Tind**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|GatherElements|*in* data:**T**
*in* indices:**Tind**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**Tind** = tensor(int64)| -|||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**Tind** = tensor(int64)| -|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|GlobalAveragePool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|GlobalMaxPool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|Greater|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| -|HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|14+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|If|*in* cond:**B**
*out* outputs:**V**|13+|**B** = tensor(bool)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[11, 12]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| -|LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| -|||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| -|LayerNormalization|*in* X:**T**
*in* Scale:**T**
*in* B:**T**
*out* Y:**T**
*out* Mean:**U**
*out* InvStdDev:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)| -|LeakyRelu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| -|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|LessOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| -|Log|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|LogSoftmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|Loop|*in* M:**I**
*in* cond:**B**
*in* v_initial:**V**
*out* v_final_and_scan_outputs:**V**|13+|**B** = tensor(bool)
**I** = tensor(int64)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[11, 12]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[1, 10]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|MatMulInteger|*in* A:**T1**
*in* B:**T2**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*out* Y:**T3**|10+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(int32)| -|Max|*in* data_0:**T**
*out* max:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||12|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[6, 11]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|MaxPool|*in* X:**T**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**
*out* Indices:**I**|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int8), tensor(uint8)| -|||11|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| -|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| -|||[8, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 7]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| -|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||12|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[6, 11]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|Mul|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| -|NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)| -|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)| -|Not|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool)
**T1** = tensor(bool)| -|OneHot|*in* indices:**T1**
*in* depth:**T2**
*in* values:**T3**
*out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)
**T3** = tensor(float), tensor(float16), tensor(int64)| -|Or|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|Pow|*in* X:**T**
*in* Y:**T**
*out* Z:**T**

or

*in* X:**T**
*in* Y:**T1**
*out* Z:**T**|15+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[13, 14]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[7, 11]|**T** = tensor(double), tensor(float), tensor(float16)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|10+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| -|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| -|||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| -|RandomNormal|*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|RandomNormalLike|*in* input:**T1**
*out* output:**T2**|1+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(double), tensor(float), tensor(float16)| -|RandomUniform|*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|RandomUniformLike|*in* input:**T1**
*out* output:**T2**|1+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(double), tensor(float), tensor(float16)| -|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* output:**T**|11+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| -|Reciprocal|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceL1|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceL2|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceLogSum|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceLogSumExp|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|ReduceMax|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|ReduceMean|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceMin|*in* data:**T**
*out* reduced:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||13|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| -|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceProd|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| -|ReduceSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| -|ReduceSumSquare|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| -|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| -|||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| -|||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| -|||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| -|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| -|ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T** = tensor(double), tensor(float)
**T2** = tensor(int64)| -|Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| -|ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|11+|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[9, 10]|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| -|SequenceAt|*in* input_sequence:**S**
*in* position:**I**
*out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SequenceConstruct|*in* inputs:**T**
*out* output_sequence:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SequenceEmpty|*out* output:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||[13, 14]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**T**
*out* Y:**T**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)| -|Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| -|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)| -|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)| -|||10|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)| -|||[1, 9]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)| -|Softmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|Softplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|Softsign|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|SpaceToDepth|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|Split|*in* input:**T**
*in* split:**T**
*out* outputs...:**T**

or

*in* input:**T**
*in* split:**tensor(int64)**
*out* outputs:**T**

or

*in* input:**T**
*out* outputs:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|Sum|*in* data_0:**T**
*out* sum:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|||[8, 12]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|||[6, 7]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|Tanh|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|ThresholdedRelu|*in* X:**T**
*out* Y:**T**|10+|**T** = tensor(double), tensor(float), tensor(float16)| -|||1+|**T** = tensor(double), tensor(float), tensor(float16)| -|Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**

or

*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| -|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| -|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||10|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[1, 9]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Transpose|*in* data:**T**
*out* transposed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Upsample|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**|9|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| -|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| -|Where|*in* condition:**B**
*in* X:**T**
*in* Y:**T**
*out* output:**T**|9+|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| -|Xor|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -| | -| | -|**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| -|BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| -|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|ComplexMul|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| -|ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| -|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| -|DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)| -|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)| -|EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)| -|FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| -|FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| -|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T** = tensor(float)| -|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| -|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float16)
**T2** = tensor(int8), tensor(uint8)| -|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)| -|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -| | -| | From f1c449d6307e51f5615febfede15706adc094b76 Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Wed, 12 Jan 2022 05:10:47 -0800 Subject: [PATCH 45/53] Updated docs --- docs/OperatorKernels.md | 321 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 321 insertions(+) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b348b3a4eed5d..4123f5c9ef684 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -5,6 +5,7 @@ Do not modify directly.* ## Execution Providers - [CPUExecutionProvider](#cpuexecutionprovider) +- [CUDAExecutionProvider](#cudaexecutionprovider) --------------- @@ -433,3 +434,323 @@ Do not modify directly.* |Upsample|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| | | | | + + +
+ +## Operators implemented by CUDAExecutionProvider + +| Op Name | Parameters | OpSet Version | Types Supported | +|---------|------------|---------------|-----------------| +|**Operator Domain:** *ai.onnx*|||| +|Abs|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Add|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| +|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| +|||[7, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| +|BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|15+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(double), tensor(float), tensor(float16)| +|||14|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)| +|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| +|Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|Clip|*in* input:**T**
*in* min:**T**
*in* max:**T**
*out* output:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)| +|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)| +|||11|**T** = tensor(float)| +|||[6, 10]|**T** = tensor(float)| +|Compress|*in* input:**T**
*in* condition:**T1**
*out* output:**T**|11+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|Concat|*in* inputs:**T**
*out* concat_result:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|ConvTranspose|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|Cos|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| +|Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| +|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| +|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**|10+|**T** = tensor(int8), tensor(uint8)| +|Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|Dropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T1**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|||12|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|||[10, 11]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bool)| +|||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)| +|DynamicSlice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(double), tensor(float), tensor(float16)| +|Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[7, 10]|**T** = tensor(bool), tensor(int32), tensor(int64)| +|Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|Exp|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|Expand|*in* input:**T**
*in* shape:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[8, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|EyeLike|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)
**T2** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)| +|Flatten|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[1, 8]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Floor|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|GRU|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|Gather|*in* data:**T**
*in* indices:**Tind**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|GatherElements|*in* data:**T**
*in* indices:**Tind**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**Tind** = tensor(int64)| +|||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**Tind** = tensor(int64)| +|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| +|GlobalAveragePool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|GlobalMaxPool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|Greater|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| +|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|14+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|If|*in* cond:**B**
*out* outputs:**V**|13+|**B** = tensor(bool)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[11, 12]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|LayerNormalization|*in* X:**T**
*in* Scale:**T**
*in* B:**T**
*out* Y:**T**
*out* Mean:**U**
*out* InvStdDev:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)| +|LeakyRelu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| +|LessOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|Log|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|LogSoftmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|Loop|*in* M:**I**
*in* cond:**B**
*in* v_initial:**V**
*out* v_final_and_scan_outputs:**V**|13+|**B** = tensor(bool)
**I** = tensor(int64)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[11, 12]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[1, 10]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 8]|**T** = tensor(double), tensor(float), tensor(float16)| +|MatMulInteger|*in* A:**T1**
*in* B:**T2**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*out* Y:**T3**|10+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(int32)| +|Max|*in* data_0:**T**
*out* max:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||12|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[6, 11]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|MaxPool|*in* X:**T**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**
*out* Indices:**I**|12+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int8), tensor(uint8)| +|||11|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| +|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| +|||[8, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 7]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| +|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||12|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[6, 11]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|Mul|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| +|NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)| +|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)| +|Not|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool)
**T1** = tensor(bool)| +|OneHot|*in* indices:**T1**
*in* depth:**T2**
*in* values:**T3**
*out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)
**T3** = tensor(float), tensor(float16), tensor(int64)| +|Or|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| +|PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|Pow|*in* X:**T**
*in* Y:**T**
*out* Z:**T**

or

*in* X:**T**
*in* Y:**T1**
*out* Z:**T**|15+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[13, 14]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[7, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|10+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| +|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|RandomNormal|*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|RandomNormalLike|*in* input:**T1**
*out* output:**T2**|1+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(double), tensor(float), tensor(float16)| +|RandomUniform|*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|RandomUniformLike|*in* input:**T1**
*out* output:**T2**|1+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(double), tensor(float), tensor(float16)| +|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* output:**T**|11+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| +|Reciprocal|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceL1|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceL2|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceLogSum|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceLogSumExp|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|ReduceMax|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceMean|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceMin|*in* data:**T**
*out* reduced:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||13|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||11|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceProd|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32)| +|ReduceSum|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**

or

*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|ReduceSumSquare|*in* data:**T**
*out* reduced:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T** = tensor(double), tensor(float)
**T2** = tensor(int64)| +|Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| +|ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|11+|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[9, 10]|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|SequenceAt|*in* input_sequence:**S**
*in* position:**I**
*out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|SequenceConstruct|*in* inputs:**T**
*out* output_sequence:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|SequenceEmpty|*out* output:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||[13, 14]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**T**
*out* Y:**T**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)| +|Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| +|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)| +|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)| +|||10|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)| +|||[1, 9]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(float), tensor(int32), tensor(int64)| +|Softmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|Softplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|Softsign|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|SpaceToDepth|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|Split|*in* input:**T**
*in* split:**T**
*out* outputs...:**T**

or

*in* input:**T**
*in* split:**tensor(int64)**
*out* outputs:**T**

or

*in* input:**T**
*out* outputs:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|Sum|*in* data_0:**T**
*out* sum:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[8, 12]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[6, 7]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|Tanh|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|ThresholdedRelu|*in* X:**T**
*out* Y:**T**|10+|**T** = tensor(double), tensor(float), tensor(float16)| +|||1+|**T** = tensor(double), tensor(float), tensor(float16)| +|Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**

or

*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| +|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| +|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||10|**I** = tensor(int64)
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[1, 9]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Transpose|*in* data:**T**
*out* transposed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Upsample|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**|9|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|Where|*in* condition:**B**
*in* X:**T**
*in* Y:**T**
*out* output:**T**|9+|**B** = tensor(bool)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| +|Xor|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| +| | +| | +|**Operator Domain:** *com.microsoft*|||| +|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|ComplexMul|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| +|ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| +|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| +|DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)| +|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)| +|EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)| +|FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| +|FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| +|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T** = tensor(float)| +|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| +|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float16)
**T2** = tensor(int8), tensor(uint8)| +|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)| +|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +| | +| | From eebf728ebf40bda1155a27d8e453d85a5eaefefb Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Thu, 20 Jan 2022 11:10:09 -0800 Subject: [PATCH 46/53] updated prefix_vocab_mask dimension in convert script --- .../contrib_ops/cpu/transformers/beam_search.cc | 11 ++++++++--- onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 2 +- .../tools/transformers/convert_beam_search.py | 14 ++++++++++++-- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 4cff9237ca2ac..ee33ac7850c39 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -298,13 +298,18 @@ Status BeamSearchImpl::CheckInputs(const OpKernelContextInternal& context) { if (prefix_vocab_mask != nullptr) { // prefix_vocab_mask is optional const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims(); - if (vocab_mask_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 1 dimension, got ", + if (vocab_mask_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 2 dimensions, got ", vocab_mask_dims.size()); } + // prefix_vocab_mask first dimension should be same as the first dimension of input_ids + if (static_cast(vocab_mask_dims[0]) != static_cast(dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input_ids and prefix_vocab_mask must have the same batch_size"); + } + // There is dependency on vocab_size parameter, which shall be set before calling this function. - if (static_cast(vocab_mask_dims[0]) != parameters_->vocab_size) { + if (static_cast(vocab_mask_dims[1]) != parameters_->vocab_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' shape does not match with vocab_size, got ", vocab_mask_dims[0]); } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 6f57a214b2b74..eb9a1f14949dc 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -695,7 +695,7 @@ void RegisterTextGenerationSchemas() { "T", OpSchema::Optional) .Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) .Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) - .Input(9, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(9, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 02cbd3bc4a4a3..bbb1ed3a93c2f 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -131,12 +131,13 @@ def parse_arguments(argv=None): beam_search_group.add_argument('--vocab_size', type=int, required=False, + default=-1, help="Vocab_size of the underlying model") beam_search_group.add_argument('--prefix_vocab_mask', required=False, action='store_true', - help="This vocab mask applies only to first iteration, enable if last work in query might need auto complete") + help="This vocab mask applies only to first iteration, enable if last word in query might need auto complete") beam_search_group.set_defaults(prefix_vocab_mask=False) mixed_precision_option_group = parser.add_argument_group( @@ -241,6 +242,10 @@ def convert_model(args): pad_token_id = config.eos_token_id vocab_size = config.vocab_size + # if vocab_size is given in parameters use that. + if args.vocab_size != -1: + vocab_size = args.vocab_size + model = onnx.load(args.gpt2_onnx) model.graph.name = "gpt2 subgraph" inputs = [ @@ -287,7 +292,7 @@ def convert_model(args): ] if args.prefix_vocab_mask: - prefix_vocab_mask = helper.make_tensor_value_info('prefix_vocab_mask', TensorProto.INT32, [vocab_size]) + prefix_vocab_mask = helper.make_tensor_value_info('prefix_vocab_mask', TensorProto.INT32, ['batch_size', vocab_size]) graph_inputs.append(prefix_vocab_mask) # graph outputs @@ -318,6 +323,11 @@ def convert_model(args): def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): + + if args.prefix_vocab_mask: + print("Skipping parity test as prefix vocab mask is not implemented by Huggin Face") + return + from transformers import GPT2Tokenizer, GPT2LMHeadModel tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) From 07a9a3b0e78765f69648a9caa0209ccb18801d67 Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Thu, 20 Jan 2022 23:20:58 -0800 Subject: [PATCH 47/53] changes to support bxs prefix_vocab_mask in beamsearchop kernel --- .../cpu/transformers/logits_processor.cc | 20 ++++++++++++------- .../cpu/transformers/logits_processor.h | 3 ++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 5909192aa36b2..7334d7a53e5e7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -149,7 +149,7 @@ void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, } template -PrefixVocabMaskLogitsProcessor::PrefixVocabMaskLogitsProcessor(const gsl::span& prefix_vocab_mask) : prefix_vocab_mask_(prefix_vocab_mask) { +PrefixVocabMaskLogitsProcessor::PrefixVocabMaskLogitsProcessor(const gsl::span& prefix_vocab_mask, int batch_size) : prefix_vocab_mask_(prefix_vocab_mask), batch_size_(batch_size) { } template @@ -162,12 +162,18 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, // Process vocabulary mask and set tokens with mask value 0 to -inf. T* p = next_token_scores.scores.data(); + // next_token_scores shape (batch_size * num_beams, vocab_size) - // vocab_mask shape (vocab_size). TODO: support shape (batch_size, vocab_size) - for (int i = 0; i < next_token_scores.batch_beam_size; i++) { - for (int j = 0; j < next_token_scores.vocab_size; j++, p++) { - if (prefix_vocab_mask_[j] == 0) { - *p = std::numeric_limits::lowest(); + int num_beams = next_token_scores.batch_beam_size / batch_size_; + assert(num_beams*batch_size_ == next_token_scores.batch_beam_size); + + for (int i = 0; i < batch_size_; i++) { + int batch_vocab_mask_offset = i * next_token_scores.vocab_size; + for (int j = 0; j < num_beams; j++) { + for (int k = batch_vocab_mask_offset; k < next_token_scores.vocab_size; k++, p++) { + if (prefix_vocab_mask_[k] == 0) { + *p = std::numeric_limits::lowest(); + } } } } @@ -197,7 +203,7 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { } if (!parameters.prefix_vocab_mask.empty()) { - prefix_vocab_mask_processor_ = std::make_unique>(parameters.prefix_vocab_mask); + prefix_vocab_mask_processor_ = std::make_unique>(parameters.prefix_vocab_mask, parameters.batch_size); processor_list_.push_back(prefix_vocab_mask_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 4040552ac364e..f5985b966ba05 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -79,13 +79,14 @@ class VocabMaskLogitsProcessor : public ILogitsProcessor { template class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor { public: - PrefixVocabMaskLogitsProcessor(const gsl::span& vocab_mask); + PrefixVocabMaskLogitsProcessor(const gsl::span& vocab_mask, int batch_size); void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override; private: gsl::span prefix_vocab_mask_; + const int batch_size_; }; template From 3a9544c9dca2b40f88a4da37f7ef563fac4ef508 Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Fri, 21 Jan 2022 02:42:28 -0800 Subject: [PATCH 48/53] doc update --- docs/ContribOperators.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 6428cc0f2f214..88fd736b22bcf 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -383,7 +383,7 @@ This version of the operator has been available since version 1 of the 'com.micr
vocab_mask (optional) : M
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
prefix_vocab_mask (optional) : M
-
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
#### Outputs (1 - 3) @@ -488,7 +488,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
+
T : tensor(float16), tensor(float), tensor(double)
Constrain input and output types to float tensors.
From a312aa4d98bc05af9ec7e2535e41a80a5ece9eec Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Fri, 21 Jan 2022 02:44:13 -0800 Subject: [PATCH 49/53] OperatorKernels.md updated --- docs/OperatorKernels.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 35ac6e0e5b154..4123f5c9ef684 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -25,9 +25,9 @@ Do not modify directly.* |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int8), tensor(uint8)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| +|||[1, 10]|**T** = tensor(float), tensor(int32)| |ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| |||[1, 10]|**T** = tensor(float), tensor(int32)| @@ -729,7 +729,7 @@ Do not modify directly.* |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| -|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |ComplexMul|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| From c7d08dc68549a672953135af4952522eccfbb2bd Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Fri, 21 Jan 2022 05:06:56 -0800 Subject: [PATCH 50/53] matching docs from artifacts --- docs/ContribOperators.md | 2 +- docs/OperatorKernels.md | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 88fd736b22bcf..efbdbb52a2507 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -488,7 +488,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float16), tensor(float), tensor(double)
+
T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
Constrain input and output types to float tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 4123f5c9ef684..35ac6e0e5b154 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -25,9 +25,9 @@ Do not modify directly.* |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| -|||[1, 10]|**T** = tensor(float), tensor(int32)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| +|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int8), tensor(uint8)| |ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| |||[1, 10]|**T** = tensor(float), tensor(int32)| @@ -729,7 +729,7 @@ Do not modify directly.* |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| -|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |ComplexMul|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| From b34ae917fb67722278312b99495352f58b3981a0 Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Wed, 26 Jan 2022 10:42:53 -0800 Subject: [PATCH 51/53] minor change in logits processor --- onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 7334d7a53e5e7..f428487b9522d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -168,9 +168,8 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, assert(num_beams*batch_size_ == next_token_scores.batch_beam_size); for (int i = 0; i < batch_size_; i++) { - int batch_vocab_mask_offset = i * next_token_scores.vocab_size; for (int j = 0; j < num_beams; j++) { - for (int k = batch_vocab_mask_offset; k < next_token_scores.vocab_size; k++, p++) { + for (int k = 0; k < next_token_scores.vocab_size; k++, p++) { if (prefix_vocab_mask_[k] == 0) { *p = std::numeric_limits::lowest(); } From 85a2ae3a5f6882c317d91efbab0b8704266de687 Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Sun, 30 Jan 2022 23:21:30 -0800 Subject: [PATCH 52/53] Addressing comments --- .../contrib_ops/cpu/transformers/dump_tensor.h | 3 +-- .../cpu/transformers/logits_processor.cc | 16 +++++++++------- .../tools/transformers/convert_beam_search.py | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h index d26c3c3efda3a..2f80410011dab 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h @@ -26,8 +26,7 @@ namespace transformers { continue; \ } -//#define SKIP_IF_TOO_MANY(row_or_column_size, i, new_line) SKIP_IF_MORE_THAN(row_or_column_size, i, MAX_ROW_OR_COLUMN, new_line) -#define SKIP_IF_TOO_MANY(row_or_column_size, i, new_line) +#define SKIP_IF_TOO_MANY(row_or_column_size, i, new_line) SKIP_IF_MORE_THAN(row_or_column_size, i, MAX_ROW_OR_COLUMN, new_line) extern bool g_enable_tensor_dump; // global variance to turn on/off dump diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index f428487b9522d..6c42f84c50f46 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -9,6 +9,8 @@ namespace onnxruntime { namespace contrib { namespace transformers { +// beam_search_iteration represents the current iteration counter of beam search +// This value is used to apply processors as needed in specific iteration. static int beam_search_iteration; template @@ -154,23 +156,23 @@ PrefixVocabMaskLogitsProcessor::PrefixVocabMaskLogitsProcessor(const gsl::spa template void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, - NextTokenScores& next_token_scores) { + NextTokenScores& next_token_scores) { assert(!prefix_vocab_mask_.empty()); + if (beam_search_iteration > 1) { return; } - - // Process vocabulary mask and set tokens with mask value 0 to -inf. - T* p = next_token_scores.scores.data(); - // next_token_scores shape (batch_size * num_beams, vocab_size) int num_beams = next_token_scores.batch_beam_size / batch_size_; - assert(num_beams*batch_size_ == next_token_scores.batch_beam_size); + assert(num_beams * batch_size_ == next_token_scores.batch_beam_size); + // Process prefix vocabulary mask and set tokens with mask value 0 to -inf. + // prefix_vocab_mask shape (batch_szie, vocab_size). + T* p = next_token_scores.scores.data(); for (int i = 0; i < batch_size_; i++) { for (int j = 0; j < num_beams; j++) { for (int k = 0; k < next_token_scores.vocab_size; k++, p++) { - if (prefix_vocab_mask_[k] == 0) { + if (prefix_vocab_mask_[i][k] == 0) { *p = std::numeric_limits::lowest(); } } diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index bbb1ed3a93c2f..8bb25fd609b15 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -325,7 +325,7 @@ def convert_model(args): def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): if args.prefix_vocab_mask: - print("Skipping parity test as prefix vocab mask is not implemented by Huggin Face") + print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face") return from transformers import GPT2Tokenizer, GPT2LMHeadModel From 3db1079843b37e4b0965d2c053f276588caafbdd Mon Sep 17 00:00:00 2001 From: Viswanath Boga Date: Mon, 31 Jan 2022 01:10:13 -0800 Subject: [PATCH 53/53] Updated the prefix vocab mask usage properly --- onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 6c42f84c50f46..6473ad5d14b5d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -170,9 +170,10 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, // prefix_vocab_mask shape (batch_szie, vocab_size). T* p = next_token_scores.scores.data(); for (int i = 0; i < batch_size_; i++) { + int prefix_vocab_mask_offset = i * next_token_scores.vocab_size; for (int j = 0; j < num_beams; j++) { for (int k = 0; k < next_token_scores.vocab_size; k++, p++) { - if (prefix_vocab_mask_[i][k] == 0) { + if (prefix_vocab_mask_[prefix_vocab_mask_offset + k] == 0) { *p = std::numeric_limits::lowest(); } }