Skip to content

Commit

Permalink
providers + prep (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
roodboi authored Feb 1, 2024
1 parent c7aec7c commit c9ab910
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 293 deletions.
5 changes: 5 additions & 0 deletions .changeset/selfish-birds-appear.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@instructor-ai/instructor": patch
---

Adding explicit support for non-oai providers - currently anyscale and together ai - will do explicit checks on mode selected vs provider and model
3 changes: 2 additions & 1 deletion .example.env
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
OPENAI_API_KEY=
ANYSCALE_API_KEY=
ANYSCALE_API_KEY=
TOGETHER_API_KEY=
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANYSCALE_API_KEY: ${{ secrets.ANYSCALE_API_KEY }}
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}

steps:
- uses: actions/checkout@v3
Expand Down
Binary file modified bun.lockb
Binary file not shown.
3 changes: 1 addition & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@
},
"homepage": "https://github.com/instructor-ai/instructor-js#readme",
"dependencies": {
"zod-stream": "^0.0.5",
"zod-to-json-schema": "^3.22.3",
"zod-stream": "0.0.6",
"zod-validation-error": "^2.1.0"
},
"peerDependencies": {
Expand Down
35 changes: 0 additions & 35 deletions src/constants/modes.ts

This file was deleted.

67 changes: 67 additions & 0 deletions src/constants/providers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { MODE, type Mode } from "zod-stream"

export const PROVIDERS = {
OAI: "OAI",
ANYSCALE: "ANYSCALE",
TOGETHER: "TOGETHER",
OTHER: "OTHER"
} as const

export type Provider = keyof typeof PROVIDERS

export const PROVIDER_SUPPORTED_MODES: {
[key in Provider]: Mode[]
} = {
[PROVIDERS.OTHER]: [MODE.FUNCTIONS, MODE.TOOLS, MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA],
[PROVIDERS.OAI]: [MODE.FUNCTIONS, MODE.TOOLS, MODE.JSON, MODE.MD_JSON],
[PROVIDERS.ANYSCALE]: [MODE.TOOLS, MODE.JSON, MODE.JSON_SCHEMA],
[PROVIDERS.TOGETHER]: [MODE.TOOLS, MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA]
} as const

export const NON_OAI_PROVIDER_URLS = {
[PROVIDERS.ANYSCALE]: "api.endpoints.anyscale",
[PROVIDERS.TOGETHER]: "api.together.xyz",
[PROVIDERS.OAI]: "api.openai.com"
} as const

export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
[PROVIDERS.OTHER]: {
[MODE.FUNCTIONS]: ["*"],
[MODE.TOOLS]: ["*"],
[MODE.JSON]: ["*"],
[MODE.MD_JSON]: ["*"],
[MODE.JSON_SCHEMA]: ["*"]
},
[PROVIDERS.OAI]: {
[MODE.FUNCTIONS]: ["*"],
[MODE.TOOLS]: ["*"],
[MODE.JSON]: [
"gpt-3.5-turbo-1106",
"gpt-4-1106-preview",
"gpt-4-0125-preview",
"gpt-4-turbo-preview"
],
[MODE.MD_JSON]: ["*"]
},
[PROVIDERS.TOGETHER]: {
[MODE.JSON_SCHEMA]: [
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.1",
"togethercomputer/CodeLlama-34b-Instruct"
],
[MODE.MD_JSON]: ["*"],
[MODE.TOOLS]: [
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.1",
"togethercomputer/CodeLlama-34b-Instruct"
]
},
[PROVIDERS.ANYSCALE]: {
[MODE.JSON_SCHEMA]: [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1"
],
[MODE.MD_JSON]: ["*"],
[MODE.TOOLS]: ["mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"]
}
}
6 changes: 2 additions & 4 deletions src/dsl/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ import { OAIClientExtended } from "@/instructor"
import type { ChatCompletionCreateParams } from "openai/resources/chat/completions.mjs"
import { RefinementCtx, z } from "zod"

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type AsyncSuperRefineFunction = (data: any, ctx: RefinementCtx) => Promise<any>
type AsyncSuperRefineFunction = (data: string, ctx: RefinementCtx) => Promise<void>

export const LLMValidator = (
instructor: OAIClientExtended,
Expand All @@ -15,7 +14,7 @@ export const LLMValidator = (
reason: z.string().optional()
})

