Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eslint rule to detect side effects (which prevent tree shaking), Add script that checks presence of side effects with rollup #323

Merged
merged 8 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions langchain/.eslintrc.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,34 @@ module.exports = {
project: "./tsconfig.json",
sourceType: "module",
},
plugins: ["@typescript-eslint"],
ignorePatterns: [".eslintrc.cjs", "create-entrypoints.js", "node_modules"],
plugins: ["@typescript-eslint", "tree-shaking"],
ignorePatterns: [
".eslintrc.cjs",
"create-entrypoints.js",
"check-tree-shaking.js",
"node_modules",
],
rules: {
"tree-shaking/no-side-effects-in-initialization": [
2,
{
noSideEffectsWhenCalled: [
{
module: "@jest/globals",
functions: [
"test",
"describe",
"it",
"beforeEach",
"afterEach",
"skip",
"each",
"only",
],
},
],
},
],
"@typescript-eslint/explicit-module-boundary-types": 0,
"@typescript-eslint/no-empty-function": 0,
"@typescript-eslint/no-shadow": 0,
Expand Down
78 changes: 78 additions & 0 deletions langchain/check-tree-shaking.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import fs from "fs/promises";
import { rollup } from "rollup";

const packageJson = JSON.parse(await fs.readFile("package.json", "utf-8"));

export function listEntrypoints() {
const exports = packageJson.exports;
const entrypoints = [];

for (const [key, value] of Object.entries(exports)) {
if (typeof value === "string") {
entrypoints.push(value);
} else if (typeof value === "object") {
entrypoints.push(value.import);
}
}

return entrypoints;
}

export function listExternals() {
return [
...Object.keys(packageJson.dependencies),
...Object.keys(packageJson.peerDependencies),
/node\:/,
"axios", // axios is a dependency of openai
"pdf-parse/lib/pdf-parse.js",
];
}

export async function checkTreeShaking() {
const externals = listExternals();
const entrypoints = listEntrypoints();
const consoleLog = console.log;
const reportMap = new Map();

for (const entrypoint of entrypoints) {
let sideEffects = "";

console.log = function (...args) {
const line = args.length ? args.join(" ") : "";
if (line.trim().startsWith("First side effect in")) {
sideEffects += line + "\n";
}
};

await rollup({
external: externals,
input: entrypoint,
experimentalLogSideEffects: true,
});

reportMap.set(entrypoint, {
log: sideEffects,
hasSideEffects: sideEffects.length > 0,
});
}

console.log = consoleLog;

let failed = false;
for (const [entrypoint, report] of reportMap) {
if (report.hasSideEffects) {
failed = true;
console.log("---------------------------------");
console.log(`Tree shaking failed for ${entrypoint}`);
console.log(report.log);
}
}

if (failed) {
process.exit(1);
} else {
console.log("Tree shaking checks passed!");
}
}

checkTreeShaking();
4 changes: 3 additions & 1 deletion langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"url": "[email protected]:hwchase17/langchainjs.git"
},
"scripts": {
"build": "yarn clean && tsc --declaration --outDir dist/ && node create-entrypoints.js",
"build": "yarn clean && tsc --declaration --outDir dist/ && node create-entrypoints.js && node check-tree-shaking.js",
"build:watch": "node create-entrypoints.js && tsc --declaration --outDir dist/ --watch",
"lint": "eslint src && dpdm --exit-code circular:1 --no-warning --no-tree src/*.ts src/**/*.ts",
"lint:fix": "yarn lint --fix",
Expand Down Expand Up @@ -92,6 +92,7 @@
"eslint-config-prettier": "^8.6.0",
"eslint-plugin-import": "^2.27.5",
"eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-tree-shaking": "^1.10.0",
"hnswlib-node": "^1.4.2",
"husky": "^8.0.3",
"jest": "^29.5.0",
Expand All @@ -100,6 +101,7 @@
"prettier": "^2.8.3",
"puppeteer": "^19.7.2",
"redis": "^4.6.4",
"rollup": "^3.19.1",
"serpapi": "^1.1.1",
"sqlite3": "^5.1.4",
"srt-parser-2": "^1.2.2",
Expand Down
54 changes: 27 additions & 27 deletions langchain/src/agents/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,6 @@ class ParseError extends Error {
}
}

