From ddab8b5adb174f9bac107d550a0c79f973436f6a Mon Sep 17 00:00:00 2001 From: Brian Joseph Petro Date: Fri, 6 Dec 2024 17:24:16 -0500 Subject: [PATCH] Implement smart-rank-model using updated smart-model patterns - Updated README.md files for both smart-model and smart-rank-model to improve clarity and structure, including new sections on features, usage, and folder structure. - Introduced a new .scignore file in smart-rank-model to manage ignored files for the Smart Context VSCode extension https://github.com/brianpetro/smart-context-vscode - Enhanced adapter management in smart-rank-model by adding new adapters (Cohere, Transformers Worker, and Transformers Iframe) and improving existing ones for better integration and functionality. - Updated package.json to include the latest version of @huggingface/transformers for improved model support. - Refactored the SmartRankModel class to streamline model loading and ranking processes, including better error handling and configuration management. - Improved testing framework for both Cohere and Transformers adapters, ensuring robust validation of ranking functionality across various models. This update significantly enhances the usability and maintainability of the smart-model and smart-rank-model packages, providing a more cohesive development experience. --- smart-model/README.md | 72 +- smart-rank-model/.scignore | 4 + smart-rank-model/README.md | 87 ++- smart-rank-model/adapters.js | 5 +- smart-rank-model/adapters/_adapter.js | 29 +- smart-rank-model/adapters/{api.js => _api.js} | 70 +- smart-rank-model/adapters/_message.js | 106 +++ smart-rank-model/adapters/cohere.js | 33 +- smart-rank-model/adapters/iframe.js | 130 ++-- smart-rank-model/adapters/transformers.js | 116 ++- .../adapters/transformers_iframe.js | 34 +- .../adapters/transformers_worker.js | 23 +- smart-rank-model/adapters/worker.js | 94 ++- smart-rank-model/build/esbuild.js | 18 +- .../build/transformers_iframe_script.js | 25 +- .../build/transformers_worker_script.js | 59 +- .../connectors/transformers_iframe.js | 2 +- .../connectors/transformers_worker.js | 729 ++++++++++++++++-- smart-rank-model/package.json | 1 + smart-rank-model/smart_rank_model.js | 101 ++- smart-rank-model/test/cohere.test.js | 39 +- smart-rank-model/test/transformers.test.js | 102 ++- 22 files changed, 1426 insertions(+), 453 deletions(-) create mode 100644 smart-rank-model/.scignore rename smart-rank-model/adapters/{api.js => _api.js} (74%) create mode 100644 smart-rank-model/adapters/_message.js diff --git a/smart-model/README.md b/smart-model/README.md index ae17777b..843033a8 100644 --- a/smart-model/README.md +++ b/smart-model/README.md @@ -1,41 +1,61 @@ -# Smart Model +# smart-model + +A flexible base class for building "smart" models that handle configuration, state management, adapter loading, and model operations. Designed to provide a standardized interface, it simplifies creating specialized model classes (e.g., language models, ranking models) with different backends and adapters. Base class for smart-*-model packages. +## Features +- Base `SmartModel` class manages: + - Adapter lifecycle (load/unload) + - Settings configuration and schema processing + - State transitions (`unloaded`, `loading`, `loaded`, `unloading`) +- Extensible architecture for multiple adapters +- Centralized settings management and re-rendering triggers -## Usage +## Folder Structure +``` +smart-model +├── adapters +│ └── _adapter.js # Base adapter class +├── components +│ └── settings.js # Helper for rendering settings UIs +├── smart_model.js # Core SmartModel class +├── test +│ └── smart_model.test.js # Unit tests for SmartModel +└── package.json +``` -1. Create an instance of `SmartModel` with the required configuration options: +## Getting Started ```javascript +import { SmartModel } from 'smart-model'; + const model = new SmartModel({ - adapters: { mock: MockAdapter }, - settings: { model_key: 'mock_model' }, - model_config: { adapter: 'mock' } + adapters: { myAdapter: MyAdapterClass }, + settings: { model_key: 'my_model' }, + model_config: { adapter: 'myAdapter' } }); +await model.initialize(); // Loads adapter ``` -2. Initialize the model (loads the specified adapter): -```javascript -await model.initialize(); -``` - -3. Use adapter-specific methods: +## Extending +Subclass `SmartModel` to add domain logic: ```javascript -const result = await model.invoke_adapter_method('mock_method', 'test input'); -console.log(result); // "Processed: test input" +class MyCustomModel extends SmartModel { + get default_model_key() { return 'my_model'; } + async custom_method() { + return await this.invoke_adapter_method('some_adapter_method'); + } +} ``` +## Adapters +Adapters bridge between `SmartModel` and external APIs or local logic. Create a subclass of `SmartModelAdapter` and implement necessary methods (e.g., `load`, `rank`, `invoke_api_call`). +## Testing +Run tests: +``` +npm test +``` - -## State Transitions - -The `SmartModel` instance has the following states: -- `unloaded`: No adapter is loaded. -- `loading`: Adapter is in the process of being loaded. -- `loaded`: Adapter has been successfully loaded. -- `unloading`: Adapter is in the process of being unloaded. - -These states are managed automatically when calling `load` and `unload`. - - +## License +MIT \ No newline at end of file diff --git a/smart-rank-model/.scignore b/smart-rank-model/.scignore new file mode 100644 index 00000000..190a14df --- /dev/null +++ b/smart-rank-model/.scignore @@ -0,0 +1,4 @@ +# Ignore files when getting folder as context using Smart Context VSCode extension https://github.com/brianpetro/smart-context-vscode +cl100k_base.json +connectors/** +test/*.json \ No newline at end of file diff --git a/smart-rank-model/README.md b/smart-rank-model/README.md index 469d3df1..37f906d2 100644 --- a/smart-rank-model/README.md +++ b/smart-rank-model/README.md @@ -1,36 +1,77 @@ -# @smart-rank-model +# smart-rank-model -Convenient interface for utilizing various ranking models via API and locally. +`smart-rank-model` extends `smart-model` for ranking tasks. It supports multiple adapters (e.g., Cohere, transformers) to reorder documents based on a given query, providing a unified interface for switching backend models seamlessly. ## Features +- Extends `SmartModel` to handle ranking tasks. +- Provides a `rank(query, documents)` method, returning ranked results. +- Integrates adapters (API-based, local transformers, iframes, workers) to accommodate different runtime environments. +- Includes a curated `models.json` mapping adapter names to known model configurations. -- Supports multiple ranking models -- Flexible adapter system for different model implementations -- GPU acceleration support (when available) -- Easy-to-use API for document ranking +## Folder Structure +``` +smart-rank-model +├── adapters +│ ├── cohere.js # Cohere API adapter +│ ├── transformers.js # Local transformers adapter (Hugging Face) +│ ├── transformers_iframe.js # Transformers via iframe +│ ├── transformers_worker.js # Transformers via Web Worker +│ ├── _adapter.js # Base ranking adapter +│ ├── _api.js # Base API adapter class +│ └── _message.js # Base message-based adapter class +├── connectors # Scripts used by iframe/worker adapters +├── build # Build scripts for connectors +├── test +│ ├── cohere.test.js +│ ├── transformers.test.js +│ └── _env.js +├── models.json # Known model configurations +├── smart_rank_model.js # Core SmartRankModel class +└── package.json +``` -## Configuration +## Usage +```javascript +import { SmartRankModel } from 'smart-rank-model'; +import { cohere as CohereAdapter } from 'smart-rank-model/adapters.js'; -The `SmartRankModel` constructor accepts two parameters: +const rankModel = new SmartRankModel({ + model_key: 'cohere-rerank-english-v3.0', + adapters: { cohere: CohereAdapter }, + model_config: { api_key: 'YOUR_COHERE_API_KEY', endpoint: "https://api.cohere.ai/v1/rerank" } +}); +await rankModel.initialize(); -1. `env`: The environment object containing adapter configurations -2. `opts`: Model configuration options +const query = "What is the capital of the United States?"; +const documents = [ + "Carson City is the capital city of Nevada.", + "Washington, D.C. is the capital of the United States." +]; +const results = await rankModel.rank(query, documents); +console.log(results); +``` -### Model Options +## Switching Adapters +Change `model_key` and `adapter` to switch between Cohere, transformers, or other supported backends. `models.json` provides a quick lookup of supported models. -- `model_key`: Identifier for the model in the `models.json` file -- `adapter`: The adapter to use for this model -- `use_gpu`: Boolean to enable/disable GPU acceleration (auto-detected if not specified) -- `gpu_batch_size`: Batch size for GPU processing (default: 10) +## Local Transformers +Use `transformers` adapter to run local ranking models via WebAssembly, WebGPU, or CPU fallback: +```javascript +import { transformers as TransformersAdapter } from 'smart-rank-model/adapters.js'; -## Adapters +const localModel = new SmartRankModel({ + model_key: 'jinaai/jina-reranker-v1-tiny-en', + adapters: { transformers: TransformersAdapter }, +}); +await localModel.initialize(); +const response = await localModel.rank("organic skincare", ["Some doc", "Another doc"]); +``` -Adapters should be implemented and added to the `env.opts.smart_rank_adapters` object. Each adapter should implement the following methods: - -- `constructor(model)`: Initialize the adapter -- `load()`: Load the model -- `rank(query, documents)`: Rank the documents based on the query +## Testing +Run tests: +``` +npm test +``` ## License - -MIT License. See `LICENSE` file for details. \ No newline at end of file +MIT diff --git a/smart-rank-model/adapters.js b/smart-rank-model/adapters.js index ebcc93a4..f5c0e211 100644 --- a/smart-rank-model/adapters.js +++ b/smart-rank-model/adapters.js @@ -1,10 +1,13 @@ import { SmartRankAdapter } from "./adapters/_adapter.js"; +import { SmartRankCohereAdapter } from "./adapters/cohere.js"; import { SmartRankTransformersAdapter } from "./adapters/transformers.js"; import { SmartRankTransformersIframeAdapter } from "./adapters/transformers_iframe.js"; +import { SmartRankTransformersWorkerAdapter } from "./adapters/transformers_worker.js"; export { SmartRankAdapter as _default, SmartRankCohereAdapter as cohere, SmartRankTransformersAdapter as transformers, SmartRankTransformersIframeAdapter as transformers_iframe, -}; \ No newline at end of file + SmartRankTransformersWorkerAdapter as transformers_worker, +}; diff --git a/smart-rank-model/adapters/_adapter.js b/smart-rank-model/adapters/_adapter.js index 0ab4d0a3..4bde24ce 100644 --- a/smart-rank-model/adapters/_adapter.js +++ b/smart-rank-model/adapters/_adapter.js @@ -1,10 +1,33 @@ import { SmartModelAdapter } from "smart-model/adapters/_adapter.js"; -export class SmartRankModelAdapter extends SmartModelAdapter { +/** + * Base adapter class for ranking models + * @abstract + * @class SmartRankAdapter + * @extends SmartModelAdapter + */ +export class SmartRankAdapter extends SmartModelAdapter { + /** + * Create a SmartRankAdapter instance. + * @param {SmartRankModel} model - The parent SmartRankModel instance + */ constructor(model) { super(model); + /** + * @deprecated Use this.model instead + */ this.smart_rank = model; } - async count_tokens(input) { throw new Error("Not implemented"); } - async rank(query, documents) { throw new Error("Not implemented"); } + + /** + * Rank documents based on a query. + * @abstract + * @param {string} query - The query string + * @param {Array} documents - The documents to rank + * @returns {Promise>} Array of ranking results {index, score, ...} + * @throws {Error} If the method is not implemented by subclass + */ + async rank(query, documents) { + throw new Error('rank method not implemented'); + } } diff --git a/smart-rank-model/adapters/api.js b/smart-rank-model/adapters/_api.js similarity index 74% rename from smart-rank-model/adapters/api.js rename to smart-rank-model/adapters/_api.js index f92065f5..abb7e692 100644 --- a/smart-rank-model/adapters/api.js +++ b/smart-rank-model/adapters/_api.js @@ -1,35 +1,29 @@ +import { SmartRankAdapter } from "./_adapter.js"; import { SmartHttpRequest } from "smart-http-request"; -import { SmartRankModelAdapter } from "./_adapter.js"; import { SmartHttpRequestFetchAdapter } from "smart-http-request/adapters/fetch.js"; /** - * Base API adapter class for SmartRankModel. + * Base adapter class for API-based ranking models (e.g., Cohere) * Handles HTTP requests and response processing for remote ranking services. * @abstract * @class SmartRankModelApiAdapter - * @extends SmartRankModelAdapter + * @extends SmartRankAdapter */ -export class SmartRankModelApiAdapter extends SmartRankModelAdapter { - +export class SmartRankModelApiAdapter extends SmartRankAdapter { /** - * Get the request adapter class. - * @returns {SmartRankModelRequestAdapter} The request adapter class + * Get the API endpoint URL + * @returns {string} Endpoint URL */ - get req_adapter() { - return SmartRankModelRequestAdapter; + get endpoint() { + return this.model_config.endpoint; } /** - * Get the response adapter class. - * @returns {SmartRankModelResponseAdapter} The response adapter class + * Get the API key for authentication + * @returns {string} API key */ - get res_adapter() { - return SmartRankModelResponseAdapter; - } - - /** @returns {string} API endpoint URL */ - get endpoint() { - return this.model_config.endpoint; + get api_key() { + return this.adapter_settings.api_key || this.settings.api_key || this.model_config.api_key; } /** @@ -38,29 +32,17 @@ export class SmartRankModelApiAdapter extends SmartRankModelAdapter { */ get http_adapter() { if (!this._http_adapter) { - if (this.model.opts.http_adapter) - this._http_adapter = this.model.opts.http_adapter; - else - this._http_adapter = new SmartHttpRequest({ - adapter: SmartHttpRequestFetchAdapter, - }); + if (this.model.opts.http_adapter) this._http_adapter = this.model.opts.http_adapter; + else this._http_adapter = new SmartHttpRequest({ adapter: SmartHttpRequestFetchAdapter }); } return this._http_adapter; } - /** - * Get API key for authentication - * @returns {string} API key - */ - get api_key() { - return this.settings.api_key || this.model_config.api_key; - } - /** * Make an API request with retry logic * @param {Object} req - Request configuration * @param {number} [retries=0] - Number of retries attempted - * @returns {Promise} API response + * @returns {Promise} API response JSON */ async request(req, retries = 0) { try { @@ -70,9 +52,6 @@ export class SmartRankModelApiAdapter extends SmartRankModelAdapter { ...req, }); const resp_json = await this.get_resp_json(resp); - if (this.is_error(resp_json)) { - return await this.handle_request_err(resp_json, req, retries); - } return resp_json; } catch (error) { return await this.handle_request_err(error, req, retries); @@ -111,7 +90,7 @@ export class SmartRankModelApiAdapter extends SmartRankModelAdapter { * @returns {Promise} True if API key is valid */ async validate_api_key() { - const resp = await this.rank("test query", ["Test document 1", "Test document 2"]); + const resp = await this.rank("test query", ["Test document"]); return Array.isArray(resp) && resp.length > 0 && resp[0].score !== null; } } @@ -122,10 +101,10 @@ export class SmartRankModelApiAdapter extends SmartRankModelAdapter { */ export class SmartRankModelRequestAdapter { /** - * @constructor + * Create request adapter instance * @param {SmartRankModelApiAdapter} adapter - The SmartRankModelApiAdapter instance * @param {string} query - The query string - * @param {Array} documents - Array of document strings + * @param {Array} documents - Array of documents */ constructor(adapter, query, documents) { this.adapter = adapter; @@ -138,11 +117,13 @@ export class SmartRankModelRequestAdapter { * @returns {Object} Headers object */ get_headers() { - return { + let headers = { "Content-Type": "application/json", - ...(this.adapter.adapter_config.headers || {}), - "Authorization": `Bearer ${this.adapter.api_key}`, }; + if (this.adapter.api_key) { + headers["Authorization"] = `Bearer ${this.adapter.api_key}`; + } + return headers; } /** @@ -161,6 +142,7 @@ export class SmartRankModelRequestAdapter { * Prepare request body for API call * @abstract * @returns {Object} Request body object + * @throws {Error} If not implemented by subclass */ prepare_request_body() { throw new Error("prepare_request_body not implemented"); @@ -184,7 +166,7 @@ export class SmartRankModelResponseAdapter { /** * Convert response to standard format - * @returns {Array} Array of ranking results + * @returns {Array} Array of ranking results {index, score, ...} */ to_standard() { return this.parse_response(); @@ -194,9 +176,9 @@ export class SmartRankModelResponseAdapter { * Parse API response * @abstract * @returns {Array} Parsed ranking results + * @throws {Error} If not implemented by subclass */ parse_response() { throw new Error("parse_response not implemented"); } } - diff --git a/smart-rank-model/adapters/_message.js b/smart-rank-model/adapters/_message.js new file mode 100644 index 00000000..2dafea9d --- /dev/null +++ b/smart-rank-model/adapters/_message.js @@ -0,0 +1,106 @@ +import { SmartRankAdapter } from "./_adapter.js"; + +/** + * Base adapter for message-based ranking implementations (iframe/worker) + * Handles communication between main thread and isolated contexts. + * @abstract + * @class SmartRankMessageAdapter + * @extends SmartRankAdapter + */ +export class SmartRankMessageAdapter extends SmartRankAdapter { + /** + * Create message adapter instance + * @param {SmartRankModel} model - Parent model instance + */ + constructor(model) { + super(model); + /** + * Queue of pending message promises + * @type {Object.} + * @private + */ + this.message_queue = {}; + + /** + * Counter for message IDs + * @type {number} + * @private + */ + this.message_id = 0; + + /** + * Message connector implementation + * @type {string|null} + * @protected + */ + this.connector = null; + + /** + * Unique prefix for message IDs + * @type {string} + * @private + */ + this.message_prefix = `msg_${Math.random().toString(36).substr(2, 9)}_`; + } + + /** + * Send message and wait for response + * @protected + * @param {string} method - Method name to call (e.g., 'rank') + * @param {Object} params - Parameters for the method + * @returns {Promise} Response data + */ + async _send_message(method, params) { + return new Promise((resolve, reject) => { + const id = `${this.message_prefix}${this.message_id++}`; + this.message_queue[id] = { resolve, reject }; + this._post_message({ method, params, id }); + }); + } + + /** + * Handle response message from worker/iframe + * @protected + * @param {string} id - Message ID + * @param {*} result - Response result + * @param {Error} [error] - Response error + */ + _handle_message_result(id, result, error) { + if (!id.startsWith(this.message_prefix)) return; + + if (result?.model_loaded) { + console.log('model loaded'); + this.model.model_loaded = true; + } + + if (this.message_queue[id]) { + if (error) { + this.message_queue[id].reject(new Error(error)); + } else { + this.message_queue[id].resolve(result); + } + delete this.message_queue[id]; + } + } + + /** + * Rank documents based on a query + * @param {string} query - The query + * @param {Array} documents - Documents to rank + * @returns {Promise>} Ranking results + */ + async rank(query, documents) { + return this._send_message('rank', { query, documents }); + } + + /** + * Post message to worker/iframe + * @abstract + * @protected + * @param {Object} message_data - Message to send + * @throws {Error} If not implemented by subclass + */ + _post_message(message_data) { + throw new Error('_post_message must be implemented by subclass'); + } +} diff --git a/smart-rank-model/adapters/cohere.js b/smart-rank-model/adapters/cohere.js index a37040fd..55c7302a 100644 --- a/smart-rank-model/adapters/cohere.js +++ b/smart-rank-model/adapters/cohere.js @@ -1,4 +1,4 @@ -import { SmartRankModelApiAdapter, SmartRankModelRequestAdapter, SmartRankModelResponseAdapter } from './api.js'; +import { SmartRankModelApiAdapter, SmartRankModelRequestAdapter, SmartRankModelResponseAdapter } from './_api.js'; /** * Adapter for Cohere's ranking API. @@ -9,7 +9,7 @@ import { SmartRankModelApiAdapter, SmartRankModelRequestAdapter, SmartRankModelR export class SmartRankCohereAdapter extends SmartRankModelApiAdapter { /** * Get the request adapter class. - * @returns {SmartRankCohereRequestAdapter} The request adapter class + * @returns {typeof SmartRankCohereRequestAdapter} The request adapter class */ get req_adapter() { return SmartRankCohereRequestAdapter; @@ -17,19 +17,28 @@ export class SmartRankCohereAdapter extends SmartRankModelApiAdapter { /** * Get the response adapter class. - * @returns {SmartRankCohereResponseAdapter} The response adapter class + * @returns {typeof SmartRankCohereResponseAdapter} The response adapter class */ get res_adapter() { return SmartRankCohereResponseAdapter; } - /** @override */ - async load() { + /** + * Load the adapter + * @async + * @returns {Promise} + */ + async load() { // Implement any initialization if necessary - return true; + return; } - /** @override */ + /** + * Rank documents using Cohere API + * @param {string} query - The query + * @param {Array} documents - Documents to rank + * @returns {Promise>} Ranked documents + */ async rank(query, documents) { const request_adapter = new this.req_adapter(this, query, documents); const request_params = request_adapter.to_platform(); @@ -45,7 +54,7 @@ export class SmartRankCohereAdapter extends SmartRankModelApiAdapter { } /** - * Override the handle_request_err method for Cohere-specific error handling. + * Handle API request errors with specific logic for Cohere * @param {Error|Object} error - Error object * @param {Object} req - Original request * @param {number} retries - Number of retries attempted @@ -75,10 +84,9 @@ class SmartRankCohereRequestAdapter extends SmartRankModelRequestAdapter { */ prepare_request_body() { return { - model: "rerank-english-v2.0", query: this.query, documents: this.documents, - // top_n: 3, // Optional: specify if needed + model: "rerank-english-v2.0", }; } } @@ -98,10 +106,9 @@ class SmartRankCohereResponseAdapter extends SmartRankModelResponseAdapter { console.error("Invalid response format from Cohere API:", this.response); return []; } - return this.response.results.map((result, index) => ({ + return this.response.results.map((result) => ({ index: result.document_index, score: result.score, - // Add additional fields if necessary })); } -} \ No newline at end of file +} diff --git a/smart-rank-model/adapters/iframe.js b/smart-rank-model/adapters/iframe.js index b7e809bf..ac769704 100644 --- a/smart-rank-model/adapters/iframe.js +++ b/smart-rank-model/adapters/iframe.js @@ -1,89 +1,101 @@ -import { SmartRankModelAdapter } from "./_adapter.js"; +import { SmartRankMessageAdapter } from "./_message.js"; -export class SmartRankIframeAdapter extends SmartRankModelAdapter { - constructor(smart_rank) { - super(smart_rank); +/** + * Adapter for running ranking models in an iframe + * Provides isolation and separate context for model execution. + * @class SmartRankIframeAdapter + * @extends SmartRankMessageAdapter + */ +export class SmartRankIframeAdapter extends SmartRankMessageAdapter { + /** + * Create iframe adapter instance + * @param {SmartRankModel} model - Parent model instance + */ + constructor(model) { + super(model); + /** @type {HTMLIFrameElement|null} */ this.iframe = null; - this.message_queue = {}; - this.message_id = 0; - this.connector = null; // override in subclass - this.origin = window.location.origin; + /** @type {string} */ + this.origin = (typeof window !== 'undefined') ? window.location.origin : 'http://localhost'; + /** @type {string} */ this.iframe_id = `smart_rank_iframe_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`; } + /** + * Initialize iframe and load model + * @async + * @returns {Promise} + */ async load() { + if (typeof document === 'undefined') { + throw new Error('SmartRankIframeAdapter requires a browser environment'); + } + + const existing_iframe = document.getElementById(this.iframe_id); + if (existing_iframe) existing_iframe.remove(); + // Create and append iframe this.iframe = document.createElement('iframe'); this.iframe.style.display = 'none'; this.iframe.id = this.iframe_id; + this.iframe.sandbox = 'allow-scripts allow-same-origin'; document.body.appendChild(this.iframe); // Set up message listener - window.addEventListener('message', this._handle_message.bind(this)); + window.addEventListener('message', this._handle_window_message.bind(this)); - // Load the iframe content + // Load iframe content this.iframe.srcdoc = ` - - - - - - `; + + + + + + `; - // Wait for iframe to load await new Promise(resolve => this.iframe.onload = resolve); - // Initialize the model in the iframe - await this._send_message('load', this.smart_rank.opts); + const load_opts = { + model_key: this.model.model_key, + use_gpu: this.model.opts.use_gpu || false, + adapters: null, + settings: null, + }; + await this._send_message('load', load_opts); + return new Promise(resolve => { const check_model_loaded = () => { - if (this.smart_rank.model_loaded) { - resolve(); - } else { - setTimeout(check_model_loaded, 100); - } + if (this.model.model_loaded) resolve(); + else setTimeout(check_model_loaded, 100); }; check_model_loaded(); }); } - async _send_message(method, params) { - return new Promise((resolve, reject) => { - const id = this.message_id++; - this.message_queue[id] = { resolve, reject }; - this.iframe.contentWindow.postMessage({ method, params, id, iframe_id: this.iframe_id }, this.origin); - }); - } - - _handle_message(event) { + /** + * Handle messages from the iframe + * @private + * @param {MessageEvent} event - Message event + */ + _handle_window_message(event) { if (event.origin !== this.origin || event.data.iframe_id !== this.iframe_id) return; - const { id, result, error } = event.data; - if (result?.model_loaded) { - console.log('model loaded'); - this.smart_rank.model_loaded = true; - } - if (this.message_queue[id]) { - if (error) { - this.message_queue[id].reject(new Error(error)); - } else { - this.message_queue[id].resolve(result); - } - delete this.message_queue[id]; - } + this._handle_message_result(id, result, error); } - async rank(query, documents) { - return this._send_message('rank', { query, documents }); + /** + * Post message to iframe + * @protected + * @param {Object} message_data - Message to send + */ + _post_message(message_data) { + this.iframe.contentWindow.postMessage({ ...message_data, iframe_id: this.iframe_id }, this.origin); } - -} \ No newline at end of file +} diff --git a/smart-rank-model/adapters/transformers.js b/smart-rank-model/adapters/transformers.js index 10d4a0e1..5d75da8f 100644 --- a/smart-rank-model/adapters/transformers.js +++ b/smart-rank-model/adapters/transformers.js @@ -1,49 +1,129 @@ import { SmartRankAdapter } from "./_adapter.js"; +/** + * Default configurations for transformers adapter + * @property {string} adapter - Adapter identifier + * @property {string} description - Human-readable description + * @property {string} default_model - Default model to use + */ +export const transformers_defaults = { + adapter: 'transformers', + description: 'Transformers', + default_model: 'jinaai/jina-reranker-v1-tiny-en', +}; + +export const transformers_models = { + 'jinaai/jina-reranker-v1-tiny-en': { + adapter: 'transformers', + model_key: 'jinaai/jina-reranker-v1-tiny-en', + }, + 'jinaai/jina-reranker-v1-turbo-en': { + adapter: 'transformers', + model_key: 'jinaai/jina-reranker-v1-turbo-en', + }, + 'mixedbread-ai/mxbai-rerank-xsmall-v1': { + adapter: 'transformers', + model_key: 'mixedbread-ai/mxbai-rerank-xsmall-v1', + }, + 'Xenova/bge-reranker-base': { + adapter: 'transformers', + model_key: 'Xenova/bge-reranker-base', + }, +}; + +/** + * Adapter for local transformer-based ranking models + * Uses @huggingface/transformers for model loading and inference + * @class SmartRankTransformersAdapter + * @extends SmartRankAdapter + * + * @example + * ```javascript + * const model = await SmartRankModel.load(env, { + * model_key: 'jinaai/jina-reranker-v1-tiny-en', + * adapters: { + * transformers: SmartRankTransformersAdapter + * } + * }); + * const results = await model.rank('query', ['doc1', 'doc2']); + * console.log(results); + * ``` + */ export class SmartRankTransformersAdapter extends SmartRankAdapter { - constructor(smart_rank) { - super(smart_rank); - this.model = null; + static defaults = transformers_defaults; + + /** + * Create transformers adapter instance + * @param {SmartRankModel} model - Parent model instance + */ + constructor(model) { + super(model); + /** @type {any|null} */ + this.model_instance = null; + /** @type {any|null} */ this.tokenizer = null; } - get use_gpu() { return this.smart_rank.opts.use_gpu || false; } + /** + * Load model and tokenizer + * @async + * @returns {Promise} + */ async load() { console.log('TransformersAdapter initializing'); - const { env, AutoTokenizer, AutoModelForSequenceClassification } = await import('@xenova/transformers'); - console.log('Transformers loaded'); + console.log(this.model.model_key); + const { AutoTokenizer, AutoModelForSequenceClassification, env } = await import('@huggingface/transformers'); env.allowLocalModels = false; const pipeline_opts = { quantized: true, }; - if (this.use_gpu) { + + if (this.model.opts.use_gpu) { console.log("[Transformers] Using GPU"); pipeline_opts.device = 'webgpu'; - pipeline_opts.dtype = 'fp32'; + // pipeline_opts.dtype = 'fp32'; } else { console.log("[Transformers] Using CPU"); - env.backends.onnx.wasm.numThreads = 8; + // env.backends.onnx.wasm.numThreads = 8; } - this.model = await AutoModelForSequenceClassification.from_pretrained(this.smart_rank.opts.model_key, pipeline_opts); - console.log('Model loaded'); - this.tokenizer = await AutoTokenizer.from_pretrained(this.smart_rank.opts.model_key); - console.log('Tokenizer loaded'); + + this.model_instance = await AutoModelForSequenceClassification.from_pretrained(this.model.model_key, pipeline_opts); + this.tokenizer = await AutoTokenizer.from_pretrained(this.model.model_key); console.log('TransformersAdapter initialized'); } + /** + * Rank documents based on a query + * @param {string} query - The query string + * @param {Array} documents - Documents to rank + * @param {Object} [options={}] - Additional ranking options + * @param {number} [options.top_k] - Limit the number of returned documents + * @param {boolean} [options.return_documents=false] - Whether to include original documents in results + * @returns {Promise>} Ranked documents with properties like {index, score, text} + */ async rank(query, documents, options = {}) { + console.log('TransformersAdapter ranking'); + console.log(documents); const { top_k = undefined, return_documents = false } = options; + if (!this.model_instance || !this.tokenizer) await this.load(); + + console.log("tokenizing"); const inputs = this.tokenizer( new Array(documents.length).fill(query), { text_pair: documents, padding: true, truncation: true } ); - const { logits } = await this.model(inputs); - return logits.sigmoid().tolist() + console.log("running model"); + const { logits } = await this.model_instance(inputs); + console.log("done"); + return logits + .sigmoid() + .tolist() .map(([score], i) => ({ index: i, score, ...(return_documents ? { text: documents[i] } : {}) - })).sort((a, b) => b.score - a.score).slice(0, top_k); + })) + .sort((a, b) => b.score - a.score) + .slice(0, top_k); } - -} \ No newline at end of file +} diff --git a/smart-rank-model/adapters/transformers_iframe.js b/smart-rank-model/adapters/transformers_iframe.js index c867e5b2..c62e683b 100644 --- a/smart-rank-model/adapters/transformers_iframe.js +++ b/smart-rank-model/adapters/transformers_iframe.js @@ -1,9 +1,37 @@ import { SmartRankIframeAdapter } from "./iframe.js"; import { transformers_connector } from "../connectors/transformers_iframe.js"; +import { transformers_defaults, transformers_models } from "./transformers.js"; +/** + * Adapter for running transformer-based ranking models in an iframe + * Combines transformer capabilities with iframe isolation. + * @class SmartRankTransformersIframeAdapter + * @extends SmartRankIframeAdapter + * + * @example + * ```javascript + * const model = await SmartRankModel.load(env, { + * model_key: 'jinaai/jina-reranker-v1-tiny-en', + * adapters: { + * transformers_iframe: SmartRankTransformersIframeAdapter + * } + * }); + * const results = await model.rank('query', ['doc1', 'doc2']); + * console.log(results); + * ``` + */ export class SmartRankTransformersIframeAdapter extends SmartRankIframeAdapter { - constructor(smart_rank) { - super(smart_rank); + static defaults = transformers_defaults; + + /** + * Create transformers iframe adapter instance + * @param {SmartRankModel} model - Parent model instance + */ + constructor(model) { + super(model); this.connector = transformers_connector; } -} \ No newline at end of file + get models() { + return transformers_models; + } +} diff --git a/smart-rank-model/adapters/transformers_worker.js b/smart-rank-model/adapters/transformers_worker.js index 7ec1c490..4fc908e4 100644 --- a/smart-rank-model/adapters/transformers_worker.js +++ b/smart-rank-model/adapters/transformers_worker.js @@ -1,9 +1,22 @@ import { SmartRankWorkerAdapter } from "./worker.js"; -// import { transformers_connector } from "../connectors/transformers_worker.js"; +import { transformers_defaults } from "./transformers.js"; +/** + * Adapter for running transformer-based ranking models in a Web Worker + * Provides isolation and parallel processing. + * @class SmartRankTransformersWorkerAdapter + * @extends SmartRankWorkerAdapter + */ export class SmartRankTransformersWorkerAdapter extends SmartRankWorkerAdapter { - constructor(smart_rank) { - super(smart_rank); - this.connector = "../connectors/transformers_worker.js"; - } + static defaults = transformers_defaults; + + /** + * Create transformers worker adapter instance + * @param {SmartRankModel} model - Parent model instance + */ + constructor(model) { + super(model); + // Set connector URL to the worker script + this.connector = "../connectors/transformers_worker.js"; + } } diff --git a/smart-rank-model/adapters/worker.js b/smart-rank-model/adapters/worker.js index 2ef8bf0a..6f37742b 100644 --- a/smart-rank-model/adapters/worker.js +++ b/smart-rank-model/adapters/worker.js @@ -1,72 +1,66 @@ -import { SmartRankAdapter } from "./_adapter.js"; +import { SmartRankMessageAdapter } from "./_message.js"; -export class SmartRankWorkerAdapter extends SmartRankAdapter { - constructor(smart_rank) { - super(smart_rank); +/** + * Adapter for running ranking models in a Web Worker + * Provides parallel processing in a separate thread. + * @class SmartRankWorkerAdapter + * @extends SmartRankMessageAdapter + */ +export class SmartRankWorkerAdapter extends SmartRankMessageAdapter { + /** + * Create worker adapter instance + * @param {SmartRankModel} model - Parent model instance + */ + constructor(model) { + super(model); + /** @type {Worker|null} */ this.worker = null; - this.message_queue = {}; - this.message_id = 0; - this.connector = null; // override in subclass + /** @type {string} */ this.worker_id = `smart_rank_worker_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`; } + /** + * Initialize worker and load model + * @returns {Promise} + */ async load() { - console.log('loading worker adapter', this.smart_rank.opts); - - // Create worker using a relative path - const worker_url = new URL(this.connector, import.meta.url); - this.worker = new Worker(worker_url, { type: 'module' }); - console.log('worker', this.worker); - + if (!this.connector) { + throw new Error('No worker connector script specified for SmartRankWorkerAdapter.'); + } + this.worker = new Worker(this.connector, { type: 'module' }); + console.log('New worker created', this.worker); - // Set up message listener - this.worker.addEventListener('message', this._handle_message.bind(this)); + this.worker.addEventListener('message', this._handle_worker_message.bind(this)); - // Initialize the model in the worker - await this._send_message('load', { ...this.smart_rank.opts, worker_id: this.worker_id }); + await this._send_message('load', { model_key: this.model.model_key, adapters: null, settings: null, worker_id: this.worker_id }); await new Promise(resolve => { const check_model_loaded = () => { - console.log('check_model_loaded', this.smart_rank.model_loaded); - if (this.smart_rank.model_loaded) { - resolve(); - } else { - setTimeout(check_model_loaded, 100); - } + console.log('check_model_loaded', this.model.model_loaded); + if (this.model.model_loaded) resolve(); + else setTimeout(check_model_loaded, 100); }; check_model_loaded(); }); console.log('model loaded'); } - async _send_message(method, params) { - return new Promise((resolve, reject) => { - const id = this.message_id++; - this.message_queue[id] = { resolve, reject }; - this.worker.postMessage({ method, params, id, worker_id: this.worker_id }); - }); + /** + * Post message to worker + * @protected + * @param {Object} message_data - Message to send + */ + _post_message(message_data) { + this.worker.postMessage({ ...message_data, worker_id: this.worker_id }); } - _handle_message(event) { - console.log('handle_message', event.data); + /** + * Handle message from worker + * @private + * @param {MessageEvent} event - Message event + */ + _handle_worker_message(event) { const { id, result, error, worker_id } = event.data; if (worker_id !== this.worker_id) return; - - if (result?.model_loaded) { - console.log('model loaded'); - this.smart_rank.model_loaded = true; - } - if (this.message_queue[id]) { - if (error) { - this.message_queue[id].reject(new Error(error)); - } else { - this.message_queue[id].resolve(result); - } - delete this.message_queue[id]; - } + this._handle_message_result(id, result, error); } - - async rank(query, documents) { - return this._send_message('rank', { query, documents }); - } - } diff --git a/smart-rank-model/build/esbuild.js b/smart-rank-model/build/esbuild.js index 08d5e4bb..d3d8ffd3 100644 --- a/smart-rank-model/build/esbuild.js +++ b/smart-rank-model/build/esbuild.js @@ -14,20 +14,20 @@ async function build_transformers_iframe_connector() { target: 'es2020', outfile: join(__dirname, '../connectors/transformers_iframe.js'), write: false, - external: ['@xenova/transformers'], + external: ['@huggingface/transformers'], }); const outputContent = result.outputFiles[0].text; const wrappedContent = `export const transformers_connector = ${JSON.stringify(outputContent)};` - .replace('@xenova/transformers', 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2') + .replace('@huggingface/transformers', 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.1.1') // escape ${} // .replace(/\$\{([\w.]+)\}/g, '\\`+$1+\\`') ; writeFileSync(join(__dirname, '../connectors/transformers_iframe.js'), wrappedContent); - console.log('Build completed successfully.'); + console.log('Build transformers_iframe_connector completed successfully.'); } catch (error) { - console.error('Build failed:', error); + console.error('Build transformers_iframe_connector failed:', error); } } @@ -40,20 +40,20 @@ async function build_transformers_worker_connector() { target: 'es2020', outfile: join(__dirname, '../connectors/transformers_worker.js'), write: false, - external: ['@xenova/transformers'], + external: ['@huggingface/transformers'], }); const connector = result.outputFiles[0].text - .replace('@xenova/transformers', 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2') + .replace('@huggingface/transformers', 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.1.1') ; writeFileSync(join(__dirname, '../connectors/transformers_worker.js'), connector); - console.log('Build worker completed successfully.'); + console.log('Build transformers_worker_connector completed successfully.'); } catch (error) { - console.error('Build failed:', error); + console.error('Build transformers_worker_connector failed:', error); } } (async () => { await build_transformers_iframe_connector(); await build_transformers_worker_connector(); -})(); \ No newline at end of file +})(); diff --git a/smart-rank-model/build/transformers_iframe_script.js b/smart-rank-model/build/transformers_iframe_script.js index 35bbdb54..f026dacd 100644 --- a/smart-rank-model/build/transformers_iframe_script.js +++ b/smart-rank-model/build/transformers_iframe_script.js @@ -2,15 +2,12 @@ import { SmartRankModel } from '../smart_rank_model.js'; import { SmartRankTransformersAdapter } from '../adapters/transformers.js'; let model = null; -let smart_env = { - smart_rank_active_models: {}, - opts: { - smart_rank_adapters: { - transformers: SmartRankTransformersAdapter - } - } -} +/** + * Process incoming messages and perform ranking. + * @param {Object} data - Message data containing method, params, id, iframe_id + * @returns {Promise} Response containing id, result, and iframe_id + */ async function process_message(data) { const { method, params, id, iframe_id } = data; try { @@ -21,7 +18,13 @@ async function process_message(data) { break; case 'load': console.log('load', params); - model = await SmartRankModel.load(smart_env, { adapter: 'transformers', model_key: params.model_key, ...params }); + model = new SmartRankModel({ + ...params, + adapters: { transformers: SmartRankTransformersAdapter }, + adapter: 'transformers', + settings: {} + }); + await model.load(); result = { model_loaded: true }; break; case 'rank': @@ -37,4 +40,6 @@ async function process_message(data) { return { id, error: error.message, iframe_id }; } } -process_message({ method: 'init' }); \ No newline at end of file + +// Initialize if needed +process_message({ method: 'init' }); diff --git a/smart-rank-model/build/transformers_worker_script.js b/smart-rank-model/build/transformers_worker_script.js index a91a86a3..e9ecef92 100644 --- a/smart-rank-model/build/transformers_worker_script.js +++ b/smart-rank-model/build/transformers_worker_script.js @@ -9,37 +9,50 @@ let smart_env = { transformers: SmartRankTransformersAdapter } } -} +}; +/** + * Process incoming messages and perform ranking. + * @param {Object} data - Message data containing method, params, id, worker_id + * @returns {Promise} Response containing id, result, and worker_id + */ async function process_message(data) { - const { method, params, id, worker_id } = data; - try { - let result; - switch (method) { - case 'load': - console.log('load', params); - model = await SmartRankModel.load(smart_env, { adapter: 'transformers', model_key: params.model_key, ...params }); - result = { model_loaded: true }; - break; - case 'rank': - if (!model) throw new Error('Model not loaded'); - result = await model.rank(params.query, params.documents); - break; - default: - throw new Error(`Unknown method: ${method}`); + const { method, params, id, worker_id } = data; + try { + let result; + switch (method) { + case 'load': + console.log('load', params); + if (!model) { + model = new SmartRankModel({ + ...params, + adapters: { transformers: SmartRankTransformersAdapter }, + adapter: 'transformers', + settings: {} + }); + await model.load(); } - return { id, result, worker_id }; - } catch (error) { - console.error('Error processing message:', error); - return { id, error: error.message, worker_id }; + result = { model_loaded: true }; + break; + case 'rank': + if (!model) throw new Error('Model not loaded'); + result = await model.rank(params.query, params.documents); + break; + default: + throw new Error(`Unknown method: ${method}`); } + return { id, result, worker_id }; + } catch (error) { + console.error('Error processing message:', error); + return { id, error: error.message, worker_id }; + } } self.addEventListener('message', async (event) => { - console.log('message', event.data); - const response = await process_message(event.data); - self.postMessage(response); + const response = await process_message(event.data); + self.postMessage(response); }); + console.log('worker loaded'); // Export process_message for testing purposes diff --git a/smart-rank-model/connectors/transformers_iframe.js b/smart-rank-model/connectors/transformers_iframe.js index 8a9147b0..7f43cf06 100644 --- a/smart-rank-model/connectors/transformers_iframe.js +++ b/smart-rank-model/connectors/transformers_iframe.js @@ -1 +1 @@ -export const transformers_connector = "// models.json\nvar models_default = {\n \"cohere-rerank-english-v3.0\": {\n adapter: \"cohere\",\n model_name: \"rerank-english-v3.0\",\n model_description: \"Cohere Rerank English v3.0\",\n model_version: \"3.0\",\n endpoint: \"https://api.cohere.ai/v1/rerank\"\n },\n \"jinaai/jina-reranker-v1-tiny-en\": {\n adapter: \"transformers\",\n model_key: \"jinaai/jina-reranker-v1-tiny-en\"\n },\n \"jinaai/jina-reranker-v1-turbo-en\": {\n adapter: \"transformers\",\n model_key: \"jinaai/jina-reranker-v1-turbo-en\"\n },\n \"mixedbread-ai/mxbai-rerank-xsmall-v1\": {\n adapter: \"transformers\",\n model_key: \"mixedbread-ai/mxbai-rerank-xsmall-v1\"\n },\n \"Xenova/bge-reranker-base\": {\n adapter: \"transformers\",\n model_key: \"Xenova/bge-reranker-base\"\n }\n};\n\n// smart_rank_model.js\nvar SmartRankModel = class _SmartRankModel {\n /**\n * Create a SmartRank instance.\n * @param {string} env - The environment to use.\n * @param {object} opts - Full model configuration object or at least a model_key and adapter\n */\n constructor(env, opts = {}) {\n this.env = env;\n this.opts = {\n ...models_default[opts.model_key] || {},\n ...opts\n };\n if (!this.opts.adapter) return console.warn(\"SmartRankModel adapter not set\");\n if (!this.env.opts.smart_rank_adapters[this.opts.adapter]) return console.warn(`SmartRankModel adapter ${this.opts.adapter} not found`);\n if (typeof navigator !== \"undefined\") this.opts.use_gpu = !!navigator?.gpu && this.opts.gpu_batch_size !== 0;\n this.opts.use_gpu = false;\n this.adapter = new this.env.opts.smart_rank_adapters[this.opts.adapter](this);\n }\n /**\n * Used to load a model with a given configuration.\n * @param {*} env \n * @param {*} opts \n */\n static async load(env, opts = {}) {\n if (env.smart_rank_active_models?.[opts.model_key]) return env.smart_rank_active_models[opts.model_key];\n try {\n const model2 = new _SmartRankModel(env, opts);\n await model2.adapter.load();\n if (!env.smart_rank_active_models) env.smart_rank_active_models = {};\n env.smart_rank_active_models[opts.model_key] = model2;\n return model2;\n } catch (error) {\n console.error(`Error loading rank model ${opts.model_key}:`, error);\n return null;\n }\n }\n async rank(query, documents) {\n return this.adapter.rank(query, documents);\n }\n};\n\n// adapters/_adapter.js\nvar SmartRankAdapter = class {\n constructor(smart_rank) {\n this.smart_rank = smart_rank;\n }\n async load() {\n throw new Error(\"Not implemented\");\n }\n async rank(query, documents) {\n throw new Error(\"Not implemented\");\n }\n};\n\n// adapters/transformers.js\nvar SmartRankTransformersAdapter = class extends SmartRankAdapter {\n constructor(smart_rank) {\n super(smart_rank);\n this.model = null;\n this.tokenizer = null;\n }\n get use_gpu() {\n return this.smart_rank.opts.use_gpu || false;\n }\n async load() {\n console.log(\"TransformersAdapter initializing\");\n const { env, AutoTokenizer, AutoModelForSequenceClassification } = await import(\"https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.1\");\n console.log(\"Transformers loaded\");\n env.allowLocalModels = false;\n const pipeline_opts = {\n quantized: true\n };\n if (this.use_gpu) {\n console.log(\"[Transformers] Using GPU\");\n pipeline_opts.device = \"webgpu\";\n pipeline_opts.dtype = \"fp32\";\n } else {\n console.log(\"[Transformers] Using CPU\");\n env.backends.onnx.wasm.numThreads = 8;\n }\n this.model = await AutoModelForSequenceClassification.from_pretrained(this.smart_rank.opts.model_key, pipeline_opts);\n console.log(\"Model loaded\");\n this.tokenizer = await AutoTokenizer.from_pretrained(this.smart_rank.opts.model_key);\n console.log(\"Tokenizer loaded\");\n console.log(\"TransformersAdapter initialized\");\n }\n async rank(query, documents, options = {}) {\n const { top_k = void 0, return_documents = false } = options;\n const inputs = this.tokenizer(\n new Array(documents.length).fill(query),\n { text_pair: documents, padding: true, truncation: true }\n );\n const { logits } = await this.model(inputs);\n return logits.sigmoid().tolist().map(([score], i) => ({\n index: i,\n score,\n ...return_documents ? { text: documents[i] } : {}\n })).sort((a, b) => b.score - a.score).slice(0, top_k);\n }\n};\n\n// build/transformers_iframe_script.js\nvar model = null;\nvar smart_env = {\n smart_rank_active_models: {},\n opts: {\n smart_rank_adapters: {\n transformers: SmartRankTransformersAdapter\n }\n }\n};\nasync function processMessage(data) {\n const { method, params, id, iframe_id } = data;\n try {\n let result;\n switch (method) {\n case \"init\":\n console.log(\"init\");\n break;\n case \"load\":\n console.log(\"load\", params);\n model = await SmartRankModel.load(smart_env, { adapter: \"transformers\", model_key: params.model_key, ...params });\n result = { model_loaded: true };\n break;\n case \"rank\":\n if (!model) throw new Error(\"Model not loaded\");\n result = await model.rank(params.query, params.documents);\n break;\n default:\n throw new Error(`Unknown method: ${method}`);\n }\n return { id, result, iframe_id };\n } catch (error) {\n console.error(\"Error processing message:\", error);\n return { id, error: error.message, iframe_id };\n }\n}\nprocessMessage({ method: \"init\" });\n"; \ No newline at end of file +export const transformers_connector = "var __defProp = Object.defineProperty;\nvar __defNormalProp = (obj, key, value) => key in obj ? __defProp(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value;\nvar __publicField = (obj, key, value) => __defNormalProp(obj, typeof key !== \"symbol\" ? key + \"\" : key, value);\n\n// ../smart-model/smart_model.js\nvar SmartModel = class {\n /**\n * Create a SmartModel instance.\n * @param {Object} opts - Configuration options\n * @param {Object} opts.adapters - Map of adapter names to adapter classes\n * @param {Object} opts.settings - Model settings configuration\n * @param {Object} opts.model_config - Model-specific configuration\n * @param {string} opts.model_config.adapter - Name of the adapter to use\n * @param {string} [opts.model_key] - Optional model identifier to override settings\n * @throws {Error} If required options are missing\n */\n constructor(opts = {}) {\n __publicField(this, \"scope_name\", \"smart_model\");\n this.opts = opts;\n this.validate_opts(opts);\n this.state = \"unloaded\";\n this._adapter = null;\n }\n /**\n * Initialize the model by loading the configured adapter.\n * @async\n * @returns {Promise}\n */\n async initialize() {\n this.load_adapter(this.adapter_name);\n await this.load();\n }\n /**\n * Validate required options.\n * @param {Object} opts - Configuration options\n */\n validate_opts(opts) {\n if (!opts.adapters) throw new Error(\"opts.adapters is required\");\n if (!opts.settings) throw new Error(\"opts.settings is required\");\n }\n /**\n * Get the current settings\n * @returns {Object} Current settings\n */\n get settings() {\n if (!this.opts.settings) this.opts.settings = {\n ...this.constructor.defaults\n };\n return this.opts.settings;\n }\n /**\n * Get the current adapter name\n * @returns {string} Current adapter name\n */\n get adapter_name() {\n const adapter_key = this.opts.model_config?.adapter || this.opts.adapter || this.settings.adapter || Object.keys(this.adapters)[0];\n if (!adapter_key || !this.adapters[adapter_key]) throw new Error(`Platform \"${adapter_key}\" not supported`);\n return adapter_key;\n }\n /**\n * Get adapter-specific settings.\n * @returns {Object} Settings for current adapter\n */\n get adapter_settings() {\n if (!this.settings[this.adapter_name]) this.settings[this.adapter_name] = {};\n return this.settings[this.adapter_name];\n }\n get adapter_config() {\n const base_config = this.adapters[this.adapter_name]?.defaults || {};\n return {\n ...base_config,\n ...this.adapter_settings,\n ...this.opts.adapter_config\n };\n }\n /**\n * Get available models.\n * @returns {Object} Map of model objects\n */\n get models() {\n return this.adapter.models;\n }\n /**\n * Get the default model key to use\n * @returns {string} Default model identifier\n */\n get default_model_key() {\n throw new Error(\"default_model_key must be overridden in sub-class\");\n }\n /**\n * Get the current model key\n * @returns {string} Current model key\n */\n get model_key() {\n return this.opts.model_key || this.adapter_config.model_key || this.settings.model_key || this.default_model_key;\n }\n /**\n * Get the current model configuration\n * @returns {Object} Combined base and custom model configuration\n */\n get model_config() {\n const model_key = this.model_key;\n const base_model_config = this.models[model_key] || {};\n return {\n ...this.adapter_config,\n ...base_model_config,\n ...this.opts.model_config\n };\n }\n get model_settings() {\n if (!this.settings[this.model_key]) this.settings[this.model_key] = {};\n return this.settings[this.model_key];\n }\n /**\n * Load the current adapter and transition to loaded state.\n * @async\n * @returns {Promise}\n */\n async load() {\n this.set_state(\"loading\");\n if (!this.adapter?.loaded) {\n await this.invoke_adapter_method(\"load\");\n }\n this.set_state(\"loaded\");\n }\n /**\n * Unload the current adapter and transition to unloaded state.\n * @async\n * @returns {Promise}\n */\n async unload() {\n if (this.adapter?.loaded) {\n this.set_state(\"unloading\");\n await this.invoke_adapter_method(\"unload\");\n this.set_state(\"unloaded\");\n }\n }\n /**\n * Set the model's state.\n * @param {('unloaded'|'loading'|'loaded'|'unloading')} new_state - The new state\n * @throws {Error} If the state is invalid\n */\n set_state(new_state) {\n const valid_states = [\"unloaded\", \"loading\", \"loaded\", \"unloading\"];\n if (!valid_states.includes(new_state)) {\n throw new Error(`Invalid state: ${new_state}`);\n }\n this.state = new_state;\n }\n get is_loading() {\n return this.state === \"loading\";\n }\n get is_loaded() {\n return this.state === \"loaded\";\n }\n get is_unloading() {\n return this.state === \"unloading\";\n }\n get is_unloaded() {\n return this.state === \"unloaded\";\n }\n // ADAPTERS\n /**\n * Get the map of available adapters\n * @returns {Object} Map of adapter names to adapter classes\n */\n get adapters() {\n return this.opts.adapters || {};\n }\n /**\n * Load a specific adapter by name.\n * @async\n * @param {string} adapter_name - Name of the adapter to load\n * @throws {Error} If adapter not found or loading fails\n * @returns {Promise}\n */\n async load_adapter(adapter_name) {\n this.set_adapter(adapter_name);\n if (!this._adapter.loaded) {\n this.set_state(\"loading\");\n try {\n await this.invoke_adapter_method(\"load\");\n this.set_state(\"loaded\");\n } catch (err) {\n this.set_state(\"unloaded\");\n throw new Error(`Failed to load adapter: ${err.message}`);\n }\n }\n }\n /**\n * Set an adapter instance by name without loading it.\n * @param {string} adapter_name - Name of the adapter to set\n * @throws {Error} If adapter not found\n */\n set_adapter(adapter_name) {\n const AdapterClass = this.adapters[adapter_name];\n if (!AdapterClass) {\n throw new Error(`Adapter \"${adapter_name}\" not found.`);\n }\n if (this._adapter?.constructor.name.toLowerCase() === adapter_name.toLowerCase()) {\n return;\n }\n this._adapter = new AdapterClass(this);\n }\n /**\n * Get the current active adapter instance\n * @returns {Object} The active adapter instance\n * @throws {Error} If adapter not found\n */\n get adapter() {\n const adapter_name = this.adapter_name;\n if (!adapter_name) {\n throw new Error(`Adapter not set for model.`);\n }\n if (!this._adapter) {\n this.load_adapter(adapter_name);\n }\n return this._adapter;\n }\n /**\n * Ensure the adapter is ready to execute a method.\n * @param {string} method - Name of the method to check\n * @throws {Error} If adapter not loaded or method not implemented\n */\n ensure_adapter_ready(method) {\n if (!this.adapter) {\n throw new Error(\"No adapter loaded.\");\n }\n if (typeof this.adapter[method] !== \"function\") {\n throw new Error(`Adapter does not implement method: ${method}`);\n }\n }\n /**\n * Invoke a method on the current adapter.\n * @async\n * @param {string} method - Name of the method to call\n * @param {...any} args - Arguments to pass to the method\n * @returns {Promise} Result from the adapter method\n * @throws {Error} If adapter not ready or method fails\n */\n async invoke_adapter_method(method, ...args) {\n this.ensure_adapter_ready(method);\n return await this.adapter[method](...args);\n }\n /**\n * Get platforms as dropdown options.\n * @returns {Array} Array of {value, name} option objects\n */\n get_platforms_as_options() {\n console.log(\"get_platforms_as_options\", this.adapters);\n return Object.entries(this.adapters).map(([key, AdapterClass]) => ({ value: key, name: AdapterClass.defaults.description || key }));\n }\n // SETTINGS\n /**\n * Get the settings configuration schema\n * @returns {Object} Settings configuration object\n */\n get settings_config() {\n return this.process_settings_config({\n adapter: {\n name: \"Model Platform\",\n type: \"dropdown\",\n description: \"Select a model platform to use with Smart Model.\",\n options_callback: \"get_platforms_as_options\",\n is_scope: true,\n // trigger re-render of settings when changed\n callback: \"adapter_changed\",\n default: \"default\"\n }\n });\n }\n /**\n * Process settings configuration with conditionals and prefixes.\n * @param {Object} _settings_config - Raw settings configuration\n * @param {string} [prefix] - Optional prefix for setting keys\n * @returns {Object} Processed settings configuration\n */\n process_settings_config(_settings_config, prefix = null) {\n return Object.entries(_settings_config).reduce((acc, [key, val]) => {\n if (val.conditional) {\n if (!val.conditional(this)) return acc;\n delete val.conditional;\n }\n const new_key = (prefix ? prefix + \".\" : \"\") + this.process_setting_key(key);\n acc[new_key] = val;\n return acc;\n }, {});\n }\n /**\n * Process an individual setting key.\n * @param {string} key - Setting key to process\n * @returns {string} Processed setting key\n */\n process_setting_key(key) {\n return key;\n }\n // override in sub-class if needed for prefixes and variable replacements\n re_render_settings() {\n if (typeof this.opts.re_render_settings === \"function\") this.opts.re_render_settings();\n else console.warn(\"re_render_settings is not a function (must be passed in model opts)\");\n }\n /**\n * Reload model.\n */\n reload_model() {\n console.log(\"reload_model\", this.opts);\n if (typeof this.opts.reload_model === \"function\") this.opts.reload_model();\n else console.warn(\"reload_model is not a function (must be passed in model opts)\");\n }\n adapter_changed() {\n this.reload_model();\n this.re_render_settings();\n }\n model_changed() {\n this.reload_model();\n this.re_render_settings();\n }\n // /**\n // * Render settings.\n // * @param {HTMLElement} [container] - Container element\n // * @param {Object} [opts] - Render options\n // * @returns {Promise} Container element\n // */\n // async render_settings(container=this.settings_container, opts = {}) {\n // if(!this.settings_container || container !== this.settings_container) this.settings_container = container;\n // const model_type = this.constructor.name.toLowerCase().replace('smart', '').replace('model', '');\n // let model_settings_container;\n // if(this.settings_container) {\n // const container_id = `#${model_type}-model-settings-container`;\n // model_settings_container = this.settings_container.querySelector(container_id);\n // if(!model_settings_container) {\n // model_settings_container = document.createElement('div');\n // model_settings_container.id = container_id;\n // this.settings_container.appendChild(model_settings_container);\n // }\n // model_settings_container.innerHTML = '
Loading ' + this.adapter_name + ' settings...
';\n // }\n // const frag = await this.render_settings_component(this, opts);\n // if(model_settings_container) {\n // model_settings_container.innerHTML = '';\n // model_settings_container.appendChild(frag);\n // this.smart_view.on_open_overlay(model_settings_container);\n // }\n // return frag;\n // }\n};\n__publicField(SmartModel, \"defaults\", {\n // override in sub-class if needed\n});\n\n// models.json\nvar models_default = {\n \"cohere-rerank-english-v3.0\": {\n adapter: \"cohere\",\n model_name: \"rerank-english-v3.0\",\n model_description: \"Cohere Rerank English v3.0\",\n model_version: \"3.0\",\n endpoint: \"https://api.cohere.ai/v1/rerank\",\n headers: {\n \"Content-Type\": \"application/json\"\n },\n api_key_header: \"Authorization\"\n },\n \"jinaai/jina-reranker-v1-tiny-en\": {\n adapter: \"transformers\",\n model_key: \"jinaai/jina-reranker-v1-tiny-en\"\n },\n \"jinaai/jina-reranker-v1-turbo-en\": {\n adapter: \"transformers\",\n model_key: \"jinaai/jina-reranker-v1-turbo-en\"\n },\n \"mixedbread-ai/mxbai-rerank-xsmall-v1\": {\n adapter: \"transformers\",\n model_key: \"mixedbread-ai/mxbai-rerank-xsmall-v1\"\n },\n \"Xenova/bge-reranker-base\": {\n adapter: \"transformers\",\n model_key: \"Xenova/bge-reranker-base\"\n }\n};\n\n// smart_rank_model.js\nvar SmartRankModel = class extends SmartModel {\n /**\n * Load the SmartRankModel with the specified configuration.\n * @param {Object} env - Environment configurations.\n * @param {Object} opts - Configuration options.\n * @param {string} opts.model_key - Model key to select the adapter.\n * @param {Object} [opts.adapters] - Optional map of adapters to override defaults.\n * @param {Object} [opts.settings] - Optional user settings.\n * @returns {Promise} Loaded SmartRankModel instance.\n * \n * @example\n * ```javascript\n * const rankModel = await SmartRankModel.load(env, {\n * model_key: 'cohere',\n * adapter: 'cohere',\n * settings: {\n * cohere_api_key: 'your-cohere-api-key',\n * },\n * });\n * ```\n */\n /**\n * Rank documents based on a query.\n * @param {string} query - The query string.\n * @param {Array} documents - Array of document strings to rank.\n * @param {Object} [options={}] - Additional ranking options.\n * @param {number} [options.top_k] - Limit the number of returned documents.\n * @param {boolean} [options.return_documents=false] - Whether to include original documents in results.\n * @returns {Promise>} Ranked documents with properties like {index, score, text}.\n * \n * @example\n * ```javascript\n * const rankings = await rankModel.rank(\"What is the capital of the United States?\", [\n * \"Carson City is the capital city of the American state of Nevada.\",\n * \"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.\",\n * \"Washington, D.C. is the capital of the United States.\",\n * ]);\n * console.log(rankings);\n * ```\n */\n async rank(query, documents, options = {}) {\n return await this.invoke_adapter_method(\"rank\", query, documents, options);\n }\n /**\n * Get available ranking models.\n * @returns {Object} Map of ranking models.\n */\n get models() {\n return models_default;\n }\n /**\n * Get the default model key.\n * @returns {string} Default model key.\n */\n get default_model_key() {\n return \"jinaai/jina-reranker-v1-tiny-en\";\n }\n /**\n * Get settings configuration schema.\n * @returns {Object} Settings configuration object.\n */\n get settings_config() {\n const _settings_config = {\n model_key: {\n name: \"Ranking Model\",\n type: \"dropdown\",\n description: \"Select a ranking model to use.\",\n options_callback: \"get_ranking_model_options\",\n callback: \"reload_model\",\n default: this.default_model_key\n },\n \"[RANKING_ADAPTER].cohere_api_key\": {\n name: \"Cohere API Key\",\n type: \"password\",\n description: \"Enter your Cohere API key for ranking.\",\n placeholder: \"Enter Cohere API Key\"\n },\n // Add adapter-specific settings here\n ...this.adapter.settings_config || {}\n };\n return this.process_settings_config(_settings_config, \"ranking_adapter\");\n }\n /**\n * Process setting keys to replace placeholders with actual adapter names.\n * @param {string} key - The setting key with placeholders.\n * @returns {string} Processed setting key.\n */\n process_setting_key(key) {\n return key.replace(/\\[RANKING_ADAPTER\\]/g, this.adapter_name);\n }\n /**\n * Get available ranking model options.\n * @returns {Array} Array of model options with value and name.\n */\n get_ranking_model_options() {\n return Object.keys(this.adapters).map((key) => ({ value: key, name: key }));\n }\n /**\n * Reload ranking model.\n */\n reload_model() {\n if (this.adapter && typeof this.adapter.load === \"function\") {\n this.adapter.load();\n }\n }\n};\n/**\n * Default configurations for SmartRankModel.\n * @type {Object}\n */\n__publicField(SmartRankModel, \"defaults\", {\n adapter: \"transformers\",\n // Default to transformers adapter\n model_key: \"jinaai/jina-reranker-v1-tiny-en\"\n});\n\n// ../smart-model/adapters/_adapter.js\nvar SmartModelAdapter = class {\n /**\n * Create a SmartModelAdapter instance.\n * @param {SmartModel} model - The parent SmartModel instance\n */\n constructor(model2) {\n this.model = model2;\n this.state = \"unloaded\";\n }\n /**\n * Load the adapter.\n * @async\n * @returns {Promise}\n */\n async load() {\n this.set_state(\"loaded\");\n }\n /**\n * Unload the adapter.\n * @returns {void}\n */\n unload() {\n this.set_state(\"unloaded\");\n }\n /**\n * Get all settings.\n * @returns {Object} All settings\n */\n get settings() {\n return this.model.settings;\n }\n /**\n * Get the current model key.\n * @returns {string} Current model identifier\n */\n get model_key() {\n return this.model.model_key;\n }\n /**\n * Get the current model configuration.\n * @returns {Object} Model configuration\n */\n get model_config() {\n return this.model.model_config;\n }\n /**\n * Get model-specific settings.\n * @returns {Object} Settings for current model\n */\n get model_settings() {\n return this.model.model_settings;\n }\n /**\n * Get adapter-specific configuration.\n * @returns {Object} Adapter configuration\n */\n get adapter_config() {\n return this.model.adapter_config;\n }\n /**\n * Get adapter-specific settings.\n * @returns {Object} Adapter settings\n */\n get adapter_settings() {\n return this.model.adapter_settings;\n }\n /**\n * Get the models.\n * @returns {Object} Map of model objects\n */\n get models() {\n if (typeof this.adapter_config.models === \"object\" && Object.keys(this.adapter_config.models || {}).length > 0) return this.adapter_config.models;\n else {\n return {};\n }\n }\n /**\n * Get available models from the API.\n * @abstract\n * @param {boolean} [refresh=false] - Whether to refresh cached models\n * @returns {Promise} Map of model objects\n */\n async get_models(refresh = false) {\n throw new Error(\"get_models not implemented\");\n }\n /**\n * Validate the parameters for get_models.\n * @returns {boolean|Array} True if parameters are valid, otherwise an array of error objects\n */\n validate_get_models_params() {\n return true;\n }\n /**\n * Get available models as dropdown options synchronously.\n * @returns {Array} Array of model options.\n */\n get_models_as_options_sync() {\n const models = this.models;\n const params_valid = this.validate_get_models_params();\n if (params_valid !== true) return params_valid;\n if (!Object.keys(models || {}).length) {\n this.get_models(true);\n return [{ value: \"\", name: \"No models currently available\" }];\n }\n return Object.values(models).map((model2) => ({ value: model2.id, name: model2.name || model2.id })).sort((a, b) => a.name.localeCompare(b.name));\n }\n /**\n * Set the adapter's state.\n * @param {('unloaded'|'loading'|'loaded'|'unloading')} new_state - The new state\n * @throws {Error} If the state is invalid\n */\n set_state(new_state) {\n const valid_states = [\"unloaded\", \"loading\", \"loaded\", \"unloading\"];\n if (!valid_states.includes(new_state)) {\n throw new Error(`Invalid state: ${new_state}`);\n }\n this.state = new_state;\n }\n // Replace individual state getters/setters with a unified state management\n get is_loading() {\n return this.state === \"loading\";\n }\n get is_loaded() {\n return this.state === \"loaded\";\n }\n get is_unloading() {\n return this.state === \"unloading\";\n }\n get is_unloaded() {\n return this.state === \"unloaded\";\n }\n};\n\n// adapters/_adapter.js\nvar SmartRankAdapter = class extends SmartModelAdapter {\n /**\n * Create a SmartRankAdapter instance.\n * @param {SmartRankModel} model - The parent SmartRankModel instance\n */\n constructor(model2) {\n super(model2);\n this.smart_rank = model2;\n }\n /**\n * Rank documents based on a query.\n * @abstract\n * @param {string} query - The query string\n * @param {Array} documents - The documents to rank\n * @returns {Promise>} Array of ranking results {index, score, ...}\n * @throws {Error} If the method is not implemented by subclass\n */\n async rank(query, documents) {\n throw new Error(\"rank method not implemented\");\n }\n};\n\n// adapters/transformers.js\nvar transformers_defaults = {\n adapter: \"transformers\",\n description: \"Transformers\",\n default_model: \"jinaai/jina-reranker-v1-tiny-en\"\n};\nvar SmartRankTransformersAdapter = class extends SmartRankAdapter {\n /**\n * Create transformers adapter instance\n * @param {SmartRankModel} model - Parent model instance\n */\n constructor(model2) {\n super(model2);\n this.model_instance = null;\n this.tokenizer = null;\n }\n /**\n * Load model and tokenizer\n * @async\n * @returns {Promise}\n */\n async load() {\n console.log(\"TransformersAdapter initializing\");\n console.log(this.model.model_key);\n const { AutoTokenizer, AutoModelForSequenceClassification, env } = await import(\"https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.1.1\");\n env.allowLocalModels = false;\n const pipeline_opts = {\n quantized: true\n };\n if (this.model.opts.use_gpu) {\n console.log(\"[Transformers] Using GPU\");\n pipeline_opts.device = \"webgpu\";\n } else {\n console.log(\"[Transformers] Using CPU\");\n }\n this.model_instance = await AutoModelForSequenceClassification.from_pretrained(this.model.model_key, pipeline_opts);\n this.tokenizer = await AutoTokenizer.from_pretrained(this.model.model_key);\n console.log(\"TransformersAdapter initialized\");\n }\n /**\n * Rank documents based on a query\n * @param {string} query - The query string\n * @param {Array} documents - Documents to rank\n * @param {Object} [options={}] - Additional ranking options\n * @param {number} [options.top_k] - Limit the number of returned documents\n * @param {boolean} [options.return_documents=false] - Whether to include original documents in results\n * @returns {Promise>} Ranked documents with properties like {index, score, text}\n */\n async rank(query, documents, options = {}) {\n console.log(\"TransformersAdapter ranking\");\n console.log(documents);\n const { top_k = void 0, return_documents = false } = options;\n if (!this.model_instance || !this.tokenizer) await this.load();\n console.log(\"tokenizing\");\n const inputs = this.tokenizer(\n new Array(documents.length).fill(query),\n { text_pair: documents, padding: true, truncation: true }\n );\n console.log(\"running model\");\n const { logits } = await this.model_instance(inputs);\n console.log(\"done\");\n return logits.sigmoid().tolist().map(([score], i) => ({\n index: i,\n score,\n ...return_documents ? { text: documents[i] } : {}\n })).sort((a, b) => b.score - a.score).slice(0, top_k);\n }\n};\n__publicField(SmartRankTransformersAdapter, \"defaults\", transformers_defaults);\n\n// build/transformers_iframe_script.js\nvar model = null;\nasync function process_message(data) {\n const { method, params, id, iframe_id } = data;\n try {\n let result;\n switch (method) {\n case \"init\":\n console.log(\"init\");\n break;\n case \"load\":\n console.log(\"load\", params);\n model = new SmartRankModel({\n ...params,\n adapters: { transformers: SmartRankTransformersAdapter },\n adapter: \"transformers\",\n settings: {}\n });\n await model.load();\n result = { model_loaded: true };\n break;\n case \"rank\":\n if (!model) throw new Error(\"Model not loaded\");\n result = await model.rank(params.query, params.documents);\n break;\n default:\n throw new Error(`Unknown method: ${method}`);\n }\n return { id, result, iframe_id };\n } catch (error) {\n console.error(\"Error processing message:\", error);\n return { id, error: error.message, iframe_id };\n }\n}\nprocess_message({ method: \"init\" });\n"; \ No newline at end of file diff --git a/smart-rank-model/connectors/transformers_worker.js b/smart-rank-model/connectors/transformers_worker.js index 0976b1f0..35b9b4dd 100644 --- a/smart-rank-model/connectors/transformers_worker.js +++ b/smart-rank-model/connectors/transformers_worker.js @@ -1,3 +1,353 @@ +var __defProp = Object.defineProperty; +var __defNormalProp = (obj, key, value) => key in obj ? __defProp(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value; +var __publicField = (obj, key, value) => __defNormalProp(obj, typeof key !== "symbol" ? key + "" : key, value); + +// ../smart-model/smart_model.js +var SmartModel = class { + /** + * Create a SmartModel instance. + * @param {Object} opts - Configuration options + * @param {Object} opts.adapters - Map of adapter names to adapter classes + * @param {Object} opts.settings - Model settings configuration + * @param {Object} opts.model_config - Model-specific configuration + * @param {string} opts.model_config.adapter - Name of the adapter to use + * @param {string} [opts.model_key] - Optional model identifier to override settings + * @throws {Error} If required options are missing + */ + constructor(opts = {}) { + __publicField(this, "scope_name", "smart_model"); + this.opts = opts; + this.validate_opts(opts); + this.state = "unloaded"; + this._adapter = null; + } + /** + * Initialize the model by loading the configured adapter. + * @async + * @returns {Promise} + */ + async initialize() { + this.load_adapter(this.adapter_name); + await this.load(); + } + /** + * Validate required options. + * @param {Object} opts - Configuration options + */ + validate_opts(opts) { + if (!opts.adapters) throw new Error("opts.adapters is required"); + if (!opts.settings) throw new Error("opts.settings is required"); + } + /** + * Get the current settings + * @returns {Object} Current settings + */ + get settings() { + if (!this.opts.settings) this.opts.settings = { + ...this.constructor.defaults + }; + return this.opts.settings; + } + /** + * Get the current adapter name + * @returns {string} Current adapter name + */ + get adapter_name() { + const adapter_key = this.opts.model_config?.adapter || this.opts.adapter || this.settings.adapter || Object.keys(this.adapters)[0]; + if (!adapter_key || !this.adapters[adapter_key]) throw new Error(`Platform "${adapter_key}" not supported`); + return adapter_key; + } + /** + * Get adapter-specific settings. + * @returns {Object} Settings for current adapter + */ + get adapter_settings() { + if (!this.settings[this.adapter_name]) this.settings[this.adapter_name] = {}; + return this.settings[this.adapter_name]; + } + get adapter_config() { + const base_config = this.adapters[this.adapter_name]?.defaults || {}; + return { + ...base_config, + ...this.adapter_settings, + ...this.opts.adapter_config + }; + } + /** + * Get available models. + * @returns {Object} Map of model objects + */ + get models() { + return this.adapter.models; + } + /** + * Get the default model key to use + * @returns {string} Default model identifier + */ + get default_model_key() { + throw new Error("default_model_key must be overridden in sub-class"); + } + /** + * Get the current model key + * @returns {string} Current model key + */ + get model_key() { + return this.opts.model_key || this.adapter_config.model_key || this.settings.model_key || this.default_model_key; + } + /** + * Get the current model configuration + * @returns {Object} Combined base and custom model configuration + */ + get model_config() { + const model_key = this.model_key; + const base_model_config = this.models[model_key] || {}; + return { + ...this.adapter_config, + ...base_model_config, + ...this.opts.model_config + }; + } + get model_settings() { + if (!this.settings[this.model_key]) this.settings[this.model_key] = {}; + return this.settings[this.model_key]; + } + /** + * Load the current adapter and transition to loaded state. + * @async + * @returns {Promise} + */ + async load() { + this.set_state("loading"); + if (!this.adapter?.loaded) { + await this.invoke_adapter_method("load"); + } + this.set_state("loaded"); + } + /** + * Unload the current adapter and transition to unloaded state. + * @async + * @returns {Promise} + */ + async unload() { + if (this.adapter?.loaded) { + this.set_state("unloading"); + await this.invoke_adapter_method("unload"); + this.set_state("unloaded"); + } + } + /** + * Set the model's state. + * @param {('unloaded'|'loading'|'loaded'|'unloading')} new_state - The new state + * @throws {Error} If the state is invalid + */ + set_state(new_state) { + const valid_states = ["unloaded", "loading", "loaded", "unloading"]; + if (!valid_states.includes(new_state)) { + throw new Error(`Invalid state: ${new_state}`); + } + this.state = new_state; + } + get is_loading() { + return this.state === "loading"; + } + get is_loaded() { + return this.state === "loaded"; + } + get is_unloading() { + return this.state === "unloading"; + } + get is_unloaded() { + return this.state === "unloaded"; + } + // ADAPTERS + /** + * Get the map of available adapters + * @returns {Object} Map of adapter names to adapter classes + */ + get adapters() { + return this.opts.adapters || {}; + } + /** + * Load a specific adapter by name. + * @async + * @param {string} adapter_name - Name of the adapter to load + * @throws {Error} If adapter not found or loading fails + * @returns {Promise} + */ + async load_adapter(adapter_name) { + this.set_adapter(adapter_name); + if (!this._adapter.loaded) { + this.set_state("loading"); + try { + await this.invoke_adapter_method("load"); + this.set_state("loaded"); + } catch (err) { + this.set_state("unloaded"); + throw new Error(`Failed to load adapter: ${err.message}`); + } + } + } + /** + * Set an adapter instance by name without loading it. + * @param {string} adapter_name - Name of the adapter to set + * @throws {Error} If adapter not found + */ + set_adapter(adapter_name) { + const AdapterClass = this.adapters[adapter_name]; + if (!AdapterClass) { + throw new Error(`Adapter "${adapter_name}" not found.`); + } + if (this._adapter?.constructor.name.toLowerCase() === adapter_name.toLowerCase()) { + return; + } + this._adapter = new AdapterClass(this); + } + /** + * Get the current active adapter instance + * @returns {Object} The active adapter instance + * @throws {Error} If adapter not found + */ + get adapter() { + const adapter_name = this.adapter_name; + if (!adapter_name) { + throw new Error(`Adapter not set for model.`); + } + if (!this._adapter) { + this.load_adapter(adapter_name); + } + return this._adapter; + } + /** + * Ensure the adapter is ready to execute a method. + * @param {string} method - Name of the method to check + * @throws {Error} If adapter not loaded or method not implemented + */ + ensure_adapter_ready(method) { + if (!this.adapter) { + throw new Error("No adapter loaded."); + } + if (typeof this.adapter[method] !== "function") { + throw new Error(`Adapter does not implement method: ${method}`); + } + } + /** + * Invoke a method on the current adapter. + * @async + * @param {string} method - Name of the method to call + * @param {...any} args - Arguments to pass to the method + * @returns {Promise} Result from the adapter method + * @throws {Error} If adapter not ready or method fails + */ + async invoke_adapter_method(method, ...args) { + this.ensure_adapter_ready(method); + return await this.adapter[method](...args); + } + /** + * Get platforms as dropdown options. + * @returns {Array} Array of {value, name} option objects + */ + get_platforms_as_options() { + console.log("get_platforms_as_options", this.adapters); + return Object.entries(this.adapters).map(([key, AdapterClass]) => ({ value: key, name: AdapterClass.defaults.description || key })); + } + // SETTINGS + /** + * Get the settings configuration schema + * @returns {Object} Settings configuration object + */ + get settings_config() { + return this.process_settings_config({ + adapter: { + name: "Model Platform", + type: "dropdown", + description: "Select a model platform to use with Smart Model.", + options_callback: "get_platforms_as_options", + is_scope: true, + // trigger re-render of settings when changed + callback: "adapter_changed", + default: "default" + } + }); + } + /** + * Process settings configuration with conditionals and prefixes. + * @param {Object} _settings_config - Raw settings configuration + * @param {string} [prefix] - Optional prefix for setting keys + * @returns {Object} Processed settings configuration + */ + process_settings_config(_settings_config, prefix = null) { + return Object.entries(_settings_config).reduce((acc, [key, val]) => { + if (val.conditional) { + if (!val.conditional(this)) return acc; + delete val.conditional; + } + const new_key = (prefix ? prefix + "." : "") + this.process_setting_key(key); + acc[new_key] = val; + return acc; + }, {}); + } + /** + * Process an individual setting key. + * @param {string} key - Setting key to process + * @returns {string} Processed setting key + */ + process_setting_key(key) { + return key; + } + // override in sub-class if needed for prefixes and variable replacements + re_render_settings() { + if (typeof this.opts.re_render_settings === "function") this.opts.re_render_settings(); + else console.warn("re_render_settings is not a function (must be passed in model opts)"); + } + /** + * Reload model. + */ + reload_model() { + console.log("reload_model", this.opts); + if (typeof this.opts.reload_model === "function") this.opts.reload_model(); + else console.warn("reload_model is not a function (must be passed in model opts)"); + } + adapter_changed() { + this.reload_model(); + this.re_render_settings(); + } + model_changed() { + this.reload_model(); + this.re_render_settings(); + } + // /** + // * Render settings. + // * @param {HTMLElement} [container] - Container element + // * @param {Object} [opts] - Render options + // * @returns {Promise} Container element + // */ + // async render_settings(container=this.settings_container, opts = {}) { + // if(!this.settings_container || container !== this.settings_container) this.settings_container = container; + // const model_type = this.constructor.name.toLowerCase().replace('smart', '').replace('model', ''); + // let model_settings_container; + // if(this.settings_container) { + // const container_id = `#${model_type}-model-settings-container`; + // model_settings_container = this.settings_container.querySelector(container_id); + // if(!model_settings_container) { + // model_settings_container = document.createElement('div'); + // model_settings_container.id = container_id; + // this.settings_container.appendChild(model_settings_container); + // } + // model_settings_container.innerHTML = '
Loading ' + this.adapter_name + ' settings...
'; + // } + // const frag = await this.render_settings_component(this, opts); + // if(model_settings_container) { + // model_settings_container.innerHTML = ''; + // model_settings_container.appendChild(frag); + // this.smart_view.on_open_overlay(model_settings_container); + // } + // return frag; + // } +}; +__publicField(SmartModel, "defaults", { + // override in sub-class if needed +}); + // models.json var models_default = { "cohere-rerank-english-v3.0": { @@ -5,7 +355,11 @@ var models_default = { model_name: "rerank-english-v3.0", model_description: "Cohere Rerank English v3.0", model_version: "3.0", - endpoint: "https://api.cohere.ai/v1/rerank" + endpoint: "https://api.cohere.ai/v1/rerank", + headers: { + "Content-Type": "application/json" + }, + api_key_header: "Authorization" }, "jinaai/jina-reranker-v1-tiny-en": { adapter: "transformers", @@ -26,99 +380,340 @@ var models_default = { }; // smart_rank_model.js -var SmartRankModel = class _SmartRankModel { +var SmartRankModel = class extends SmartModel { + /** + * Load the SmartRankModel with the specified configuration. + * @param {Object} env - Environment configurations. + * @param {Object} opts - Configuration options. + * @param {string} opts.model_key - Model key to select the adapter. + * @param {Object} [opts.adapters] - Optional map of adapters to override defaults. + * @param {Object} [opts.settings] - Optional user settings. + * @returns {Promise} Loaded SmartRankModel instance. + * + * @example + * ```javascript + * const rankModel = await SmartRankModel.load(env, { + * model_key: 'cohere', + * adapter: 'cohere', + * settings: { + * cohere_api_key: 'your-cohere-api-key', + * }, + * }); + * ``` + */ + /** + * Rank documents based on a query. + * @param {string} query - The query string. + * @param {Array} documents - Array of document strings to rank. + * @param {Object} [options={}] - Additional ranking options. + * @param {number} [options.top_k] - Limit the number of returned documents. + * @param {boolean} [options.return_documents=false] - Whether to include original documents in results. + * @returns {Promise>} Ranked documents with properties like {index, score, text}. + * + * @example + * ```javascript + * const rankings = await rankModel.rank("What is the capital of the United States?", [ + * "Carson City is the capital city of the American state of Nevada.", + * "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + * "Washington, D.C. is the capital of the United States.", + * ]); + * console.log(rankings); + * ``` + */ + async rank(query, documents, options = {}) { + return await this.invoke_adapter_method("rank", query, documents, options); + } + /** + * Get available ranking models. + * @returns {Object} Map of ranking models. + */ + get models() { + return models_default; + } + /** + * Get the default model key. + * @returns {string} Default model key. + */ + get default_model_key() { + return "jinaai/jina-reranker-v1-tiny-en"; + } /** - * Create a SmartRank instance. - * @param {string} env - The environment to use. - * @param {object} opts - Full model configuration object or at least a model_key and adapter + * Get settings configuration schema. + * @returns {Object} Settings configuration object. */ - constructor(env, opts = {}) { - this.env = env; - this.opts = { - ...models_default[opts.model_key] || {}, - ...opts + get settings_config() { + const _settings_config = { + model_key: { + name: "Ranking Model", + type: "dropdown", + description: "Select a ranking model to use.", + options_callback: "get_ranking_model_options", + callback: "reload_model", + default: this.default_model_key + }, + "[RANKING_ADAPTER].cohere_api_key": { + name: "Cohere API Key", + type: "password", + description: "Enter your Cohere API key for ranking.", + placeholder: "Enter Cohere API Key" + }, + // Add adapter-specific settings here + ...this.adapter.settings_config || {} }; - if (!this.opts.adapter) return console.warn("SmartRankModel adapter not set"); - if (!this.env.opts.smart_rank_adapters[this.opts.adapter]) return console.warn(`SmartRankModel adapter ${this.opts.adapter} not found`); - if (typeof navigator !== "undefined") this.opts.use_gpu = !!navigator?.gpu && this.opts.gpu_batch_size !== 0; - this.opts.use_gpu = false; - this.adapter = new this.env.opts.smart_rank_adapters[this.opts.adapter](this); - } - /** - * Used to load a model with a given configuration. - * @param {*} env - * @param {*} opts - */ - static async load(env, opts = {}) { - if (env.smart_rank_active_models?.[opts.model_key]) return env.smart_rank_active_models[opts.model_key]; - try { - const model2 = new _SmartRankModel(env, opts); - await model2.adapter.load(); - if (!env.smart_rank_active_models) env.smart_rank_active_models = {}; - env.smart_rank_active_models[opts.model_key] = model2; - return model2; - } catch (error) { - console.error(`Error loading rank model ${opts.model_key}:`, error); - return null; - } + return this.process_settings_config(_settings_config, "ranking_adapter"); } - async rank(query, documents) { - return this.adapter.rank(query, documents); + /** + * Process setting keys to replace placeholders with actual adapter names. + * @param {string} key - The setting key with placeholders. + * @returns {string} Processed setting key. + */ + process_setting_key(key) { + return key.replace(/\[RANKING_ADAPTER\]/g, this.adapter_name); + } + /** + * Get available ranking model options. + * @returns {Array} Array of model options with value and name. + */ + get_ranking_model_options() { + return Object.keys(this.adapters).map((key) => ({ value: key, name: key })); + } + /** + * Reload ranking model. + */ + reload_model() { + if (this.adapter && typeof this.adapter.load === "function") { + this.adapter.load(); + } } }; +/** + * Default configurations for SmartRankModel. + * @type {Object} + */ +__publicField(SmartRankModel, "defaults", { + adapter: "transformers", + // Default to transformers adapter + model_key: "jinaai/jina-reranker-v1-tiny-en" +}); -// adapters/_adapter.js -var SmartRankAdapter = class { - constructor(smart_rank) { - this.smart_rank = smart_rank; +// ../smart-model/adapters/_adapter.js +var SmartModelAdapter = class { + /** + * Create a SmartModelAdapter instance. + * @param {SmartModel} model - The parent SmartModel instance + */ + constructor(model2) { + this.model = model2; + this.state = "unloaded"; } + /** + * Load the adapter. + * @async + * @returns {Promise} + */ async load() { - throw new Error("Not implemented"); + this.set_state("loaded"); + } + /** + * Unload the adapter. + * @returns {void} + */ + unload() { + this.set_state("unloaded"); + } + /** + * Get all settings. + * @returns {Object} All settings + */ + get settings() { + return this.model.settings; + } + /** + * Get the current model key. + * @returns {string} Current model identifier + */ + get model_key() { + return this.model.model_key; + } + /** + * Get the current model configuration. + * @returns {Object} Model configuration + */ + get model_config() { + return this.model.model_config; + } + /** + * Get model-specific settings. + * @returns {Object} Settings for current model + */ + get model_settings() { + return this.model.model_settings; } + /** + * Get adapter-specific configuration. + * @returns {Object} Adapter configuration + */ + get adapter_config() { + return this.model.adapter_config; + } + /** + * Get adapter-specific settings. + * @returns {Object} Adapter settings + */ + get adapter_settings() { + return this.model.adapter_settings; + } + /** + * Get the models. + * @returns {Object} Map of model objects + */ + get models() { + if (typeof this.adapter_config.models === "object" && Object.keys(this.adapter_config.models || {}).length > 0) return this.adapter_config.models; + else { + return {}; + } + } + /** + * Get available models from the API. + * @abstract + * @param {boolean} [refresh=false] - Whether to refresh cached models + * @returns {Promise} Map of model objects + */ + async get_models(refresh = false) { + throw new Error("get_models not implemented"); + } + /** + * Validate the parameters for get_models. + * @returns {boolean|Array} True if parameters are valid, otherwise an array of error objects + */ + validate_get_models_params() { + return true; + } + /** + * Get available models as dropdown options synchronously. + * @returns {Array} Array of model options. + */ + get_models_as_options_sync() { + const models = this.models; + const params_valid = this.validate_get_models_params(); + if (params_valid !== true) return params_valid; + if (!Object.keys(models || {}).length) { + this.get_models(true); + return [{ value: "", name: "No models currently available" }]; + } + return Object.values(models).map((model2) => ({ value: model2.id, name: model2.name || model2.id })).sort((a, b) => a.name.localeCompare(b.name)); + } + /** + * Set the adapter's state. + * @param {('unloaded'|'loading'|'loaded'|'unloading')} new_state - The new state + * @throws {Error} If the state is invalid + */ + set_state(new_state) { + const valid_states = ["unloaded", "loading", "loaded", "unloading"]; + if (!valid_states.includes(new_state)) { + throw new Error(`Invalid state: ${new_state}`); + } + this.state = new_state; + } + // Replace individual state getters/setters with a unified state management + get is_loading() { + return this.state === "loading"; + } + get is_loaded() { + return this.state === "loaded"; + } + get is_unloading() { + return this.state === "unloading"; + } + get is_unloaded() { + return this.state === "unloaded"; + } +}; + +// adapters/_adapter.js +var SmartRankAdapter = class extends SmartModelAdapter { + /** + * Create a SmartRankAdapter instance. + * @param {SmartRankModel} model - The parent SmartRankModel instance + */ + constructor(model2) { + super(model2); + this.smart_rank = model2; + } + /** + * Rank documents based on a query. + * @abstract + * @param {string} query - The query string + * @param {Array} documents - The documents to rank + * @returns {Promise>} Array of ranking results {index, score, ...} + * @throws {Error} If the method is not implemented by subclass + */ async rank(query, documents) { - throw new Error("Not implemented"); + throw new Error("rank method not implemented"); } }; // adapters/transformers.js +var transformers_defaults = { + adapter: "transformers", + description: "Transformers", + default_model: "jinaai/jina-reranker-v1-tiny-en" +}; var SmartRankTransformersAdapter = class extends SmartRankAdapter { - constructor(smart_rank) { - super(smart_rank); - this.model = null; + /** + * Create transformers adapter instance + * @param {SmartRankModel} model - Parent model instance + */ + constructor(model2) { + super(model2); + this.model_instance = null; this.tokenizer = null; } - get use_gpu() { - return this.smart_rank.opts.use_gpu || false; - } + /** + * Load model and tokenizer + * @async + * @returns {Promise} + */ async load() { console.log("TransformersAdapter initializing"); - const { env, AutoTokenizer, AutoModelForSequenceClassification } = await import("https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.1"); - console.log("Transformers loaded"); + console.log(this.model.model_key); + const { AutoTokenizer, AutoModelForSequenceClassification, env } = await import("https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.1.1"); env.allowLocalModels = false; const pipeline_opts = { quantized: true }; - if (this.use_gpu) { + if (this.model.opts.use_gpu) { console.log("[Transformers] Using GPU"); pipeline_opts.device = "webgpu"; - pipeline_opts.dtype = "fp32"; } else { console.log("[Transformers] Using CPU"); - env.backends.onnx.wasm.numThreads = 8; } - this.model = await AutoModelForSequenceClassification.from_pretrained(this.smart_rank.opts.model_key, pipeline_opts); - console.log("Model loaded"); - this.tokenizer = await AutoTokenizer.from_pretrained(this.smart_rank.opts.model_key); - console.log("Tokenizer loaded"); + this.model_instance = await AutoModelForSequenceClassification.from_pretrained(this.model.model_key, pipeline_opts); + this.tokenizer = await AutoTokenizer.from_pretrained(this.model.model_key); console.log("TransformersAdapter initialized"); } + /** + * Rank documents based on a query + * @param {string} query - The query string + * @param {Array} documents - Documents to rank + * @param {Object} [options={}] - Additional ranking options + * @param {number} [options.top_k] - Limit the number of returned documents + * @param {boolean} [options.return_documents=false] - Whether to include original documents in results + * @returns {Promise>} Ranked documents with properties like {index, score, text} + */ async rank(query, documents, options = {}) { + console.log("TransformersAdapter ranking"); + console.log(documents); const { top_k = void 0, return_documents = false } = options; + if (!this.model_instance || !this.tokenizer) await this.load(); + console.log("tokenizing"); const inputs = this.tokenizer( new Array(documents.length).fill(query), { text_pair: documents, padding: true, truncation: true } ); - const { logits } = await this.model(inputs); + console.log("running model"); + const { logits } = await this.model_instance(inputs); + console.log("done"); return logits.sigmoid().tolist().map(([score], i) => ({ index: i, score, @@ -126,17 +721,10 @@ var SmartRankTransformersAdapter = class extends SmartRankAdapter { })).sort((a, b) => b.score - a.score).slice(0, top_k); } }; +__publicField(SmartRankTransformersAdapter, "defaults", transformers_defaults); // build/transformers_worker_script.js var model = null; -var smart_env = { - smart_rank_active_models: {}, - opts: { - smart_rank_adapters: { - transformers: SmartRankTransformersAdapter - } - } -}; async function process_message(data) { const { method, params, id, worker_id } = data; try { @@ -144,7 +732,15 @@ async function process_message(data) { switch (method) { case "load": console.log("load", params); - model = await SmartRankModel.load(smart_env, { adapter: "transformers", model_key: params.model_key, ...params }); + if (!model) { + model = new SmartRankModel({ + ...params, + adapters: { transformers: SmartRankTransformersAdapter }, + adapter: "transformers", + settings: {} + }); + await model.load(); + } result = { model_loaded: true }; break; case "rank": @@ -161,7 +757,6 @@ async function process_message(data) { } } self.addEventListener("message", async (event) => { - console.log("message", event.data); const response = await process_message(event.data); self.postMessage(response); }); diff --git a/smart-rank-model/package.json b/smart-rank-model/package.json index 2fa6c8e1..30766fa1 100644 --- a/smart-rank-model/package.json +++ b/smart-rank-model/package.json @@ -25,6 +25,7 @@ "url": "https://github.com/brianpetro/jsbrains/issues" }, "dependencies": { + "@huggingface/transformers": "^3.1.1", "smart-model": "file:../smart-model" }, "homepage": "https://jsbrains.org", diff --git a/smart-rank-model/smart_rank_model.js b/smart-rank-model/smart_rank_model.js index 96ef5077..0b2ce1fe 100644 --- a/smart-rank-model/smart_rank_model.js +++ b/smart-rank-model/smart_rank_model.js @@ -1,3 +1,5 @@ +// smart_rank_model.js + // Copyright (c) Brian Joseph Petro // Permission is hereby granted, free of charge, to any person obtaining @@ -23,11 +25,13 @@ import { SmartModel } from 'smart-model'; import rank_models from './models.json' assert { type: 'json' }; /** - * SmartRankModel - A versatile class for handling document ranking using various model backends + * SmartRankModel - A versatile class for handling document ranking using various model backends. * @extends SmartModel * * @example * ```javascript + * import { SmartRankModel } from 'smart-rank-model'; + * * const rankModel = await SmartRankModel.load(env, { * model_key: 'cohere', * adapter: 'cohere', @@ -41,69 +45,79 @@ import rank_models from './models.json' assert { type: 'json' }; * ``` */ export class SmartRankModel extends SmartModel { + /** + * Default configurations for SmartRankModel. + * @type {Object} + */ static defaults = { - model_key: 'cohere', // Default to Cohere adapter + adapter: 'transformers', // Default to transformers adapter + model_key: 'jinaai/jina-reranker-v1-tiny-en', }; - /** - * Create a SmartRankModel instance - * @param {Object} opts - Configuration options - * @param {Object} [opts.adapters] - Map of available adapter implementations - * @param {Object} [opts.settings] - User settings - * @param {string} [opts.model_key] - Model key to select the adapter - */ - constructor(opts = {}) { - super(opts); - } /** * Load the SmartRankModel with the specified configuration. - * @param {Object} env - Environment configurations - * @param {Object} opts - Configuration options - * @returns {Promise} Loaded SmartRankModel instance + * @param {Object} env - Environment configurations. + * @param {Object} opts - Configuration options. + * @param {string} opts.model_key - Model key to select the adapter. + * @param {Object} [opts.adapters] - Optional map of adapters to override defaults. + * @param {Object} [opts.settings] - Optional user settings. + * @returns {Promise} Loaded SmartRankModel instance. + * + * @example + * ```javascript + * const rankModel = await SmartRankModel.load(env, { + * model_key: 'cohere', + * adapter: 'cohere', + * settings: { + * cohere_api_key: 'your-cohere-api-key', + * }, + * }); + * ``` */ - static async load(env, opts = {}) { - if (env.smart_rank_active_models?.[opts.model_key]) { - return env.smart_rank_active_models[opts.model_key]; - } - try { - const model = new SmartRankModel(opts); - await model.adapter.load(); - if (!env.smart_rank_active_models) env.smart_rank_active_models = {}; - env.smart_rank_active_models[opts.model_key] = model; - return model; - } catch (error) { - console.error(`Error loading rank model ${opts.model_key}:`, error); - return null; - } - } /** * Rank documents based on a query. - * @param {string} query - The query string - * @param {Array} documents - Array of document strings - * @returns {Promise>} Ranked documents + * @param {string} query - The query string. + * @param {Array} documents - Array of document strings to rank. + * @param {Object} [options={}] - Additional ranking options. + * @param {number} [options.top_k] - Limit the number of returned documents. + * @param {boolean} [options.return_documents=false] - Whether to include original documents in results. + * @returns {Promise>} Ranked documents with properties like {index, score, text}. + * + * @example + * ```javascript + * const rankings = await rankModel.rank("What is the capital of the United States?", [ + * "Carson City is the capital city of the American state of Nevada.", + * "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + * "Washington, D.C. is the capital of the United States.", + * ]); + * console.log(rankings); + * ``` */ - async rank(query, documents) { - return await this.invoke_adapter_method('rank', query, documents); + async rank(query, documents, options = {}) { + return await this.invoke_adapter_method('rank', query, documents, options); } /** * Get available ranking models. - * @returns {Object} Map of ranking models + * @returns {Object} Map of ranking models. */ get models() { return rank_models; } - /** @override */ + /** + * Get the default model key. + * @returns {string} Default model key. + */ get default_model_key() { - return 'cohere'; // Ensure consistency with adapters + return 'jinaai/jina-reranker-v1-tiny-en'; // Ensure consistency with adapters } /** * Get settings configuration schema. - * @returns {Object} Settings configuration object + * @returns {Object} Settings configuration object. */ get settings_config() { const _settings_config = { @@ -127,13 +141,18 @@ export class SmartRankModel extends SmartModel { return this.process_settings_config(_settings_config, 'ranking_adapter'); } + /** + * Process setting keys to replace placeholders with actual adapter names. + * @param {string} key - The setting key with placeholders. + * @returns {string} Processed setting key. + */ process_setting_key(key) { return key.replace(/\[RANKING_ADAPTER\]/g, this.adapter_name); } /** * Get available ranking model options. - * @returns {Array} Array of model options with value and name + * @returns {Array} Array of model options with value and name. */ get_ranking_model_options() { return Object.keys(this.adapters).map(key => ({ value: key, name: key })); @@ -147,4 +166,4 @@ export class SmartRankModel extends SmartModel { this.adapter.load(); } } -} \ No newline at end of file +} diff --git a/smart-rank-model/test/cohere.test.js b/smart-rank-model/test/cohere.test.js index 166988d4..8ae3f5e8 100644 --- a/smart-rank-model/test/cohere.test.js +++ b/smart-rank-model/test/cohere.test.js @@ -1,34 +1,41 @@ import test from 'ava'; -import { CohereAdapter } from './cohere.js'; +import dotenv from 'dotenv'; import path from 'path'; -import { config } from 'dotenv'; import { fileURLToPath } from 'url'; import { dirname } from 'path'; +import { SmartRankModel } from '../smart_rank_model.js'; +import { cohere as SmartRankCohereAdapter } from '../adapters.js'; const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); +dotenv.config({ path: path.join(__dirname, '..', '..', '.env') }); -config({ path: path.join(__dirname, '..', '..', '.env') }); - -// Ensure that the API key and other necessary environment variables are set -const api_key = process.env.COHERE_API_KEY; -const model_name = 'rerank-english-v2.0'; -const endpoint = "https://api.cohere.ai/v1/rerank"; +const api_key = process.env.COHERE_API_KEY || 'fake_cohere_api_key'; test('CohereAdapter rank function returns expected results', async t => { - const adapter = new CohereAdapter({config: {api_key, model_name, endpoint}}); const query = "What is the capital of the United States?"; const documents = [ "Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", - "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", - "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." + "Washington, D.C. (also known as simply Washington) is the capital of the United States.", + "Capital punishment (the death penalty) has existed in the United States since before the United States was a country." ]; - const response = await adapter.rank(query, documents); - console.log(response); - // console.log({resp: response}); // Optionally log the response for debugging + const model = await SmartRankModel.load({ + smart_rank_adapters: { + cohere: SmartRankCohereAdapter + } + }, { + adapter: 'cohere', + model_config: { + api_key, + endpoint: "https://api.cohere.ai/v1/rerank" + } + }); - t.is(documents[response.results[0].index], "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", 'The top document should correctly identify Washington, D.C. as the capital'); + const response = await model.rank(query, documents); + t.true(Array.isArray(response), 'Response should be an array of ranked results'); + // Check if the top-ranked document is the one mentioning Washington, D.C. + const topRankedDoc = documents[response[0].index]; + t.truthy(topRankedDoc.includes("Washington, D.C."), 'Top ranked doc should mention Washington, D.C.'); }); - diff --git a/smart-rank-model/test/transformers.test.js b/smart-rank-model/test/transformers.test.js index 8e971808..c3e79e70 100644 --- a/smart-rank-model/test/transformers.test.js +++ b/smart-rank-model/test/transformers.test.js @@ -1,9 +1,12 @@ import test from 'ava'; import { load_test_env } from './_env.js'; import { SmartRankModel } from '../smart_rank_model.js'; +import { SmartRankTransformersAdapter } from '../adapters/transformers.js'; test.before(async t => { await load_test_env(t); + // Initialize test models + t.context.models = {}; }); const query = "Organic skincare products for sensitive skin"; @@ -29,52 +32,69 @@ const docs2 = [ ]; const expected_top2 = "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."; -test('jina-reranker-v1-tiny-en rank function returns expected results', async t => { - const model_key = 'jinaai/jina-reranker-v1-tiny-en'; - t.timeout(30000); - const model = await SmartRankModel.load(t.context.env, {model_key}); - const response = await model.rank(query, documents, { return_documents: true, top_k: 3 }); - // console.log({response}); - t.is(documents[response[0].corpus_id], expected_top, 'The top document should correctly identify the best strategy for sustainable agriculture'); - const response2 = await model.rank(query2, docs2, { return_documents: true, top_k: 3 }); - // console.log({response2}); - t.is(docs2[response2[0].corpus_id], expected_top2, 'The top document should correctly identify Washington, D.C. as the capital'); -}); -test('jina-reranker-v1-turbo-en rank function returns expected results', async t => { - const model_key = 'jinaai/jina-reranker-v1-turbo-en'; - t.timeout(30000); - const model = await SmartRankModel.load(t.context.env, {model_key}); - const response = await model.rank(query, documents, { return_documents: true, top_k: 3 }); - // console.log({response}); - t.is(documents[response[0].corpus_id], expected_top, 'The top document should correctly identify the best strategy for sustainable agriculture'); - const response2 = await model.rank(query2, docs2, { return_documents: true, top_k: 3 }); - // console.log({response2}); - t.is(docs2[response2[0].corpus_id], expected_top2, 'The top document should correctly identify Washington, D.C. as the capital'); -}); +async function get_or_create_model(t, model_key, opts = {}) { + const cache_key = `${model_key}${opts.quantized ? '_quantized' : ''}`; + if (!t.context.models[cache_key]) { + t.context.models[cache_key] = new SmartRankModel({ + adapters: { transformers: SmartRankTransformersAdapter }, + settings: { + adapter: 'transformers', + model_key, + ...opts + }, + }); + await t.context.models[cache_key].adapter.load(); + } + return t.context.models[cache_key]; +} -test('mxbai-rerank-xsmall-v1 rank function returns expected results', async t => { - const model_key = 'mixedbread-ai/mxbai-rerank-xsmall-v1'; +async function test_model(t, model_key, opts = {}) { t.timeout(30000); - const model = await SmartRankModel.load(t.context.env, {model_key}); - const response = await model.rank(query, documents, { return_documents: true, top_k: 3 }); - // console.log({response}); - t.is(documents[response[0].corpus_id], expected_top, 'The top document should correctly identify the best strategy for sustainable agriculture'); - const response2 = await model.rank(query2, docs2, { return_documents: true, top_k: 3 }); - // console.log({response2}); - t.is(docs2[response2[0].corpus_id], expected_top2, 'The top document should correctly identify Washington, D.C. as the capital'); -}); + const model = await get_or_create_model(t, model_key, opts); -// Xenova/bge-reranker-base -test('bge-reranker-base rank function returns expected results', async t => { - const model_key = 'Xenova/bge-reranker-base'; - t.timeout(30000); - const model = await SmartRankModel.load(t.context.env, {model_key, quantized: true}); - const response = await model.rank(query, documents, { return_documents: true, top_k: 3 }); + // Test skincare ranking + const response = await model.rank(query, documents, { return_documents: true, top_k: 10 }); // console.log({response}); - t.is(documents[response[0].corpus_id], expected_top, 'The top document should correctly identify the best strategy for sustainable agriculture'); - const response2 = await model.rank(query2, docs2, { return_documents: true, top_k: 3 }); + t.is(documents[response[0].index], expected_top, 'The top document should correctly identify the most relevant skincare document'); + + // Test capital city ranking + const response2 = await model.rank(query2, docs2, { return_documents: true, top_k: 10 }); // console.log({response2}); - t.is(docs2[response2[0].corpus_id], expected_top2, 'The top document should correctly identify Washington, D.C. as the capital'); + t.is(docs2[response2[0].index], expected_top2, 'The top document should correctly identify Washington, D.C. as the capital'); +} + +// Test cases for different models +const test_cases = [ + { + name: 'jina-reranker-v1-tiny-en', + model_key: 'jinaai/jina-reranker-v1-tiny-en' + }, + { + name: 'jina-reranker-v1-turbo-en', + model_key: 'jinaai/jina-reranker-v1-turbo-en' + }, + { + name: 'mxbai-rerank-xsmall-v1', + model_key: 'mixedbread-ai/mxbai-rerank-xsmall-v1' + }, + { + name: 'bge-reranker-base', + model_key: 'Xenova/bge-reranker-base', + opts: { quantized: true } + } +]; + +// Generate test cases dynamically +for (const test_case of test_cases) { + test.serial(`${test_case.name} rank function returns expected results`, async t => { + await test_model(t, test_case.model_key, test_case.opts || {}); + }); +} + +// Clean up models after all tests +test.after.always(t => { + // Clean up any resources if needed + t.context.models = {}; });