const fn = async (value, ctx) => {
return async (value, ctx) => {
const validated = await instructor.chat.completions.create({
max_retries: 0,
...params,
Expand All @@ -41,5 +40,4 @@ export const LLMValidator = (
})
}
}
return fn
}
57 changes: 47 additions & 10 deletions src/instructor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,27 @@ import {
ChatCompletionCreateParamsWithModel,
InstructorConfig,
LogLevel,
Mode,
ReturnTypeBasedOnParams
} from "@/types"
import OpenAI from "openai"
import { z } from "zod"
import ZodStream, { OAIStream, withResponseModel } from "zod-stream"
import ZodStream, { OAIResponseParser, OAIStream, withResponseModel, type Mode } from "zod-stream"
import { fromZodError } from "zod-validation-error"

import { MODE, MODE_TO_PARSER } from "@/constants/modes"
import {
NON_OAI_PROVIDER_URLS,
Provider,
PROVIDER_SUPPORTED_MODES,
PROVIDER_SUPPORTED_MODES_BY_MODEL,
PROVIDERS
} from "./constants/providers"

const MAX_RETRIES_DEFAULT = 0

class Instructor {
readonly client: OpenAI
readonly mode: Mode
readonly provider: Provider
readonly debug: boolean = false

/**
Expand All @@ -29,11 +35,39 @@ class Instructor {
this.mode = mode
this.debug = debug

//TODO: probably some more sophisticated validation we can do here re: modes and otherwise.
// but just throwing quick here for now.
if (mode === MODE.JSON_SCHEMA) {
if (!this.client.baseURL.includes("anyscale")) {
throw new Error("JSON_SCHEMA mode is only support on Anyscale.")
const provider =
this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.ANYSCALE) ? PROVIDERS.ANYSCALE
: this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.TOGETHER) ? PROVIDERS.TOGETHER
: this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.TOGETHER) ? PROVIDERS.OAI
: PROVIDERS.OTHER

this.provider = provider

this.validateOptions()
}

private validateOptions() {
const isModeSupported = PROVIDER_SUPPORTED_MODES[this.provider].includes(this.mode)

if (this.provider === PROVIDERS.OTHER) {
this.log("debug", "Unknown provider - cant validate options.")
}

if (!isModeSupported) {
throw new Error(`Mode ${this.mode} is not supported by provider ${this.provider}`)

Check failure on line 57 in src/instructor.ts

View workflow job for this annotation

GitHub Actions / run-tests

error: Mode MD_JSON is not supported by provider ANYSCALE

at validateOptions (/home/runner/work/instructor-js/instructor-js/src/instructor.ts:57:13) at new Instructor (/home/runner/work/instructor-js/instructor-js/src/instructor.ts:46:5) at instructor_default (/home/runner/work/instructor-js/instructor-js/src/instructor.ts:285:22) at /home/runner/work/instructor-js/instructor-js/tests/mode.test.ts:90:18 at extractUser (/home/runner/work/instructor-js/instructor-js/tests/mode.test.ts:83:28) at /home/runner/work/instructor-js/instructor-js/tests/mode.test.ts:111:26 at /home/runner/work/instructor-js/instructor-js/tests/mode.test.ts:110:99
}
}

private validateModelModeSupport<T extends z.AnyZodObject>(
params: ChatCompletionCreateParamsWithModel<T>
) {
if (this.provider !== PROVIDERS.OAI) {
const modelSupport = PROVIDER_SUPPORTED_MODES_BY_MODEL[this.provider][this.mode]

if (!modelSupport.includes("*") && !modelSupport.includes(params.model)) {
throw new Error(
`Model ${params.model} is not supported by provider ${this.provider} in mode ${this.mode}`
)
}
}
}
Expand Down Expand Up @@ -98,9 +132,10 @@ class Instructor {
this.log("debug", response_model.name, "making completion call with params: ", resolvedParams)

const completion = await this.client.chat.completions.create(resolvedParams)
const parser = MODE_TO_PARSER[this.mode]

const parsedCompletion = parser(completion as OpenAI.Chat.Completions.ChatCompletion)
const parsedCompletion = OAIResponseParser(
completion as OpenAI.Chat.Completions.ChatCompletion
)
try {
return JSON.parse(parsedCompletion) as z.infer<T>
} catch (error) {
Expand Down Expand Up @@ -200,6 +235,8 @@ class Instructor {
>(
params: P
): Promise<ReturnTypeBasedOnParams<P>> => {
this.validateModelModeSupport(params)

if (this.isChatCompletionCreateParamsWithModel(params)) {
if (params.stream) {
return this.chatCompletionStream(params) as ReturnTypeBasedOnParams<
Expand Down
100 changes: 0 additions & 100 deletions src/oai/params.ts

This file was deleted.

Loading

0 comments on commit c9ab910

Please sign in to comment.