// Hacky workaround to add static abstract methods. See detailed description of
// issue here: https://stackoverflow.com/a/65847601
export interface StaticAgent {
/**
* Create a prompt for this class
*
* @param tools - List of tools the agent will have access to, used to format the prompt.
* @param fields - Additional fields used to format the prompt.
*
* @returns A PromptTemplate assembled from the given tools and fields.
* */
// eslint-disable-next-line @typescript-eslint/no-explicit-any
createPrompt(tools: Tool[], fields?: Record<string, any>): BasePromptTemplate;
/** Construct an agent from an LLM and a list of tools */
fromLLMAndTools(
llm: BaseLLM,
tools: Tool[],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
args?: Record<string, any>
): Agent;
validateTools(_: Tool[]): void;
}

export const staticImplements = <T>(_: T) => {};

/**
* Class responsible for calling a language model and deciding an action.
*
Expand Down Expand Up @@ -95,11 +70,36 @@ export abstract class Agent {
*/
prepareForNewCall(): void {}

/**
* Create a prompt for this class
*
* @param tools - List of tools the agent will have access to, used to format the prompt.
* @param fields - Additional fields used to format the prompt.
*
* @returns A PromptTemplate assembled from the given tools and fields.
* */
static createPrompt(
_tools: Tool[],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
_fields?: Record<string, any>
): BasePromptTemplate {
throw new Error("Not implemented");
}

/** Construct an agent from an LLM and a list of tools */
static fromLLMAndTools(
_llm: BaseLLM,
_tools: Tool[],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
_args?: Record<string, any>
): Agent {
throw new Error("Not implemented");
}

/**
* Validate that appropriate tools are passed in
*/
// eslint-disable-next-line no-unused-vars
static validateTools(_: Tool[]): void {}
static validateTools(_tools: Tool[]): void {}

