Skip to content

Commit

Permalink
feat: support custom action provider (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xRAG authored Feb 1, 2025
1 parent a536df4 commit d2bf34c
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import { z } from "zod";
import { CreateAction } from "./actionDecorator";
import { ActionProvider } from "./actionProvider";
import { Network } from "../network";
import { WalletProvider } from "../wallet-providers";

interface CustomActionProviderOptions<TWalletProvider extends WalletProvider> {
name: string;
description: string;
schema: z.ZodSchema;
invoke:
| ((args: any) => Promise<any>)
| ((walletProvider: TWalletProvider, args: any) => Promise<any>);
}

/**
* CustomActionProvider is a custom action provider that allows for custom action registration
*/
export class CustomActionProvider<TWalletProvider extends WalletProvider> extends ActionProvider {
/**
* Creates a new CustomActionProvider that dynamically adds decorated action methods
*
* @param actions - Array of custom actions to be added to the provider
*/
constructor(actions: CustomActionProviderOptions<TWalletProvider>[]) {
super("custom", []);

actions.forEach(({ name, description, schema, invoke }) => {
// Check if the invoke function expects a wallet provider
const takesWalletProvider = invoke.length === 2;

// Define the method on the prototype with the correct signature
Object.defineProperty(CustomActionProvider.prototype, name, {
value: takesWalletProvider
? async function (walletProvider: WalletProvider, args: unknown) {
const parsedArgs = schema.parse(args);
return await (invoke as any)(walletProvider, parsedArgs);
}
: async function (args: unknown) {
const parsedArgs = schema.parse(args);
return await (invoke as any)(parsedArgs);
},
configurable: true,
writable: true,
enumerable: true,
});

// Manually set the parameter metadata
const paramTypes = takesWalletProvider ? [WalletProvider, Object] : [Object];
Reflect.defineMetadata("design:paramtypes", paramTypes, CustomActionProvider.prototype, name);

// Apply the decorator using original name
const decoratedMethod = CreateAction({
name,
description,
schema,
})(
CustomActionProvider.prototype,
name,
Object.getOwnPropertyDescriptor(CustomActionProvider.prototype, name)!,
);

// Add the decorated method to the instance
Object.defineProperty(this, name, {
value: decoratedMethod,
configurable: true,
writable: true,
});
});
}

/**
* Custom action providers are supported on all networks
*
* @param _ - The network to checkpointSaver
* @returns true
*/
supportsNetwork(_: Network): boolean {
return true;
}
}

export const customActionProvider = <TWalletProvider extends WalletProvider>(
actions:
| CustomActionProviderOptions<TWalletProvider>
| CustomActionProviderOptions<TWalletProvider>[],
) => new CustomActionProvider<TWalletProvider>(Array.isArray(actions) ? actions : [actions]);
1 change: 1 addition & 0 deletions cdp-agentkit-core/typescript/src/action-providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ export * from "./basename";
export * from "./farcaster";
export * from "./twitter";
export * from "./wallet";
export * from "./customActionProvider";
7 changes: 2 additions & 5 deletions cdp-agentkit-core/typescript/src/agentkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ type AgentKitOptions = {
cdpApiKeyPrivateKey?: string;
walletProvider?: WalletProvider;
actionProviders?: ActionProvider[];
actions?: Action[];
};

/**
Expand All @@ -18,7 +17,6 @@ type AgentKitOptions = {
export class AgentKit {
private walletProvider: WalletProvider;
private actionProviders: ActionProvider[];
private actions?: Action[];

/**
* Initializes a new AgentKit instance
Expand All @@ -31,7 +29,6 @@ export class AgentKit {
private constructor(config: AgentKitOptions & { walletProvider: WalletProvider }) {
this.walletProvider = config.walletProvider;
this.actionProviders = config.actionProviders || [walletActionProvider()];
this.actions = config.actions || [];
}

/**
Expand Down Expand Up @@ -71,11 +68,11 @@ export class AgentKit {
* @returns An array of actions
*/
public getActions(): Action[] {
let actions: Action[] = this.actions || [];
const actions: Action[] = [];

for (const actionProvider of this.actionProviders) {
if (actionProvider.supportsNetwork(this.walletProvider.getNetwork())) {
actions = actions.concat(actionProvider.getActions(this.walletProvider));
actions.push(...actionProvider.getActions(this.walletProvider));
}
}

Expand Down

0 comments on commit d2bf34c

Please sign in to comment.