Skip to content

Commit

Permalink
fix: embedder errors in embed length (#9584)
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Krick <[email protected]>
  • Loading branch information
mattkrick authored Apr 1, 2024
1 parent 8cdd901 commit 341b4b7
Show file tree
Hide file tree
Showing 19 changed files with 716 additions and 732 deletions.
7 changes: 4 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,14 @@
"@types/dotenv": "^6.1.1",
"@types/jscodeshift": "^0.11.3",
"@types/lodash.toarray": "^4.4.7",
"@typescript-eslint/eslint-plugin": "^6.21.0",
"@typescript-eslint/parser": "^6.21.0",
"@typescript-eslint/eslint-plugin": "^7.4.0",
"@typescript-eslint/parser": "^7.4.0",
"autoprefixer": "^10.4.13",
"babel-loader": "^9.1.2",
"concurrently": "^8.0.1",
"copy-webpack-plugin": "^11.0.0",
"eslint-config-prettier": "^8.5.0",
"eslint": "^8.57.0",
"eslint-config-prettier": "^9.1.0",
"graphql": "15.7.2",
"html-webpack-plugin": "^5.5.0",
"husky": "^7.0.4",
Expand Down
2 changes: 0 additions & 2 deletions packages/client/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
"@types/stripe-v2": "^2.0.1",
"babel-plugin-relay": "^12.0.0",
"debug": "^4.1.1",
"eslint": "^8.2.0",
"eslint-config-prettier": "^8.5.0",
"eslint-plugin-emotion": "^10.0.14",
"eslint-plugin-react": "^7.16.0",
"eslint-plugin-react-hooks": "^1.6.1",
Expand Down
3 changes: 2 additions & 1 deletion packages/embedder/EmbeddingsJobQueueStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ import {DB} from 'parabol-server/postgres/pg'
import RootDataLoader from '../server/dataloader/RootDataLoader'
import {processJob} from './processJob'
import {Logger} from '../server/utils/Logger'
import {EmbeddingsTableName} from './ai_models/AbstractEmbeddingsModel'

export type DBJob = Selectable<DB['EmbeddingsJobQueue']>
export type EmbedJob = DBJob & {
jobType: 'embed'
jobData: {
embeddingsMetadataId: number
model: string
model: EmbeddingsTableName
}
}
export type RerankJob = DBJob & {jobType: 'rerank'; jobData: {discussionIds: string[]}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ const insertDiscussionsIntoMetadata = async (discussions: DiscussionMeta[], prio
if (!metadataRows[0]) return

const modelManager = getModelManager()
const models = modelManager.embeddingModels.map((m) => m.tableName)
const tableNames = [...modelManager.embeddingModels.keys()]
return (
pg
.with('Insert', (qc) =>
Expand All @@ -55,8 +55,9 @@ const insertDiscussionsIntoMetadata = async (discussions: DiscussionMeta[], prio
.with('Metadata', (qc) =>
qc
.selectFrom('Insert')
.fullJoin(sql<{model: string}>`UNNEST(ARRAY[${sql.join(models)}])`.as('model'), (join) =>
join.onTrue()
.fullJoin(
sql<{model: string}>`UNNEST(ARRAY[${sql.join(tableNames)}])`.as('model'),
(join) => join.onTrue()
)
.select(['id', 'model'])
)
Expand Down
56 changes: 42 additions & 14 deletions packages/embedder/ai_models/AbstractEmbeddingsModel.ts
Original file line number Diff line number Diff line change
@@ -1,43 +1,71 @@
import {sql} from 'kysely'
import getKysely from 'parabol-server/postgres/getKysely'
import {DB} from 'parabol-server/postgres/pg'
import isValid from '../../server/graphql/isValid'
import {Logger} from '../../server/utils/Logger'
import {EMBEDDER_JOB_PRIORITY} from '../EMBEDDER_JOB_PRIORITY'
import {ISO6391} from '../iso6393To1'
import {AbstractModel, ModelConfig} from './AbstractModel'
import {AbstractModel} from './AbstractModel'

export interface EmbeddingModelParams {
embeddingDimensions: number
maxInputTokens: number
tableSuffix: string
languages: ISO6391[]
}
export type EmbeddingsTable = Extract<keyof DB, `Embeddings_${string}`>
export interface EmbeddingModelConfig extends ModelConfig {
tableSuffix: string
}
export type EmbeddingsTableName = `Embeddings_${string}`
export type EmbeddingsTable = Extract<keyof DB, EmbeddingsTableName>

export abstract class AbstractEmbeddingsModel extends AbstractModel {
readonly embeddingDimensions: number
readonly maxInputTokens: number
readonly tableName: string
readonly tableName: EmbeddingsTableName
readonly languages: ISO6391[]
constructor(config: EmbeddingModelConfig) {
super(config)
const modelParams = this.constructModelParams(config)
constructor(modelId: string, url: string) {
super(url)
const modelParams = this.constructModelParams(modelId)
this.embeddingDimensions = modelParams.embeddingDimensions
this.languages = modelParams.languages
this.maxInputTokens = modelParams.maxInputTokens
this.tableName = `Embeddings_${modelParams.tableSuffix}`
}
protected abstract constructModelParams(config: EmbeddingModelConfig): EmbeddingModelParams
protected abstract constructModelParams(modelId: string): EmbeddingModelParams
abstract getEmbedding(content: string, retries?: number): Promise<number[] | Error>

abstract getTokens(content: string): Promise<number[] | Error>
splitText(content: string) {

async chunkText(content: string) {
const tokens = await this.getTokens(content)
if (tokens instanceof Error) return tokens
const isFullTextTooBig = tokens.length > this.maxInputTokens
if (!isFullTextTooBig) return [content]

for (let i = 0; i < 3; i++) {
const tokensPerWord = (4 + i) / 3
const chunks = this.splitText(content, tokensPerWord)
const chunkLengths = await Promise.all(
chunks.map(async (chunk) => {
const chunkTokens = await this.getTokens(chunk)
if (chunkTokens instanceof Error) return chunkTokens
return chunkTokens.length
})
)
const firstError = chunkLengths.find(
(chunkLength): chunkLength is Error => chunkLength instanceof Error
)
if (firstError) return firstError

const validChunks = chunkLengths.filter(isValid)
if (validChunks.every((chunkLength) => chunkLength <= this.maxInputTokens)) {
return chunks
}
}
return new Error(`Text is too long and could not be split into chunks. Is it english?`)
}
// private because result must still be too long to go into model. Must verify with getTokens
private splitText(content: string, tokensPerWord = 4 / 3) {
// it's actually 4 / 3, but don't want to chance a failed split
const TOKENS_PER_WORD = 5 / 3
const WORD_LIMIT = Math.floor(this.maxInputTokens / TOKENS_PER_WORD)
const WORD_LIMIT = Math.floor(this.maxInputTokens / tokensPerWord)
const chunks: string[] = []
const delimiters = ['\n\n', '\n', '.', ' ']
const countWords = (text: string) => text.trim().split(/\s+/).length
Expand Down Expand Up @@ -98,7 +126,7 @@ export abstract class AbstractEmbeddingsModel extends AbstractModel {
'tablename'
)} = ${this.tableName}`.execute(pg)
).rows.length > 0
if (hasTable) return undefined
if (hasTable) return
const vectorDimensions = this.embeddingDimensions
Logger.log(`ModelManager: creating ${this.tableName} with ${vectorDimensions} dimensions`)
await sql`
Expand Down
11 changes: 5 additions & 6 deletions packages/embedder/ai_models/AbstractGenerationModel.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {AbstractModel, ModelConfig} from './AbstractModel'
import {AbstractModel} from './AbstractModel'

export interface GenerationOptions {
maxNewTokens?: number
Expand All @@ -11,16 +11,15 @@ export interface GenerationOptions {
export interface GenerationModelParams {
maxInputTokens: number
}
export interface GenerationModelConfig extends ModelConfig {}

export abstract class AbstractGenerationModel extends AbstractModel {
readonly maxInputTokens: number
constructor(config: GenerationModelConfig) {
super(config)
const modelParams = this.constructModelParams(config)
constructor(modelId: string, url: string) {
super(url)
const modelParams = this.constructModelParams(modelId)
this.maxInputTokens = modelParams.maxInputTokens
}

protected abstract constructModelParams(config: GenerationModelConfig): GenerationModelParams
protected abstract constructModelParams(modelId: string): GenerationModelParams
abstract summarize(content: string, options: GenerationOptions): Promise<string>
}
9 changes: 2 additions & 7 deletions packages/embedder/ai_models/AbstractModel.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
export interface ModelConfig {
model: string
url: string
}

export abstract class AbstractModel {
public readonly url: string

constructor(config: ModelConfig) {
this.url = this.normalizeUrl(config.url)
constructor(url: string) {
this.url = this.normalizeUrl(url)
}

// removes a trailing slash from the inputUrl
Expand Down
155 changes: 65 additions & 90 deletions packages/embedder/ai_models/ModelManager.ts
Original file line number Diff line number Diff line change
@@ -1,118 +1,93 @@
import {AbstractEmbeddingsModel, EmbeddingModelConfig} from './AbstractEmbeddingsModel'
import {AbstractGenerationModel, GenerationModelConfig} from './AbstractGenerationModel'
import {ModelConfig} from './AbstractModel'
import {AbstractEmbeddingsModel, EmbeddingsTableName} from './AbstractEmbeddingsModel'
import {AbstractGenerationModel} from './AbstractGenerationModel'
import OpenAIGeneration from './OpenAIGeneration'
import TextEmbeddingsInference from './TextEmbeddingsInference'
import TextGenerationInference from './TextGenerationInference'

interface ModelManagerConfig {
embeddingModels: EmbeddingModelConfig[]
generationModels: GenerationModelConfig[]
}

type EmbeddingsModelType = 'text-embeddings-inference'
type GenerationModelType = 'openai' | 'text-generation-inference'

export class ModelManager {
embeddingModels: AbstractEmbeddingsModel[]
embeddingModelsMapByTable: {[key: string]: AbstractEmbeddingsModel}
generationModels: AbstractGenerationModel[]
export interface ModelConfig {
model: `${EmbeddingsModelType | GenerationModelType}:${string}`
url: string
}

private isValidConfig(
maybeConfig: Partial<ModelManagerConfig>
): maybeConfig is ModelManagerConfig {
if (!maybeConfig.embeddingModels || !Array.isArray(maybeConfig.embeddingModels)) {
throw new Error('Invalid configuration: embedding_models is missing or not an array')
}
if (!maybeConfig.generationModels || !Array.isArray(maybeConfig.generationModels)) {
throw new Error('Invalid configuration: summarization_models is missing or not an array')
export class ModelManager {
embeddingModels: Map<EmbeddingsTableName, AbstractEmbeddingsModel>
generationModels: Map<string, AbstractGenerationModel>

private parseModelEnvVars(envVar: 'AI_EMBEDDING_MODELS' | 'AI_GENERATION_MODELS'): ModelConfig[] {
const envValue = process.env[envVar]
if (!envValue) return []
let models
try {
models = JSON.parse(envValue)
} catch (e) {
throw new Error(`Invalid Env Var: ${envVar}. Must be a valid JSON`)
}

maybeConfig.embeddingModels.forEach((model: ModelConfig) => {
this.isValidModelConfig(model)
})

maybeConfig.generationModels.forEach((model: ModelConfig) => {
this.isValidModelConfig(model)
})

return true
}

private isValidModelConfig(model: ModelConfig): model is ModelConfig {
if (typeof model.model !== 'string') {
throw new Error('Invalid ModelConfig: model field should be a string')
}
if (model.url !== undefined && typeof model.url !== 'string') {
throw new Error('Invalid ModelConfig: url field should be a string')
if (!Array.isArray(models)) {
throw new Error(`Invalid Env Var: ${envVar}. Must be an array`)
}

return true
}

constructor(config: ModelManagerConfig) {
// Validate configuration
this.isValidConfig(config)

// Initialize embeddings models
this.embeddingModelsMapByTable = {}
this.embeddingModels = config.embeddingModels.map((modelConfig) => {
const [modelType] = modelConfig.model.split(':') as [EmbeddingsModelType, string]

switch (modelType) {
case 'text-embeddings-inference': {
const embeddingsModel = new TextEmbeddingsInference(modelConfig)
this.embeddingModelsMapByTable[embeddingsModel.tableName] = embeddingsModel
return embeddingsModel
const properties = ['model', 'url']
models.forEach((model, idx) => {
properties.forEach((prop) => {
if (typeof model[prop] !== 'string') {
throw new Error(`Invalid Env Var: ${envVar}. Invalid "${prop}" at index ${idx}`)
}
default:
throw new Error(`unsupported embeddings model '${modelType}'`)
}
})
})
return models
}

// Initialize summarization models
this.generationModels = config.generationModels.map((modelConfig) => {
const [modelType, _] = modelConfig.model.split(':') as [GenerationModelType, string]

switch (modelType) {
case 'openai': {
return new OpenAIGeneration(modelConfig)
constructor() {
// Initialize embeddings models
const embeddingConfig = this.parseModelEnvVars('AI_EMBEDDING_MODELS')
this.embeddingModels = new Map(
embeddingConfig.map((modelConfig) => {
const {model, url} = modelConfig
const [modelType, modelId] = model.split(':') as [EmbeddingsModelType, string]
switch (modelType) {
case 'text-embeddings-inference': {
const embeddingsModel = new TextEmbeddingsInference(modelId, url)
return [embeddingsModel.tableName, embeddingsModel]
}
default:
throw new Error(`unsupported embeddings model '${modelType}'`)
}
case 'text-generation-inference': {
return new TextGenerationInference(modelConfig)
})
)

// Initialize generation models
const generationConfig = this.parseModelEnvVars('AI_GENERATION_MODELS')
this.generationModels = new Map<string, AbstractGenerationModel>(
generationConfig.map((modelConfig) => {
const {model, url} = modelConfig
const [modelType, modelId] = model.split(':') as [GenerationModelType, string]
switch (modelType) {
case 'openai': {
return [modelId, new OpenAIGeneration(modelId, url)]
}
case 'text-generation-inference': {
return [modelId, new TextGenerationInference(modelId, url)]
}
default:
throw new Error(`unsupported generation model '${modelType}'`)
}
default:
throw new Error(`unsupported summarization model '${modelType}'`)
}
})
})
)
}

async maybeCreateTables() {
return Promise.all(this.embeddingModels.map((model) => model.createTable()))
return Promise.all([...this.embeddingModels].map(([, model]) => model.createTable()))
}
}

let modelManager: ModelManager | undefined
export function getModelManager() {
if (modelManager) return modelManager
const {AI_EMBEDDING_MODELS, AI_GENERATION_MODELS} = process.env
const config: ModelManagerConfig = {
embeddingModels: [],
generationModels: []
}
try {
config.embeddingModels = AI_EMBEDDING_MODELS && JSON.parse(AI_EMBEDDING_MODELS)
} catch (e) {
throw new Error(`Invalid AI_EMBEDDING_MODELS .env JSON: ${e}`)
if (!modelManager) {
modelManager = new ModelManager()
}
try {
config.generationModels = AI_GENERATION_MODELS && JSON.parse(AI_GENERATION_MODELS)
} catch (e) {
throw new Error(`Invalid AI_GENERATION_MODELS .env JSON: ${e}`)
}

modelManager = new ModelManager(config)

return modelManager
}

Expand Down
Loading

0 comments on commit 341b4b7

Please sign in to comment.