_stop(): string[] {
return [`\n${this.observationPrefix()}`];
Expand Down
4 changes: 1 addition & 3 deletions langchain/src/agents/chat/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { LLMChain } from "../../chains/index.js";
import { Agent, StaticAgent, staticImplements } from "../agent.js";
import { Agent } from "../agent.js";
import {
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
Expand Down Expand Up @@ -27,9 +27,7 @@ type ZeroShotAgentInput = AgentInput;
/**
* Agent for the MRKL chain.
* @augments Agent
* @augments StaticAgent
*/
@(staticImplements<StaticAgent>)
export class ChatAgent extends Agent {
constructor(input: ZeroShotAgentInput) {
super(input);
Expand Down
4 changes: 1 addition & 3 deletions langchain/src/agents/chat_convo/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { LLMChain } from "../../chains/index.js";
import { Agent, StaticAgent, staticImplements } from "../agent.js";
import { Agent } from "../agent.js";
import {
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
Expand Down Expand Up @@ -67,9 +67,7 @@ type ZeroShotAgentInput = AgentInput;
/**
* Agent for the MRKL chain.
* @augments Agent
* @augments StaticAgent
*/
@(staticImplements<StaticAgent>)
export class ChatConversationalAgent extends Agent {
outputParser: BaseOutputParser;

Expand Down
2 changes: 1 addition & 1 deletion langchain/src/agents/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export {
SerializedZeroShotAgent,
SerializedAgent,
} from "./types.js";
export { Agent, StaticAgent, staticImplements } from "./agent.js";
export { Agent } from "./agent.js";
export { AgentExecutor } from "./executor.js";
export { ZeroShotAgent } from "./mrkl/index.js";
export { ChatAgent } from "./chat/index.js";
Expand Down
4 changes: 1 addition & 3 deletions langchain/src/agents/mrkl/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
SerializedFromLLMAndTools,
SerializedZeroShotAgent,
} from "../types.js";
import { Agent, StaticAgent, staticImplements } from "../agent.js";
import { Agent } from "../agent.js";
import { Tool } from "../tools/base.js";

const FINAL_ANSWER_ACTION = "Final Answer:";
Expand All @@ -27,9 +27,7 @@ type ZeroShotAgentInput = AgentInput;
/**
* Agent for the MRKL chain.
* @augments Agent
* @augments StaticAgent
*/
@(staticImplements<StaticAgent>)
export class ZeroShotAgent extends Agent {
constructor(input: ZeroShotAgentInput) {
super(input);
Expand Down
28 changes: 16 additions & 12 deletions langchain/src/cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,26 @@ export abstract class BaseCache<T = Generation[]> {
abstract update(prompt: string, llmKey: string, value: T): Promise<void>;
}

const GLOBAL_MAP = new Map();

export class InMemoryCache<T = Generation[]> extends BaseCache<T> {
#cache: Map<string, T>;
private cache: Map<string, T>;

constructor() {
constructor(map?: Map<string, T>) {
super();
this.#cache = new Map();
this.cache = map ?? new Map();
}

lookup(prompt: string, llmKey: string): Promise<T | null> {
return Promise.resolve(
this.#cache.get(getCacheKey(prompt, llmKey)) ?? null
);
return Promise.resolve(this.cache.get(getCacheKey(prompt, llmKey)) ?? null);
}

async update(prompt: string, llmKey: string, value: T): Promise<void> {
this.#cache.set(getCacheKey(prompt, llmKey), value);
this.cache.set(getCacheKey(prompt, llmKey), value);
}

static global(): InMemoryCache {
return new InMemoryCache(GLOBAL_MAP);
}
}

Expand All @@ -44,11 +48,11 @@ export class InMemoryCache<T = Generation[]> extends BaseCache<T> {
* TODO: Generalize to support other types.
*/
export class RedisCache extends BaseCache<Generation[]> {
#redisClient: RedisClientType;
private redisClient: RedisClientType;

constructor(redisClient: RedisClientType) {
super();
this.#redisClient = redisClient;
this.redisClient = redisClient;
}

public async lookup(
Expand All @@ -57,7 +61,7 @@ export class RedisCache extends BaseCache<Generation[]> {
): Promise<Generation[] | null> {
let idx = 0;
let key = getCacheKey(prompt, llmKey, String(idx));
let value = await this.#redisClient.get(key);
let value = await this.redisClient.get(key);
const generations: Generation[] = [];

while (value) {
Expand All @@ -68,7 +72,7 @@ export class RedisCache extends BaseCache<Generation[]> {
generations.push({ text: value });
idx += 1;
key = getCacheKey(prompt, llmKey, String(idx));
value = await this.#redisClient.get(key);
value = await this.redisClient.get(key);
}

return generations.length > 0 ? generations : null;
Expand All @@ -81,7 +85,7 @@ export class RedisCache extends BaseCache<Generation[]> {
): Promise<void> {
for (let i = 0; i < value.length; i += 1) {
const key = getCacheKey(prompt, llmKey, String(i));
await this.#redisClient.set(key, value[i].text);
await this.redisClient.set(key, value[i].text);
}
}
}
1 change: 1 addition & 0 deletions langchain/src/callbacks/tests/tracer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
const TEST_SESSION_ID = 2023;
const _DATE = 1620000000000;

// eslint-disable-next-line tree-shaking/no-side-effects-in-initialization
Date.now = jest.fn(() => _DATE);

class FakeTracer extends BaseTracer {
Expand Down
1 change: 0 additions & 1 deletion langchain/src/chains/chat_vector_db_chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ export class ChatVectorDBQAChain
const question_generator_prompt = PromptTemplate.fromTemplate(
questionGeneratorTemplate || question_generator_template
);

const qa_prompt = PromptTemplate.fromTemplate(qaTemplate || qa_template);

const qaChain = loadQAStuffChain(llm, { prompt: qa_prompt });
Expand Down
12 changes: 6 additions & 6 deletions langchain/src/chains/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ Current conversation:
Human: {input}
AI:`;

const defaultPrompt = new PromptTemplate({
template: defaultTemplate,
inputVariables: ["history", "input"],
});

export class ConversationChain extends LLMChain {
constructor(fields: {
llm: BaseLanguageModel;
Expand All @@ -24,7 +19,12 @@ export class ConversationChain extends LLMChain {
memory?: BaseMemory;
}) {
super({
prompt: fields.prompt ?? defaultPrompt,
prompt:
fields.prompt ??
new PromptTemplate({
template: defaultTemplate,
inputVariables: ["history", "input"],
}),
llm: fields.llm,
outputKey: fields.outputKey ?? "response",
});
Expand Down
Loading