From 867960220e218ed2d74f5ee8b90fd8dcee0d151d Mon Sep 17 00:00:00 2001 From: chelouche9 Date: Thu, 23 Mar 2023 22:06:18 +0200 Subject: [PATCH] add jsdoc for models.js --- dist/types/models.d.ts | 562 ++++++++++++++++++++++---------- src/models.js | 718 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 1092 insertions(+), 188 deletions(-) diff --git a/dist/types/models.d.ts b/dist/types/models.d.ts index bb500c9b3..845306024 100644 --- a/dist/types/models.d.ts +++ b/dist/types/models.d.ts @@ -1,97 +1,279 @@ +/** + * Helper class to determine model type from config + */ export class AutoModel { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; + /** + * Instantiates a pre-trained model based on the given model path and config. + * @param {string} modelPath - The path to the pre-trained model. + * @param {function} progressCallback - Optional. A callback function that can be used to track loading progress. + * @returns {Promise} - A promise that resolves to an instance of a pre-trained model. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; } +/** + * Class representing an automatic sequence-to-sequence language model. + */ export class AutoModelForSeq2SeqLM { static modelClassMapping: { t5: typeof T5ForConditionalGeneration; bart: typeof BartForConditionalGeneration; whisper: typeof WhisperForConditionalGeneration; }; - static from_pretrained(modelPath: any, progressCallback?: any): Promise; + /** + * Loads a pretrained sequence-to-sequence language model from a file path. + * @param {string} modelPath - The path to the model files. + * @param {function} [progressCallback=null] - A callback function to track loading progress. + * @returns {Promise} A Promise that resolves to an instance of the appropriate model class. + * @throws {Error} If the model type is unsupported. + * @static + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; } +/** + * Helper class for loading sequence classification models from pretrained checkpoints + */ export class AutoModelForSequenceClassification { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; + /** + * Load a sequence classification model from a pretrained checkpoint + * @param {string} modelPath - The path to the model checkpoint directory + * @param {function} [progressCallback=null] - An optional callback function to receive progress updates + * @returns {Promise} A promise that resolves to a pre-trained sequence classification model + * @throws {Error} if an unsupported model type is encountered + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; } +/** + * A class for loading pre-trained models for causal language modeling tasks. + */ export class AutoModelForCausalLM { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; + /** + * Loads a pre-trained model from the given path and returns an instance of the appropriate class. + * @param {string} modelPath - The path to the pre-trained model. + * @param {function} [progressCallback=null] - An optional callback function to track the progress of the loading process. + * @returns {Promise} An instance of the appropriate class for the loaded model. + * @throws {Error} If the loaded model type is not supported. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; } +/** + * A class to automatically select the appropriate model for Masked Language Modeling (MLM) tasks. + */ export class AutoModelForMaskedLM { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; + /** + * Loads a pre-trained model from a given directory and returns an instance of the appropriate model class. + * + * @async + * @param {string} modelPath - The path to the pre-trained model directory. + * @param {function} [progressCallback=null] - An optional callback function to track the loading progress. + * @returns {Promise} An instance of the appropriate model class for MLM tasks. + * @throws {Error} If an unsupported model type is encountered. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; } +/** + * Automatic model class for question answering tasks. + */ export class AutoModelForQuestionAnswering { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; + /** + * Loads and returns a question answering model from a pretrained model path. + * @param {string} modelPath - The path to the pretrained model. + * @param {function} [progressCallback=null] - Optional callback function to track loading progress. + * @returns {Promise} - The loaded question answering model. + * @throws Will throw an error if an unsupported model type is encountered. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; } +/** + * Class representing an autoencoder-decoder model for vision-to-sequence tasks. + */ export class AutoModelForVision2Seq { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; + /** + * Loads a pretrained model from a given path. + * @param {string} modelPath - The path to the pretrained model. + * @param {function} progressCallback - Optional callback function to track progress of the model loading. + * @returns {Promise} - A Promise that resolves to a new instance of VisionEncoderDecoderModel. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; } +/** + * AutoModelForImageClassification is a class for loading pre-trained image classification models from ONNX format. + */ export class AutoModelForImageClassification { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; + /** + * Loads a pre-trained image classification model from a given directory path. + * @param {string} modelPath - The path to the directory containing the pre-trained model. + * @param {function} [progressCallback=null] - A callback function to monitor the loading progress. + * @returns {Promise} A Promise that resolves with an instance of the ViTForImageClassification class. + * @throws {Error} If the specified model type is not supported. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; } +/** + * T5Model is a class representing a T5 model for conditional generation. + * @extends T5PreTrainedModel + */ export class T5ForConditionalGeneration extends T5PreTrainedModel { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; - constructor(config: any, session: any, decoder_merged_session: any, generation_config: any); + /** + * Loads the pre-trained model from a given path. + * @async + * @param {string} modelPath - The path to the pre-trained model. + * @param {function} progressCallback - A function to call with progress updates (optional). + * @returns {Promise} The loaded model instance. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; + /** + * @param {object} config - The model configuration. + * @param {any} session - session for the model. + * @param {any} decoder_merged_session - session for the decoder. + * @param {GenerationConfig} generation_config - The generation configuration. + */ + constructor(config: object, session: any, decoder_merged_session: any, generation_config: GenerationConfig); decoder_merged_session: any; - generation_config: any; + generation_config: GenerationConfig; num_decoder_layers: any; num_decoder_heads: any; decoder_dim_kv: any; num_encoder_layers: any; num_encoder_heads: any; encoder_dim_kv: any; - getStartBeams(inputs: any, numOutputTokens: any, ...args: any[]): { - inputs: any; - encoder_outputs: any; - past_key_values: any; - output_token_ids: any[]; - done: boolean; - score: number; - id: number; - }[]; + /** + * Generates the start beams for a given set of inputs and output length. + * @param {number[][]} inputs - The input token IDs. + * @param {number} numOutputTokens - The desired output length. + * @returns {Array} The start beams. + */ + getStartBeams(inputs: number[][], numOutputTokens: number, ...args: any[]): any[]; + /** + * Runs the beam search for a given beam. + * @async + * @param {any} beam - The current beam. + * @returns {Promise} The model output. + */ runBeam(beam: any): Promise; - updateBeam(beam: any, newTokenId: any): void; - forward(model_inputs: any): Promise; + /** + * Updates the given beam with a new token ID. + * @param {any} beam - The current beam. + * @param {number} newTokenId - The new token ID to add to the output sequence. + */ + updateBeam(beam: any, newTokenId: number): void; } +/** + * A base class for pre-trained models that provides the model configuration and a TensorFlow session. + * @extends Callable + */ declare class PreTrainedModel extends Callable { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; - constructor(config: any, session: any); + /** + * Loads a pre-trained model from the given modelPath. + * @static + * @async + * @param {string} modelPath - The path to the pre-trained model. + * @param {function} progressCallback - A function to be called with progress updates. + * @returns {Promise} A new instance of the PreTrainedModel class. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; + /** + * @param {object} config - The configuration object for the model. + * @param {any} session - The TensorFlow session for running inference. + */ + constructor(config: object, session: any); config: any; session: any; - dispose(): Promise; - toI64Tensor(items: any): Tensor; + /** + * Disposes of all the ONNX sessions that were created during inference. + * @returns {Promise>} - An array of promises, one for each ONNX session that is being disposed. + */ + dispose(): Promise>; + /** + * Converts an array or Tensor of integers to an int64 Tensor. + * @param {Array|Tensor} items - The input integers to be converted. + * @returns {Tensor} The int64 Tensor with the converted values. + * @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length. + */ + toI64Tensor(items: any[] | Tensor): Tensor; + /** + * Runs the model with the provided inputs + * @param {Object} model_inputs - Object containing input tensors + * @returns {Promise} - Object containing output tensors + */ _call(model_inputs: any): Promise; - forward(model_inputs: any): Promise; + /** + * Forward method should be implemented in subclasses. + * @abstract + * @param {object} model_inputs - The input data to the model in the format specified in the ONNX model. + * @returns {Promise} - The output data from the model in the format specified in the ONNX model. + * @throws {Error} - This method must be implemented in subclasses. + */ + forward(model_inputs: object): Promise; /** * @param {GenerationConfig} generation_config * @param {number} input_ids_seq_length * @returns {LogitsProcessorList} */ _get_logits_processor(generation_config: GenerationConfig, input_ids_seq_length: number, logits_processor?: any): LogitsProcessorList; - _get_generation_config(generation_config: any): GenerationConfig; - generate(inputs: any, generation_config?: any, logits_processor?: any, { inputs_attention_mask }?: { + /** + * This function merges multiple generation configs together to form a final generation config to be used by the model for text generation. + * It first creates an empty `GenerationConfig` object, then it applies the model's own `generation_config` property to it. Finally, if a `generation_config` object was passed in the arguments, it overwrites the corresponding properties in the final config with those of the passed config object. + * + * @param {GenerationConfig} generation_config - A `GenerationConfig` object containing generation parameters. + * @returns {GenerationConfig} The final generation config object to be used by the model for text generation. + */ + _get_generation_config(generation_config: GenerationConfig): GenerationConfig; + /** + * Generates text based on the given inputs and generation configuration using the model. + * @param {Array} inputs - An array of input token IDs. + * @param {Object|null} generation_config - The generation configuration to use. If null, default configuration will be used. + * @param {Object|null} logits_processor - An optional logits processor to use. If null, a new LogitsProcessorList instance will be created. + * @param {Object} options - options + * @param {Object} [options.inputs_attention_mask=null] - An optional attention mask for the inputs. + * @returns {Promise} An array of generated output sequences, where each sequence is an array of token IDs. + * @throws {Error} Throws an error if the inputs array is empty. + */ + generate(inputs: any[], generation_config?: any | null, logits_processor?: any | null, { inputs_attention_mask }?: { inputs_attention_mask?: any; }): Promise; - groupBeams(beams: any): any[]; - getPastKeyValues(decoderResults: any): {}; + /** + * Groups an array of beam objects by their ids. + * + * @param {Array} beams - The array of beam objects to group. + * @returns {Array} - An array of arrays, where each inner array contains beam objects with the same id. + */ + groupBeams(beams: any[]): any[]; + /** + * Returns an object containing past key values from the given decoder results object. + * + * @param {Object} decoderResults - The decoder results object. + * @returns {Object} - An object containing past key values. + */ + getPastKeyValues(decoderResults: any): any; + /** + * Adds past key values to the decoder feeds object. If pastKeyValues is null, creates new tensors for past key values. + * + * @param {Object} decoderFeeds - The decoder feeds object to add past key values to. + * @param {Object} pastKeyValues - An object containing past key values. + * @param {boolean} [hasDecoder=false] - Whether the model has a decoder. + */ addPastKeyValues(decoderFeeds: any, pastKeyValues: any, hasDecoder?: boolean): void; } -declare class T5Model extends T5PreTrainedModel { - generate(...args: any[]): Promise; -} -declare class BartModel extends BartPretrainedModel { - generate(...args: any[]): Promise; -} -declare class WhisperModel extends WhisperPreTrainedModel { - generate(...args: any[]): Promise; -} -declare class GPT2Model extends GPT2PreTrainedModel { - generate(...args: any[]): Promise; -} -declare class CodeGenModel extends CodeGenPreTrainedModel { - generate(...args: any[]): Promise; -} +/** + * BART model with a language model head for conditional generation. + * @extends BartPretrainedModel + */ declare class BartForConditionalGeneration extends BartPretrainedModel { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; - constructor(config: any, session: any, decoder_merged_session: any, generation_config: any); + /** + * Loads a BartForConditionalGeneration instance from a pretrained model stored on disk. + * @param {string} modelPath - The path to the directory containing the pretrained model. + * @param {function} [progressCallback=null] - An optional callback function to track the download progress. + * @returns {Promise} - The pretrained BartForConditionalGeneration instance. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; + /** + * Create a new BartForConditionalGeneration instance. + * @param {object} config - The configuration object for the Bart model. + * @param {object} session - The TensorFlow.js session used to execute the model. + * @param {object} decoder_merged_session - The TensorFlow.js session used to execute the decoder. + * @param {object} generation_config - The generation configuration object. + */ + constructor(config: object, session: object, decoder_merged_session: object, generation_config: object); decoder_merged_session: any; generation_config: any; num_decoder_layers: any; @@ -100,21 +282,46 @@ declare class BartForConditionalGeneration extends BartPretrainedModel { num_encoder_layers: any; num_encoder_heads: any; encoder_dim_kv: number; - getStartBeams(inputs: any, numOutputTokens: any, ...args: any[]): { - inputs: any; - encoder_outputs: any; - past_key_values: any; - output_token_ids: any[]; - done: boolean; - score: number; - id: number; - }[]; + /** + * Returns the initial beam for generating output text. + * @param {object} inputs - The input object containing the encoded input text. + * @param {number} numOutputTokens - The maximum number of output tokens to generate. + * @param {...any} args - Additional arguments to pass to the sequence-to-sequence generation function. + * @returns {any} - The initial beam for generating output text. + */ + getStartBeams(inputs: object, numOutputTokens: number, ...args: any[]): any; + /** + * Runs a single step of the beam search generation algorithm. + * @param {any} beam - The current beam being generated. + * @returns {Promise} - The updated beam after a single generation step. + */ runBeam(beam: any): Promise; - updateBeam(beam: any, newTokenId: any): void; - forward(model_inputs: any): Promise; + /** + * Updates the beam by appending the newly generated token ID to the list of output token IDs. + * @param {any} beam - The current beam being generated. + * @param {number} newTokenId - The ID of the newly generated token to append to the list of output token IDs. + */ + updateBeam(beam: any, newTokenId: number): void; } +/** + * WhisperForConditionalGeneration class for generating conditional outputs from Whisper models. + * @extends WhisperPreTrainedModel + */ declare class WhisperForConditionalGeneration extends WhisperPreTrainedModel { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; + /** + * Loads a pre-trained model from a saved model directory. + * @param {string} modelPath - Path to the saved model directory. + * @param {function} progressCallback - Optional function for tracking loading progress. + * @returns {Promise} Promise object represents the loaded model. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; + /** + * Initializes the WhisperForConditionalGeneration object. + * @param {Object} config - Configuration object for the model. + * @param {Object} session - TensorFlow.js Session object for the model. + * @param {Object} decoder_merged_session - TensorFlow.js Session object for the decoder. + * @param {Object} generation_config - Configuration object for the generation process. + */ constructor(config: any, session: any, decoder_merged_session: any, generation_config: any); decoder_merged_session: any; generation_config: any; @@ -124,122 +331,156 @@ declare class WhisperForConditionalGeneration extends WhisperPreTrainedModel { num_encoder_layers: any; num_encoder_heads: any; encoder_dim_kv: number; - generate(inputs: any, generation_config?: any, logits_processor?: any): Promise; - getStartBeams(inputTokenIds: any, numOutputTokens: any, ...args: any[]): { - inputs: any; - encoder_outputs: any; - past_key_values: any; - output_token_ids: any[]; - done: boolean; - score: number; - id: number; - }[]; + /** + * Generates outputs based on input and generation configuration. + * @param {Object} inputs - Input data for the model. + * @param {Object} generation_config - Configuration object for the generation process. + * @param {Object} logits_processor - Optional logits processor object. + * @returns {Promise} Promise object represents the generated outputs. + */ + generate(inputs: any, generation_config?: any, logits_processor?: any): Promise; + /** + * Gets the start beams for generating outputs. + * @param {Array} inputTokenIds - Array of input token IDs. + * @param {number} numOutputTokens - Number of output tokens to generate. + * @returns {Array} Array of start beams. + */ + getStartBeams(inputTokenIds: any[], numOutputTokens: number, ...args: any[]): any[]; + /** + * Runs a beam for generating outputs. + * @param {Object} beam - Beam object. + * @returns {Promise} Promise object represents the generated outputs for the beam. + */ runBeam(beam: any): Promise; - updateBeam(beam: any, newTokenId: any): void; - forward(model_inputs: any): Promise; -} -declare class BertForSequenceClassification extends BertPreTrainedModel { - _call(model_inputs: any): Promise; -} -declare class DistilBertForSequenceClassification extends DistilBertPreTrainedModel { - _call(model_inputs: any): Promise; -} -declare class AlbertForSequenceClassification extends AlbertPreTrainedModel { - _call(model_inputs: any): Promise; -} -declare class RobertaForSequenceClassification extends RobertaPreTrainedModel { - _call(model_inputs: any): Promise; + /** + * Updates the beam by appending the newly generated token ID to the list of output token IDs. + * @param {any} beam - The current beam being generated. + * @param {number} newTokenId - The ID of the newly generated token to append to the list of output token IDs. + */ + updateBeam(beam: any, newTokenId: number): void; } +/** + * GPT-2 language model head on top of the GPT-2 base model. This model is suitable for text generation tasks. + * @extends GPT2PreTrainedModel + */ declare class GPT2LMHeadModel extends GPT2PreTrainedModel { num_heads: any; num_layers: any; dim_kv: number; - getStartBeams(inputTokenIds: any, numOutputTokens: any, inputs_attention_mask: any): { - input: any; - model_input_ids: any; - attention_mask: any; - past_key_values: any; - output_token_ids: any[]; - num_output_tokens: any; - done: boolean; - score: number; - id: number; - }[]; + /** + * Initializes and returns the beam for text generation task + * @param {Tensor} inputTokenIds - The input token ids. + * @param {number} numOutputTokens - The number of tokens to be generated. + * @param {Tensor} inputs_attention_mask - Optional input attention mask. + * @returns {any} A Beam object representing the initialized beam. + */ + getStartBeams(inputTokenIds: Tensor, numOutputTokens: number, inputs_attention_mask: Tensor): any; + /** + * Runs beam search for text generation given a beam. + * @param {any} beam - The Beam object representing the beam. + * @returns {Promise} A Beam object representing the updated beam after running beam search. + */ runBeam(beam: any): Promise; - updateBeam(beam: any, newTokenId: any): void; - forward(model_inputs: any): Promise<{ - logits: any; - past_key_values: any; - }>; + /** + * Updates the given beam with the new generated token id. + * @param {any} beam - The Beam object representing the beam. + * @param {number} newTokenId - The new generated token id to be added to the beam. + */ + updateBeam(beam: any, newTokenId: number): void; } +/** + * CodeGenForCausalLM is a class that represents a code generation model based on the GPT-2 architecture. It extends the `CodeGenPreTrainedModel` class. + * @extends CodeGenPreTrainedModel + */ declare class CodeGenForCausalLM extends CodeGenPreTrainedModel { num_heads: any; num_layers: any; dim_kv: number; - getStartBeams(inputTokenIds: any, numOutputTokens: any, inputs_attention_mask: any): { - input: any; - model_input_ids: any; - attention_mask: any; - past_key_values: any; - output_token_ids: any[]; - num_output_tokens: any; - done: boolean; - score: number; - id: number; - }[]; + /** + * Initializes and returns the beam for text generation task + * @param {Tensor} inputTokenIds - The input token ids. + * @param {number} numOutputTokens - The number of tokens to be generated. + * @param {Tensor} inputs_attention_mask - Optional input attention mask. + * @returns {any} A Beam object representing the initialized beam. + */ + getStartBeams(inputTokenIds: Tensor, numOutputTokens: number, inputs_attention_mask: Tensor): any; + /** + * Runs beam search for text generation given a beam. + * @param {any} beam - The Beam object representing the beam. + * @returns {Promise} A Beam object representing the updated beam after running beam search. + */ runBeam(beam: any): Promise; - updateBeam(beam: any, newTokenId: any): void; - forward(model_inputs: any): Promise<{ - logits: any; - past_key_values: any; - }>; -} -declare class BertForQuestionAnswering extends BertPreTrainedModel { - _call(model_inputs: any): Promise; -} -declare class DistilBertForQuestionAnswering extends DistilBertPreTrainedModel { - _call(model_inputs: any): Promise; -} -declare class AlbertForQuestionAnswering extends AlbertPreTrainedModel { - _call(model_inputs: any): Promise; -} -declare class RobertaForQuestionAnswering extends RobertaPreTrainedModel { - _call(model_inputs: any): Promise; + /** + * Updates the given beam with the new generated token id. + * @param {any} beam - The Beam object representing the beam. + * @param {number} newTokenId - The new generated token id to be added to the beam. + */ + updateBeam(beam: any, newTokenId: number): void; } +/** + * Vision Encoder-Decoder model based on OpenAI's GPT architecture for image captioning and other vision tasks + * @extends PreTrainedModel + */ declare class VisionEncoderDecoderModel extends PreTrainedModel { - static from_pretrained(modelPath: any, progressCallback?: any): Promise; - constructor(config: any, session: any, decoder_merged_session: any); + /** + * Loads a VisionEncoderDecoderModel from the given path. + * + * @param {string} modelPath - The path to the folder containing the saved model files. + * @param {function} [progressCallback=null] - Optional callback function to track the progress of model loading. + * @returns {Promise} A Promise that resolves with the loaded VisionEncoderDecoderModel instance. + */ + static from_pretrained(modelPath: string, progressCallback?: Function): Promise; + /** + * @param {object} config - The configuration object specifying the hyperparameters and other model settings. + * @param {object} session - The TensorFlow.js session containing the encoder model. + * @param {any} decoder_merged_session - The TensorFlow.js session containing the merged decoder model. + */ + constructor(config: object, session: object, decoder_merged_session: any); decoder_merged_session: any; num_layers: any; num_heads: any; dim_kv: number; - getStartBeams(inputs: any, numOutputTokens: any, ...args: any[]): { - inputs: any; - encoder_outputs: any; - past_key_values: any; - output_token_ids: any[]; - done: boolean; - score: number; - id: number; - }[]; + /** + * Generate beam search outputs for the given input pixels and number of output tokens. + * + * @param {array} inputs - The input pixels as a Tensor. + * @param {number} numOutputTokens - The number of output tokens to generate. + * @param {...*} args - Optional additional arguments to pass to seq2seqStartBeams. + * @returns {any} An array of Beam objects representing the top-K output sequences. + */ + getStartBeams(inputs: any[], numOutputTokens: number, ...args: any[]): any; + /** + * Generate the next beam step for the given beam. + * + * @param {any} beam - The current beam. + * @returns {Promise} The updated beam with the additional predicted token ID. + */ runBeam(beam: any): Promise; - updateBeam(beam: any, newTokenId: any): void; - forward(model_inputs: any): Promise; + /** + * Update the given beam with the additional predicted token ID. + * + * @param {any} beam - The current beam. + * @param {number} newTokenId - The new predicted token ID to add to the beam's output sequence. + */ + updateBeam(beam: any, newTokenId: number): void; } +/** + * Vision Transformer model for image classification tasks. + * @extends PreTrainedModel + */ declare class ViTForImageClassification extends PreTrainedModel { - _call(model_inputs: any): Promise; + /** + * Runs a forward pass of the model. + * @param {object} model_inputs - Inputs to the model. + * @returns {Promise} - Output of the model. + */ + _call(model_inputs: object): Promise; } declare class T5PreTrainedModel extends PreTrainedModel { } -declare class Seq2SeqLMOutput { - constructor(logits: any, past_key_values: any, encoder_outputs: any); - logits: any; - past_key_values: any; - encoder_outputs: any; -} +import { GenerationConfig } from "./generation.js"; import { Callable } from "./utils.js"; import { Tensor } from "./tensor_utils.js"; -import { GenerationConfig } from "./generation.js"; import { LogitsProcessorList } from "./generation.js"; declare class BartPretrainedModel extends PreTrainedModel { } @@ -249,21 +490,14 @@ declare class GPT2PreTrainedModel extends PreTrainedModel { } declare class CodeGenPreTrainedModel extends PreTrainedModel { } -declare class BertPreTrainedModel extends PreTrainedModel { -} +/** + * Output type for a sequence classification model. + */ declare class SequenceClassifierOutput { - constructor(logits: any); - logits: any; -} -declare class DistilBertPreTrainedModel extends PreTrainedModel { -} -declare class AlbertPreTrainedModel extends PreTrainedModel { -} -declare class RobertaPreTrainedModel extends PreTrainedModel { -} -declare class QuestionAnsweringModelOutput { - constructor(start_logits: any, end_logits: any); - start_logits: any; - end_logits: any; + /** + * @param {Tensor} logits + */ + constructor(logits: Tensor); + logits: Tensor; } export {}; diff --git a/src/models.js b/src/models.js index 84bfb6beb..c4b3c221f 100644 --- a/src/models.js +++ b/src/models.js @@ -29,7 +29,13 @@ const ONNXTensor = ONNX.Tensor ////////////////////////////////////////////////// // Helper functions - +/** + * Constructs an InferenceSession using a model file located at the specified path. + * @param {string} modelPath - The path to the directory containing the model file. + * @param {string} fileName - The name of the model file. + * @param {function} [progressCallback=null] - An optional function to track progress during the creation of the session. + * @returns {Promise} - A Promise that resolves to an InferenceSession object. + */ async function constructSession(modelPath, fileName, progressCallback = null) { let buffer = await getModelFile(modelPath, fileName, progressCallback); @@ -41,12 +47,23 @@ async function constructSession(modelPath, fileName, progressCallback = null) { return session } +/** + * Executes an InferenceSession using the specified inputs. + * @param {InferenceSession} session - The InferenceSession object to run. + * @param {Object} inputs - An object that maps input names to input tensors. + * @returns {Promise} - A Promise that resolves to an object that maps output names to output tensors. + */ async function sessionRun(session, inputs) { let output = await session.run(inputs); output = replaceTensors(output); return output; } +/** + * Replaces ONNX Tensor objects with custom Tensor objects to support additional functions. + * @param {Object} obj - The object to replace tensor objects in. + * @returns {Object} - The object with tensor objects replaced by custom Tensor objects. + */ function replaceTensors(obj) { // Convert ONNX Tensors with our custom Tensor class // to support additional functions @@ -58,6 +75,12 @@ function replaceTensors(obj) { return obj; } +/** + * Prepares an attention mask for a sequence of tokens based on configuration options. + * @param {Object} self - The calling object instance. + * @param {Tensor} tokens - The input tokens. + * @returns {Tensor} - The attention mask tensor. + */ function _prepare_attention_mask(self, tokens) { // Prepare attention mask @@ -85,13 +108,23 @@ function _prepare_attention_mask(self, tokens) { } } +/** + * Creates a boolean tensor with a single value. + * @param {boolean} value - The value of the tensor. + * @returns {Tensor} - The boolean tensor. + */ function boolTensor(value) { // Create boolean tensor return new Tensor('bool', [value], [1]); } // JS doesn't support mixings, so we define some reused functions here, and allow "this" to be passed in - +/** + * Loads a sequence-to-sequence model from the specified path. + * @param {string} modelPath - The path to the model directory. + * @param {function} progressCallback - The optional progress callback function. + * @returns {Promise} - A promise that resolves with information about the loaded model. + */ async function seq2seqLoadModel(modelPath, progressCallback) { let info = await Promise.all([ fetchJSON(modelPath, 'config.json', progressCallback), @@ -108,6 +141,18 @@ async function seq2seqLoadModel(modelPath, progressCallback) { return info; } + +/** + * Perform forward pass on the seq2seq model. + * @async + * @function + * @param {Object} self - The seq2seq model object. + * @param {Object} model_inputs - The input object for the model containing encoder and decoder inputs. + * @param {Object} options - The options + * @param {string} [options.encoder_input_name='input_ids'] - The name of the input tensor for the encoder. + * @param {boolean} [options.add_decoder_pkv=true] - Flag to add the decoder past key values. + * @returns {Promise} - Promise that resolves with the output of the seq2seq model. + */ async function seq2seq_forward(self, model_inputs, { encoder_input_name = 'input_ids', add_decoder_pkv = true @@ -144,6 +189,15 @@ async function seq2seq_forward(self, model_inputs, { return new Seq2SeqLMOutput(logits, pastKeyValues, encoderOutputs); } +/** + * Start the beam search process for the seq2seq model. + * @function + * @param {Object} self - The seq2seq model object. + * @param {Array} inputTokenIds - Array of input token ids for each input sequence. + * @param {number} numOutputTokens - The maximum number of output tokens for the model. + * @param {boolean} [requires_attention_mask=true] - Flag to indicate if the model requires an attention mask. + * @returns {Array} - Array of beam search objects. + */ function seq2seqStartBeams(self, inputTokenIds, numOutputTokens, requires_attention_mask = true) { let beams = []; let beamId = 0; @@ -176,6 +230,16 @@ function seq2seqStartBeams(self, inputTokenIds, numOutputTokens, requires_attent return beams; } +/** + * Run beam search on the seq2seq model for a single beam. + * @async + * @function + * @param {Object} self - The seq2seq model object. + * @param {Object} beam - The beam search object for which to run the model. + * @param {Object} options - options + * @param {string} [options.input_name='input_ids'] - The name of the input tensor for the encoder. + * @returns {Promise} - Promise that resolves with the output of the seq2seq model for the given beam. + */ async function seq2seqRunBeam(self, beam, { input_name = 'input_ids', } = {} @@ -201,6 +265,14 @@ async function seq2seqRunBeam(self, beam, { return output; } +/** + * Forward pass of the text generation model. + * @async + * @function + * @param {Object} self - The text generation model object. + * @param {Object} model_inputs - The input data to be used for the forward pass. + * @returns {Promise} - Promise that resolves with an object containing the logits and past key values. + */ async function textgen_forward(self, model_inputs) { let past_key_values = model_inputs.past_key_values; let decoderFeeds = { @@ -217,6 +289,14 @@ async function textgen_forward(self, model_inputs) { return { logits, past_key_values }; } +/** + * Starts the generation of text by initializing the beams for the given input token IDs. + * @param {Object} self - The text generation model object. + * @param {Tensor} inputTokenIds - An array of input token IDs to generate text from. + * @param {number} numOutputTokens - The maximum number of tokens to generate for each beam. + * @param {Tensor} [inputs_attention_mask] - The attention mask tensor for the input token IDs. + * @returns {Object[]} An array of beams initialized with the given inputs and parameters. + */ function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attention_mask) { let beams = []; @@ -255,6 +335,20 @@ function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attentio return beams; } +/** + * Runs a single step of the text generation process for a given beam. + * + * @async + * @function textgenRunBeam + * @param {Object} self - The textgen object. + * @param {Object} beam - The beam to run. + * @param {Tensor} beam.input - The input tensor. + * @param {Tensor} beam.model_input_ids - The input ids to the model. + * @param {Tensor} beam.attention_mask - The attention mask. + * @param {Object} beam.past_key_values - The past key values. + * @param {number[]} beam.output_token_ids - The output token ids. + * @returns {Promise} The output of the generation step. + */ async function textgenRunBeam(self, beam) { let attnMaskData = new BigInt64Array(beam.input.data.length + beam.output_token_ids.length).fill(1n) @@ -278,6 +372,11 @@ async function textgenRunBeam(self, beam) { return output; } +/** + * Update a beam with a new token ID. + * @param {object} beam - The beam to update. + * @param {number} newTokenId - The new token ID to add to the beam's output. +*/ function textgenUpdatebeam(beam, newTokenId) { beam.output_token_ids = [...beam.output_token_ids, newTokenId]; beam.model_input_ids = new Tensor('int64', [BigInt(newTokenId)], [1, 1]); @@ -286,7 +385,15 @@ function textgenUpdatebeam(beam, newTokenId) { ////////////////////////////////////////////////// // Base class +/** + * A base class for pre-trained models that provides the model configuration and a TensorFlow session. + * @extends Callable + */ class PreTrainedModel extends Callable { + /** + * @param {object} config - The configuration object for the model. + * @param {any} session - The TensorFlow session for running inference. + */ constructor(config, session) { super(); @@ -295,6 +402,10 @@ class PreTrainedModel extends Callable { } + /** + * Disposes of all the ONNX sessions that were created during inference. + * @returns {Promise>} - An array of promises, one for each ONNX session that is being disposed. + */ async dispose() { // Dispose of all ONNX sessions sessions // TODO use: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry @@ -309,6 +420,14 @@ class PreTrainedModel extends Callable { return await Promise.all(promises); } + /** + * Loads a pre-trained model from the given modelPath. + * @static + * @async + * @param {string} modelPath - The path to the pre-trained model. + * @param {function} progressCallback - A function to be called with progress updates. + * @returns {Promise} A new instance of the PreTrainedModel class. + */ static async from_pretrained(modelPath, progressCallback = null) { let config = await fetchJSON(modelPath, 'config.json', progressCallback); @@ -326,6 +445,12 @@ class PreTrainedModel extends Callable { return new this(config, session); } + /** + * Converts an array or Tensor of integers to an int64 Tensor. + * @param {Array|Tensor} items - The input integers to be converted. + * @returns {Tensor} The int64 Tensor with the converted values. + * @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length. + */ toI64Tensor(items) { if (items instanceof Tensor) { return items; @@ -354,10 +479,22 @@ class PreTrainedModel extends Callable { } } + /** + * Runs the model with the provided inputs + * @param {Object} model_inputs - Object containing input tensors + * @returns {Promise} - Object containing output tensors + */ async _call(model_inputs) { return await sessionRun(this.session, model_inputs); } + /** + * Forward method should be implemented in subclasses. + * @abstract + * @param {object} model_inputs - The input data to the model in the format specified in the ONNX model. + * @returns {Promise} - The output data from the model in the format specified in the ONNX model. + * @throws {Error} - This method must be implemented in subclasses. + */ async forward(model_inputs) { throw Error("forward should be implemented in subclasses.") } @@ -486,6 +623,13 @@ class PreTrainedModel extends Callable { return processors; } + /** + * This function merges multiple generation configs together to form a final generation config to be used by the model for text generation. + * It first creates an empty `GenerationConfig` object, then it applies the model's own `generation_config` property to it. Finally, if a `generation_config` object was passed in the arguments, it overwrites the corresponding properties in the final config with those of the passed config object. + * + * @param {GenerationConfig} generation_config - A `GenerationConfig` object containing generation parameters. + * @returns {GenerationConfig} The final generation config object to be used by the model for text generation. + */ _get_generation_config(generation_config) { // Create empty generation config (contains defaults) let gen_config = new GenerationConfig(); @@ -500,6 +644,17 @@ class PreTrainedModel extends Callable { } return gen_config; } + + /** + * Generates text based on the given inputs and generation configuration using the model. + * @param {Array} inputs - An array of input token IDs. + * @param {Object|null} generation_config - The generation configuration to use. If null, default configuration will be used. + * @param {Object|null} logits_processor - An optional logits processor to use. If null, a new LogitsProcessorList instance will be created. + * @param {Object} options - options + * @param {Object} [options.inputs_attention_mask=null] - An optional attention mask for the inputs. + * @returns {Promise} An array of generated output sequences, where each sequence is an array of token IDs. + * @throws {Error} Throws an error if the inputs array is empty. + */ async generate( inputs, generation_config = null, @@ -597,6 +752,13 @@ class PreTrainedModel extends Callable { } ) } + + /** + * Groups an array of beam objects by their ids. + * + * @param {Array} beams - The array of beam objects to group. + * @returns {Array} - An array of arrays, where each inner array contains beam objects with the same id. + */ groupBeams(beams) { // Group beams by their ids const groups = {}; @@ -610,6 +772,13 @@ class PreTrainedModel extends Callable { return Object.values(groups); } + + /** + * Returns an object containing past key values from the given decoder results object. + * + * @param {Object} decoderResults - The decoder results object. + * @returns {Object} - An object containing past key values. + */ getPastKeyValues(decoderResults) { const pkvs = {}; @@ -620,6 +789,14 @@ class PreTrainedModel extends Callable { } return pkvs; } + + /** + * Adds past key values to the decoder feeds object. If pastKeyValues is null, creates new tensors for past key values. + * + * @param {Object} decoderFeeds - The decoder feeds object to add past key values to. + * @param {Object} pastKeyValues - An object containing past key values. + * @param {boolean} [hasDecoder=false] - Whether the model has a decoder. + */ addPastKeyValues(decoderFeeds, pastKeyValues, hasDecoder = false) { if (pastKeyValues === null) { // TODO support batches (i.e., batch_size > 1) @@ -654,19 +831,49 @@ class PreTrainedModel extends Callable { // Bert models class BertPreTrainedModel extends PreTrainedModel { } class BertModel extends BertPreTrainedModel { } +/** + * BertForMaskedLM is a class representing a BERT model for masked language modeling. + * @extends BertPreTrainedModel + */ class BertForMaskedLM extends BertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs - The inputs to the model. + * @returns {Promise} - An object containing the model's output logits for masked language modeling. + */ async _call(model_inputs) { let logits = (await super._call(model_inputs)).logits; return new MaskedLMOutput(logits) } } +/** + * BertForSequenceClassification is a class representing a BERT model for sequence classification. + * @extends BertPreTrainedModel + */ class BertForSequenceClassification extends BertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs - The inputs to the model. + * @returns {Promise} - An object containing the model's output logits for sequence classification. + */ async _call(model_inputs) { let logits = (await super._call(model_inputs)).logits; return new SequenceClassifierOutput(logits) } } +/** + * BertForQuestionAnswering is a class representing a BERT model for question answering. + * @extends BertPreTrainedModel + */ class BertForQuestionAnswering extends BertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs - The inputs to the model. + * @returns {Promise} - An object containing the model's output logits for question answering. + */ async _call(model_inputs) { let outputs = await super._call(model_inputs); return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); @@ -678,19 +885,49 @@ class BertForQuestionAnswering extends BertPreTrainedModel { // DistilBert models class DistilBertPreTrainedModel extends PreTrainedModel { } class DistilBertModel extends DistilBertPreTrainedModel { } +/** + * DistilBertForSequenceClassification is a class representing a DistilBERT model for sequence classification. + * @extends DistilBertPreTrainedModel + */ class DistilBertForSequenceClassification extends DistilBertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs - The inputs to the model. + * @returns {Promise} - An object containing the model's output logits for question answering. + */ async _call(model_inputs) { let logits = (await super._call(model_inputs)).logits; return new SequenceClassifierOutput(logits) } } +/** + * DistilBertForQuestionAnswering is a class representing a DistilBERT model for question answering. + * @extends DistilBertPreTrainedModel + */ class DistilBertForQuestionAnswering extends DistilBertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs - The inputs to the model. + * @returns {Promise} - An object containing the model's output logits for question answering. + */ async _call(model_inputs) { let outputs = await super._call(model_inputs); return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); } } +/** + * DistilBertForMaskedLM is a class representing a DistilBERT model for masking task. + * @extends DistilBertPreTrainedModel + */ class DistilBertForMaskedLM extends DistilBertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs - The inputs to the model. + * @returns {Promise} - returned object + */ async _call(model_inputs) { let logits = (await super._call(model_inputs)).logits; return new MaskedLMOutput(logits) @@ -702,19 +939,49 @@ class DistilBertForMaskedLM extends DistilBertPreTrainedModel { // DistilBert models class AlbertPreTrainedModel extends PreTrainedModel { } class AlbertModel extends AlbertPreTrainedModel { } +/** + * AlbertForSequenceClassification is a class representing an Albert model for sequence classification. + * @extends AlbertPreTrainedModel + */ class AlbertForSequenceClassification extends AlbertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs - The inputs to the model. + * @returns {Promise} - returned object + */ async _call(model_inputs) { let logits = (await super._call(model_inputs)).logits; return new SequenceClassifierOutput(logits) } } +/** + * AlbertForQuestionAnswering is a class representing an Albert model for question answering. + * @extends AlbertPreTrainedModel + */ class AlbertForQuestionAnswering extends AlbertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs - The inputs to the model. + * @returns {Promise} - returned object + */ async _call(model_inputs) { let outputs = await super._call(model_inputs); return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); } } +/** + * AlbertForMaskedLM is a class representing an Albert model for masking task. + * @extends AlbertPreTrainedModel + */ class AlbertForMaskedLM extends AlbertPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs - The inputs to the model. + * @returns {Promise} - returned object + */ async _call(model_inputs) { let logits = (await super._call(model_inputs)).logits; return new MaskedLMOutput(logits) @@ -726,8 +993,16 @@ class AlbertForMaskedLM extends AlbertPreTrainedModel { ////////////////////////////////////////////////// // T5 models class T5PreTrainedModel extends PreTrainedModel { }; - +/** + * T5Model is a class representing a T5 model without a language model head. + * @extends T5PreTrainedModel + */ class T5Model extends T5PreTrainedModel { + /** + * Generates text based on the provided arguments. + * + * @throws {Error} - Throws an error as the current model class (T5Model) is not compatible with `.generate()`. + */ async generate(...args) { throw Error( "The current model class (T5Model) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'T5ForConditionalGeneration'}" @@ -735,7 +1010,17 @@ class T5Model extends T5PreTrainedModel { } } +/** + * T5Model is a class representing a T5 model for conditional generation. + * @extends T5PreTrainedModel + */ class T5ForConditionalGeneration extends T5PreTrainedModel { + /** + * @param {object} config - The model configuration. + * @param {any} session - session for the model. + * @param {any} decoder_merged_session - session for the decoder. + * @param {GenerationConfig} generation_config - The generation configuration. + */ constructor(config, session, decoder_merged_session, generation_config) { super(config, session); this.decoder_merged_session = decoder_merged_session; @@ -750,22 +1035,53 @@ class T5ForConditionalGeneration extends T5PreTrainedModel { this.encoder_dim_kv = this.config.d_kv; } + /** + * Loads the pre-trained model from a given path. + * @async + * @param {string} modelPath - The path to the pre-trained model. + * @param {function} progressCallback - A function to call with progress updates (optional). + * @returns {Promise} The loaded model instance. + */ static async from_pretrained(modelPath, progressCallback = null) { let info = await seq2seqLoadModel(modelPath, progressCallback); return new this(...info); } + /** + * Generates the start beams for a given set of inputs and output length. + * @param {number[][]} inputs - The input token IDs. + * @param {number} numOutputTokens - The desired output length. + * @returns {Array} The start beams. + */ getStartBeams(inputs, numOutputTokens, ...args) { return seq2seqStartBeams(this, inputs, numOutputTokens); } + /** + * Runs the beam search for a given beam. + * @async + * @param {any} beam - The current beam. + * @returns {Promise} The model output. + */ async runBeam(beam) { return await seq2seqRunBeam(this, beam); } + + /** + * Updates the given beam with a new token ID. + * @param {any} beam - The current beam. + * @param {number} newTokenId - The new token ID to add to the output sequence. + */ updateBeam(beam, newTokenId) { beam.output_token_ids = [...beam.output_token_ids, newTokenId]; } + /** + * Runs the forward pass of the model for a given set of inputs. + * @async + * @param {Object} model_inputs - The model inputs. + * @returns {Promise} The model output. + */ async forward(model_inputs) { return await seq2seq_forward(this, model_inputs); } @@ -775,8 +1091,20 @@ class T5ForConditionalGeneration extends T5PreTrainedModel { ////////////////////////////////////////////////// // Bart models class BartPretrainedModel extends PreTrainedModel { }; - +/** + * BART encoder and decoder model. + * + * @hideconstructor + * @extends BartPretrainedModel + */ class BartModel extends BartPretrainedModel { + /** + * Throws an error because the current model class (BartModel) is not compatible with `.generate()`. + * + * @async + * @throws {Error} The current model class (BartModel) is not compatible with `.generate()`. + * @returns {Promise} + */ async generate(...args) { throw Error( "The current model class (BartModel) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'BartForConditionalGeneration'}" @@ -784,7 +1112,18 @@ class BartModel extends BartPretrainedModel { } } +/** + * BART model with a language model head for conditional generation. + * @extends BartPretrainedModel + */ class BartForConditionalGeneration extends BartPretrainedModel { + /** + * Create a new BartForConditionalGeneration instance. + * @param {object} config - The configuration object for the Bart model. + * @param {object} session - The TensorFlow.js session used to execute the model. + * @param {object} decoder_merged_session - The TensorFlow.js session used to execute the decoder. + * @param {object} generation_config - The generation configuration object. + */ constructor(config, session, decoder_merged_session, generation_config) { super(config, session); this.decoder_merged_session = decoder_merged_session; @@ -799,22 +1138,52 @@ class BartForConditionalGeneration extends BartPretrainedModel { this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } + /** + * Loads a BartForConditionalGeneration instance from a pretrained model stored on disk. + * @param {string} modelPath - The path to the directory containing the pretrained model. + * @param {function} [progressCallback=null] - An optional callback function to track the download progress. + * @returns {Promise} - The pretrained BartForConditionalGeneration instance. + */ static async from_pretrained(modelPath, progressCallback = null) { let info = await seq2seqLoadModel(modelPath, progressCallback); return new this(...info); } + /** + * Returns the initial beam for generating output text. + * @param {object} inputs - The input object containing the encoded input text. + * @param {number} numOutputTokens - The maximum number of output tokens to generate. + * @param {...any} args - Additional arguments to pass to the sequence-to-sequence generation function. + * @returns {any} - The initial beam for generating output text. + */ getStartBeams(inputs, numOutputTokens, ...args) { return seq2seqStartBeams(this, inputs, numOutputTokens); } + /** + * Runs a single step of the beam search generation algorithm. + * @param {any} beam - The current beam being generated. + * @returns {Promise} - The updated beam after a single generation step. + */ async runBeam(beam) { return await seq2seqRunBeam(this, beam); } + + /** + * Updates the beam by appending the newly generated token ID to the list of output token IDs. + * @param {any} beam - The current beam being generated. + * @param {number} newTokenId - The ID of the newly generated token to append to the list of output token IDs. + */ updateBeam(beam, newTokenId) { beam.output_token_ids = [...beam.output_token_ids, newTokenId]; } + /** + * Runs the forward pass of the model for a given set of inputs. + * @async + * @param {Object} model_inputs - The model inputs. + * @returns {Promise} The model output. + */ async forward(model_inputs) { return await seq2seq_forward(this, model_inputs); } @@ -826,19 +1195,47 @@ class BartForConditionalGeneration extends BartPretrainedModel { // Roberta models class RobertaPreTrainedModel extends PreTrainedModel { } class RobertaModel extends RobertaPreTrainedModel { } + +/** + * RobertaForMaskedLM class for performing masked language modeling on Roberta models. + * @extends RobertaPreTrainedModel + */ class RobertaForMaskedLM extends RobertaPreTrainedModel { + /** + * Calls parent _call function and returns MaskedLMOutput object with logits. + * @param {Object} model_inputs - Input data for the model. + * @returns {Promise} Promise object represents the masked language modeling output. + */ async _call(model_inputs) { let logits = (await super._call(model_inputs)).logits; return new MaskedLMOutput(logits) } } +/** + * RobertaForSequenceClassification class for performing sequence classification on Roberta models. + * @extends RobertaPreTrainedModel + */ class RobertaForSequenceClassification extends RobertaPreTrainedModel { + /** + * Calls parent _call function and returns SequenceClassifierOutput object with logits. + * @param {Object} model_inputs - Input data for the model. + * @returns {Promise} Promise object represents the sequence classification output. + */ async _call(model_inputs) { let logits = (await super._call(model_inputs)).logits; return new SequenceClassifierOutput(logits) } } +/** + * RobertaForQuestionAnswering class for performing question answering on Roberta models. + * @extends RobertaPreTrainedModel + */ class RobertaForQuestionAnswering extends RobertaPreTrainedModel { + /** + * Calls parent _call function and returns QuestionAnsweringModelOutput object with start and end logits. + * @param {Object} model_inputs - Input data for the model. + * @returns {Promise} Promise object represents the question answering output. + */ async _call(model_inputs) { let outputs = await super._call(model_inputs); return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); @@ -849,16 +1246,33 @@ class RobertaForQuestionAnswering extends RobertaPreTrainedModel { ////////////////////////////////////////////////// // T5 models class WhisperPreTrainedModel extends PreTrainedModel { }; - +/** + * WhisperModel class for training Whisper models without a language model head. + * @extends WhisperPreTrainedModel + */ class WhisperModel extends WhisperPreTrainedModel { + /** + * Throws an error when attempting to generate output since this model doesn't have a language model head. + * @throws Error + */ async generate(...args) { throw Error( "The current model class (WhisperModel) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'WhisperForConditionalGeneration'}" ) } } - +/** + * WhisperForConditionalGeneration class for generating conditional outputs from Whisper models. + * @extends WhisperPreTrainedModel + */ class WhisperForConditionalGeneration extends WhisperPreTrainedModel { + /** + * Initializes the WhisperForConditionalGeneration object. + * @param {Object} config - Configuration object for the model. + * @param {Object} session - TensorFlow.js Session object for the model. + * @param {Object} decoder_merged_session - TensorFlow.js Session object for the decoder. + * @param {Object} generation_config - Configuration object for the generation process. + */ constructor(config, session, decoder_merged_session, generation_config) { super(config, session); this.decoder_merged_session = decoder_merged_session; @@ -874,7 +1288,13 @@ class WhisperForConditionalGeneration extends WhisperPreTrainedModel { } - + /** + * Generates outputs based on input and generation configuration. + * @param {Object} inputs - Input data for the model. + * @param {Object} generation_config - Configuration object for the generation process. + * @param {Object} logits_processor - Optional logits processor object. + * @returns {Promise} Promise object represents the generated outputs. + */ async generate( inputs, generation_config = null, @@ -903,25 +1323,54 @@ class WhisperForConditionalGeneration extends WhisperPreTrainedModel { return super.generate(inputs, generation_config, logits_processor) } + /** + * Loads a pre-trained model from a saved model directory. + * @param {string} modelPath - Path to the saved model directory. + * @param {function} progressCallback - Optional function for tracking loading progress. + * @returns {Promise} Promise object represents the loaded model. + */ static async from_pretrained(modelPath, progressCallback = null) { let info = await seq2seqLoadModel(modelPath, progressCallback); return new this(...info); } + /** + * Gets the start beams for generating outputs. + * @param {Array} inputTokenIds - Array of input token IDs. + * @param {number} numOutputTokens - Number of output tokens to generate. + * @returns {Array} Array of start beams. + */ getStartBeams(inputTokenIds, numOutputTokens, ...args) { // arguments ignored in this case return seq2seqStartBeams(this, inputTokenIds, numOutputTokens, false); } + /** + * Runs a beam for generating outputs. + * @param {Object} beam - Beam object. + * @returns {Promise} Promise object represents the generated outputs for the beam. + */ async runBeam(beam) { return await seq2seqRunBeam(this, beam, { input_name: 'input_features', }); } + + /** + * Updates the beam by appending the newly generated token ID to the list of output token IDs. + * @param {any} beam - The current beam being generated. + * @param {number} newTokenId - The ID of the newly generated token to append to the list of output token IDs. + */ updateBeam(beam, newTokenId) { beam.output_token_ids = [...beam.output_token_ids, newTokenId]; } + /** + * Runs the forward pass of the model for a given set of inputs. + * @async + * @param {Object} model_inputs - The model inputs. + * @returns {Promise} The model output. + */ async forward(model_inputs) { return await seq2seq_forward(this, model_inputs, { encoder_input_name: 'input_features', @@ -931,7 +1380,16 @@ class WhisperForConditionalGeneration extends WhisperPreTrainedModel { ////////////////////////////////////////////////// ////////////////////////////////////////////////// +/** + * Vision Encoder-Decoder model based on OpenAI's GPT architecture for image captioning and other vision tasks + * @extends PreTrainedModel + */ class VisionEncoderDecoderModel extends PreTrainedModel { + /** + * @param {object} config - The configuration object specifying the hyperparameters and other model settings. + * @param {object} session - The TensorFlow.js session containing the encoder model. + * @param {any} decoder_merged_session - The TensorFlow.js session containing the merged decoder model. + */ constructor(config, session, decoder_merged_session) { super(config, session); this.decoder_merged_session = decoder_merged_session; @@ -941,6 +1399,13 @@ class VisionEncoderDecoderModel extends PreTrainedModel { this.dim_kv = this.config.decoder.n_embd / this.num_heads; } + /** + * Loads a VisionEncoderDecoderModel from the given path. + * + * @param {string} modelPath - The path to the folder containing the saved model files. + * @param {function} [progressCallback=null] - Optional callback function to track the progress of model loading. + * @returns {Promise} A Promise that resolves with the loaded VisionEncoderDecoderModel instance. + */ static async from_pretrained(modelPath, progressCallback = null) { let [config, session, decoder_merged_session] = await Promise.all([ @@ -958,19 +1423,45 @@ class VisionEncoderDecoderModel extends PreTrainedModel { return new this(config, session, decoder_merged_session); } + /** + * Generate beam search outputs for the given input pixels and number of output tokens. + * + * @param {array} inputs - The input pixels as a Tensor. + * @param {number} numOutputTokens - The number of output tokens to generate. + * @param {...*} args - Optional additional arguments to pass to seq2seqStartBeams. + * @returns {any} An array of Beam objects representing the top-K output sequences. + */ getStartBeams(inputs, numOutputTokens, ...args) { return seq2seqStartBeams(this, inputs, numOutputTokens); } + /** + * Generate the next beam step for the given beam. + * + * @param {any} beam - The current beam. + * @returns {Promise} The updated beam with the additional predicted token ID. + */ async runBeam(beam) { return seq2seqRunBeam(this, beam, { input_name: 'pixel_values', }); } + /** + * Update the given beam with the additional predicted token ID. + * + * @param {any} beam - The current beam. + * @param {number} newTokenId - The new predicted token ID to add to the beam's output sequence. + */ updateBeam(beam, newTokenId) { beam.output_token_ids = [...beam.output_token_ids, newTokenId]; } + /** + * Compute the forward pass of the model on the given input tensors. + * + * @param {object} model_inputs - The input tensors as an object with keys 'pixel_values' and 'decoder_input_ids'. + * @returns {Promise} The output tensor of the model. + */ async forward(model_inputs) { return await seq2seq_forward(this, model_inputs, { encoder_input_name: 'pixel_values', @@ -992,6 +1483,10 @@ class CLIPModel extends CLIPPreTrainedModel { ////////////////////////////////////////////////// // GPT2 models class GPT2PreTrainedModel extends PreTrainedModel { } +/** + * GPT2Model is not compatible with `.generate()`, as it doesn't have a language model head. + * @extends GPT2PreTrainedModel + */ class GPT2Model extends GPT2PreTrainedModel { async generate(...args) { throw Error( @@ -999,8 +1494,15 @@ class GPT2Model extends GPT2PreTrainedModel { ) } } - +/** + * GPT-2 language model head on top of the GPT-2 base model. This model is suitable for text generation tasks. + * @extends GPT2PreTrainedModel + */ class GPT2LMHeadModel extends GPT2PreTrainedModel { + /** + * @param {object} config - The configuration of the model. + * @param {any} session - The ONNX session containing the model weights. + */ constructor(config, session) { super(config, session); @@ -1012,19 +1514,40 @@ class GPT2LMHeadModel extends GPT2PreTrainedModel { this.dim_kv = this.config.n_embd / this.num_heads; } + /** + * Initializes and returns the beam for text generation task + * @param {Tensor} inputTokenIds - The input token ids. + * @param {number} numOutputTokens - The number of tokens to be generated. + * @param {Tensor} inputs_attention_mask - Optional input attention mask. + * @returns {any} A Beam object representing the initialized beam. + */ getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) { return textgenStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) } - + /** + * Runs beam search for text generation given a beam. + * @param {any} beam - The Beam object representing the beam. + * @returns {Promise} A Beam object representing the updated beam after running beam search. + */ async runBeam(beam) { return await textgenRunBeam(this, beam); } + /** + * Updates the given beam with the new generated token id. + * @param {any} beam - The Beam object representing the beam. + * @param {number} newTokenId - The new generated token id to be added to the beam. + */ updateBeam(beam, newTokenId) { return textgenUpdatebeam(beam, newTokenId); } + /** + * Forward pass for the model. + * @param {object} model_inputs - The inputs for the model. + * @returns {Promise} The output tensor of the model. + */ async forward(model_inputs) { return await textgen_forward(this, model_inputs) } @@ -1039,7 +1562,20 @@ class GPT2LMHeadModel extends GPT2PreTrainedModel { ////////////////////////////////////////////////// // CodeGen models class CodeGenPreTrainedModel extends PreTrainedModel { } +/** + * CodeGenModel is a class representing a code generation model without a language model head. + * + * @extends CodeGenPreTrainedModel + */ class CodeGenModel extends CodeGenPreTrainedModel { + /** + * Throws an error indicating that the current model class is not compatible with `.generate()`, + * as it doesn't have a language model head. + * + * @throws {Error} The current model class is not compatible with `.generate()` + * + * @param {...any} args - Arguments passed to the generate function + */ async generate(...args) { throw Error( "The current model class (CodeGenModel) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'CodeGenForCausalLM'}" @@ -1047,7 +1583,15 @@ class CodeGenModel extends CodeGenPreTrainedModel { } } +/** + * CodeGenForCausalLM is a class that represents a code generation model based on the GPT-2 architecture. It extends the `CodeGenPreTrainedModel` class. + * @extends CodeGenPreTrainedModel + */ class CodeGenForCausalLM extends CodeGenPreTrainedModel { + /** + * @param {object} config The model configuration object. + * @param {object} session The ONNX session object. + */ constructor(config, session) { super(config, session); @@ -1059,18 +1603,40 @@ class CodeGenForCausalLM extends CodeGenPreTrainedModel { this.dim_kv = this.config.n_embd / this.num_heads; } + /** + * Initializes and returns the beam for text generation task + * @param {Tensor} inputTokenIds - The input token ids. + * @param {number} numOutputTokens - The number of tokens to be generated. + * @param {Tensor} inputs_attention_mask - Optional input attention mask. + * @returns {any} A Beam object representing the initialized beam. + */ getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) { return textgenStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) } + /** + * Runs beam search for text generation given a beam. + * @param {any} beam - The Beam object representing the beam. + * @returns {Promise} A Beam object representing the updated beam after running beam search. + */ async runBeam(beam) { return await textgenRunBeam(this, beam); } + /** + * Updates the given beam with the new generated token id. + * @param {any} beam - The Beam object representing the beam. + * @param {number} newTokenId - The new generated token id to be added to the beam. + */ updateBeam(beam, newTokenId) { return textgenUpdatebeam(beam, newTokenId); } + /** + * Forward pass for the model. + * @param {object} model_inputs - The inputs for the model. + * @returns {Promise} The output tensor of the model. + */ async forward(model_inputs) { return await textgen_forward(this, model_inputs) } @@ -1079,8 +1645,16 @@ class CodeGenForCausalLM extends CodeGenPreTrainedModel { ////////////////////////////////////////////////// ////////////////////////////////////////////////// +/** + * Vision Transformer model for image classification tasks. + * @extends PreTrainedModel + */ class ViTForImageClassification extends PreTrainedModel { - + /** + * Runs a forward pass of the model. + * @param {object} model_inputs - Inputs to the model. + * @returns {Promise} - Output of the model. + */ async _call(model_inputs) { let logits = (await super._call(model_inputs)).logits; return new SequenceClassifierOutput(logits) @@ -1093,9 +1667,17 @@ class ViTForImageClassification extends PreTrainedModel { ////////////////////////////////////////////////// // AutoModels, used to simplify construction of PreTrainedModels // (uses config to instantiate correct class) +/** + * Helper class to determine model type from config + */ class AutoModel { // Helper class to determine model type from config - + /** + * Instantiates a pre-trained model based on the given model path and config. + * @param {string} modelPath - The path to the pre-trained model. + * @param {function} progressCallback - Optional. A callback function that can be used to track loading progress. + * @returns {Promise} - A promise that resolves to an instance of a pre-trained model. + */ static async from_pretrained(modelPath, progressCallback = null) { let config = await fetchJSON(modelPath, 'config.json', progressCallback); @@ -1137,9 +1719,17 @@ class AutoModel { } } } - +/** + * Helper class for loading sequence classification models from pretrained checkpoints + */ class AutoModelForSequenceClassification { - + /** + * Load a sequence classification model from a pretrained checkpoint + * @param {string} modelPath - The path to the model checkpoint directory + * @param {function} [progressCallback=null] - An optional callback function to receive progress updates + * @returns {Promise} A promise that resolves to a pre-trained sequence classification model + * @throws {Error} if an unsupported model type is encountered + */ static async from_pretrained(modelPath, progressCallback = null) { let [config, session] = await Promise.all([ @@ -1168,13 +1758,23 @@ class AutoModelForSequenceClassification { } } } - +/** + * Class representing an automatic sequence-to-sequence language model. + */ class AutoModelForSeq2SeqLM { static modelClassMapping = { 't5': T5ForConditionalGeneration, 'bart': BartForConditionalGeneration, 'whisper': WhisperForConditionalGeneration, } + /** + * Loads a pretrained sequence-to-sequence language model from a file path. + * @param {string} modelPath - The path to the model files. + * @param {function} [progressCallback=null] - A callback function to track loading progress. + * @returns {Promise} A Promise that resolves to an instance of the appropriate model class. + * @throws {Error} If the model type is unsupported. + * @static + */ static async from_pretrained(modelPath, progressCallback = null) { let info = await seq2seqLoadModel(modelPath, progressCallback); let config = info[0]; @@ -1185,8 +1785,17 @@ class AutoModelForSeq2SeqLM { return new cls(...info) } } - +/** + * A class for loading pre-trained models for causal language modeling tasks. + */ class AutoModelForCausalLM { + /** + * Loads a pre-trained model from the given path and returns an instance of the appropriate class. + * @param {string} modelPath - The path to the pre-trained model. + * @param {function} [progressCallback=null] - An optional callback function to track the progress of the loading process. + * @returns {Promise} An instance of the appropriate class for the loaded model. + * @throws {Error} If the loaded model type is not supported. + */ static async from_pretrained(modelPath, progressCallback = null) { let [config, session] = await Promise.all([ @@ -1218,9 +1827,19 @@ class AutoModelForCausalLM { } } } - +/** + * A class to automatically select the appropriate model for Masked Language Modeling (MLM) tasks. + */ class AutoModelForMaskedLM { - + /** + * Loads a pre-trained model from a given directory and returns an instance of the appropriate model class. + * + * @async + * @param {string} modelPath - The path to the pre-trained model directory. + * @param {function} [progressCallback=null] - An optional callback function to track the loading progress. + * @returns {Promise} An instance of the appropriate model class for MLM tasks. + * @throws {Error} If an unsupported model type is encountered. + */ static async from_pretrained(modelPath, progressCallback = null) { let config = await fetchJSON(modelPath, 'config.json', progressCallback); @@ -1250,9 +1869,17 @@ class AutoModelForMaskedLM { } } } - +/** + * Automatic model class for question answering tasks. + */ class AutoModelForQuestionAnswering { - + /** + * Loads and returns a question answering model from a pretrained model path. + * @param {string} modelPath - The path to the pretrained model. + * @param {function} [progressCallback=null] - Optional callback function to track loading progress. + * @returns {Promise} - The loaded question answering model. + * @throws Will throw an error if an unsupported model type is encountered. + */ static async from_pretrained(modelPath, progressCallback = null) { let [config, session] = await Promise.all([ @@ -1281,8 +1908,16 @@ class AutoModelForQuestionAnswering { } } } - +/** + * Class representing an autoencoder-decoder model for vision-to-sequence tasks. + */ class AutoModelForVision2Seq { + /** + * Loads a pretrained model from a given path. + * @param {string} modelPath - The path to the pretrained model. + * @param {function} progressCallback - Optional callback function to track progress of the model loading. + * @returns {Promise} - A Promise that resolves to a new instance of VisionEncoderDecoderModel. + */ static async from_pretrained(modelPath, progressCallback = null) { let [config, session, decoder_merged_session] = await Promise.all([ @@ -1309,8 +1944,18 @@ class AutoModelForVision2Seq { } } } - +/** + * AutoModelForImageClassification is a class for loading pre-trained image classification models from ONNX format. + */ class AutoModelForImageClassification { + /** + * Loads a pre-trained image classification model from a given directory path. + * @param {string} modelPath - The path to the directory containing the pre-trained model. + * @param {function} [progressCallback=null] - A callback function to monitor the loading progress. + * @returns {Promise} A Promise that resolves with an instance of the ViTForImageClassification class. + * @throws {Error} If the specified model type is not supported. + */ + static async from_pretrained(modelPath, progressCallback = null) { let [config, session] = await Promise.all([ @@ -1339,27 +1984,52 @@ class AutoModelForImageClassification { ////////////////////////////////////////////////// ////////////////////////////////////////////////// +/** + * Represents the output of a sequence-to-sequence language model. + */ class Seq2SeqLMOutput { + /** + * @param {Tensor} logits - The output logits of the model. + * @param {Array} past_key_values - An array of key/value pairs that represent the previous state of the model. + * @param {Tensor} encoder_outputs - The output of the encoder in a sequence-to-sequence model. + */ constructor(logits, past_key_values, encoder_outputs) { this.logits = logits; this.past_key_values = past_key_values; this.encoder_outputs = encoder_outputs; } } - +/** + * Output type for a sequence classification model. + */ class SequenceClassifierOutput { + /** + * @param {Tensor} logits + */ constructor(logits) { this.logits = logits; } } - +/** + * Output of a masked language modeling (MLM) model. + */ class MaskedLMOutput { + /** + * + * @param {Tensor} logits + */ constructor(logits) { this.logits = logits; } } - +/** + * Output of a Question Answering Model. + */ class QuestionAnsweringModelOutput { + /** + * @param {Float32Array} start_logits - The logits for start positions of the answer. + * @param {Float32Array} end_logits - The logits for end positions of the answer. + */ constructor(start_logits, end_logits) { this.start_logits = start_logits; this.end_logits = end_logits;