Skip to content

Commit

Permalink
feat: add customizable token modifier costs per AI model
Browse files Browse the repository at this point in the history
chore: add token modifier tests

- Implemented 'tokenModifierRatio' for AI models in ServerConfig.ts
- Added `_calculateTokenCost` function in AIController.ts to calculate the modifier ratio for a given model.
- Added tests to AIController.spec.ts to ensure proper operation

This allows us to set custom token-to-cost ratios for each model, enabling differential pricing or discounts based on the AI model in use.
  • Loading branch information
TroyceGowdy committed Nov 14, 2024
1 parent 4d7715f commit afa452c
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 2 deletions.
147 changes: 147 additions & 0 deletions src/aux-records/AIController.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,13 @@ describe('AIController', () => {
provider: 'provider2',
model: 'test-model3',
},
{
provider: 'provider1',
model: 'test-model-token-ratio',
},
],
allowedChatSubscriptionTiers: ['test-tier'],
tokenModifierRatio: { 'test-model-token-ratio': 2.0 },
},
},
generateSkybox: {
Expand Down Expand Up @@ -474,6 +479,7 @@ describe('AIController', () => {
},
],
allowedChatSubscriptionTiers: true,
tokenModifierRatio: { default: 1.0 },
},
},
generateSkybox: null,
Expand Down Expand Up @@ -621,6 +627,70 @@ describe('AIController', () => {
});
});

it('should use configure token ratio', async () => {
chatInterface.chat.mockReturnValueOnce(
Promise.resolve({
choices: [
{
role: 'user',
content: 'test',
finishReason: 'stop',
},
],
totalTokens: 123,
})
);

const result = await controller.chat({
model: 'test-model-token-ratio',
messages: [
{
role: 'user',
content: 'test',
},
],
temperature: 0.5,
userId,
userSubscriptionTier,
});

expect(result).toEqual({
success: true,
choices: [
{
role: 'user',
content: 'test',
finishReason: 'stop',
},
],
});
expect(chatInterface.chat).toBeCalledWith({
model: 'test-model-token-ratio',
messages: [
{
role: 'user',
content: 'test',
},
],
temperature: 0.5,
userId: 'test-user',
});

const metrics = await store.getSubscriptionAiChatMetrics({
ownerId: userId,
});

expect(metrics).toEqual({
ownerId: userId,
subscriptionStatus: null,
subscriptionId: null,
subscriptionType: 'user',
currentPeriodStartMs: null,
currentPeriodEndMs: null,
totalTokensInCurrentPeriod: 246,
});
});

describe('subscriptions', () => {
beforeEach(async () => {
store.subscriptionConfiguration = buildSubscriptionConfig(
Expand Down Expand Up @@ -1068,6 +1138,7 @@ describe('AIController', () => {
},
],
allowedChatSubscriptionTiers: ['test-tier'],
tokenModifierRatio: { default: 1.0 },
},
},
generateSkybox: {
Expand Down Expand Up @@ -1489,6 +1560,7 @@ describe('AIController', () => {
},
],
allowedChatSubscriptionTiers: true,
tokenModifierRatio: { default: 1.0 },
},
},
generateSkybox: null,
Expand Down Expand Up @@ -1659,6 +1731,80 @@ describe('AIController', () => {
});
});

it('should use configure token ratio', async () => {
chatInterface.chatStream.mockReturnValueOnce(
asyncIterable<AIChatInterfaceStreamResponse>([
Promise.resolve({
choices: [
{
role: 'user',
content: 'test',
finishReason: 'stop',
},
],
totalTokens: 123,
}),
])
);

const result = await unwindAndCaptureAsync(
controller.chatStream({
model: 'test-model-token-ratio',
messages: [
{
role: 'user',
content: 'test',
},
],
temperature: 0.5,
userId,
userSubscriptionTier,
})
);

expect(result).toEqual({
result: {
success: true,
},
states: [
{
choices: [
{
role: 'user',
content: 'test',
finishReason: 'stop',
},
],
},
],
});
expect(chatInterface.chatStream).toBeCalledWith({
model: 'test-model-token-ratio',
messages: [
{
role: 'user',
content: 'test',
},
],
temperature: 0.5,
userId: 'test-user',
});

const metrics = await store.getSubscriptionAiChatMetrics({
ownerId: userId,
});

expect(metrics).toEqual({
ownerId: userId,
subscriptionStatus: null,
subscriptionId: null,
subscriptionType: 'user',
currentPeriodStartMs: null,
currentPeriodEndMs: null,
totalTokensInCurrentPeriod: 246,
});
});

describe('subscriptions', () => {
beforeEach(async () => {
store.subscriptionConfiguration = buildSubscriptionConfig(
Expand Down Expand Up @@ -2112,6 +2258,7 @@ describe('AIController', () => {
},
],
allowedChatSubscriptionTiers: ['test-tier'],
tokenModifierRatio: { default: 1.0 },
},
},
generateSkybox: {
Expand Down
27 changes: 25 additions & 2 deletions src/aux-records/AIController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
} from '@casual-simulation/aux-common/Errors';
import {
AIChatInterface,
AIChatInterfaceResponse,
AIChatInterfaceStreamResponse,
AIChatMessage,
AIChatStreamMessage,
Expand Down Expand Up @@ -470,10 +471,12 @@ export class AIController {
});

if (result.totalTokens > 0) {
const adjustedTokens = this._calculateTokenCost(result, model);

await this._metrics.recordChatMetrics({
userId: request.userId,
createdAtMs: Date.now(),
tokens: result.totalTokens,
tokens: adjustedTokens,
});
}

Expand All @@ -495,6 +498,17 @@ export class AIController {
}
}

private _calculateTokenCost(
result: AIChatInterfaceResponse | AIChatInterfaceStreamResponse,
model: string
) {
const totalTokens = result.totalTokens;
const tokenModifierRatio = this._chatOptions.tokenModifierRatio;
const modifier = tokenModifierRatio[model] ?? 1.0;
const adjustedTokens = modifier * totalTokens;
return adjustedTokens;
}

@traced(TRACE_NAME)
async *chatStream(
request: AIChatRequest
Expand Down Expand Up @@ -684,10 +698,14 @@ export class AIController {

for await (let chunk of result) {
if (chunk.totalTokens > 0) {
const adjustedTokens = this._calculateTokenCost(
chunk,
model
);
await this._metrics.recordChatMetrics({
userId: request.userId,
createdAtMs: Date.now(),
tokens: chunk.totalTokens,
tokens: adjustedTokens,
});
}

Expand Down Expand Up @@ -1456,6 +1474,11 @@ export interface AIChatRequest {
* If the AI generates a sequence of tokens that match one of the given words, then it will stop generating tokens.
*/
stopWords?: string[];

/**
* The maximum number of tokens that should be generated.
*/
totalTokens?: number;
}

export type AIChatResponse = AIChatSuccess | AIChatFailure;
Expand Down

0 comments on commit afa452c

Please sign in to comment.