import {
  slice,
  concat,
  type Transport,
  type Chain,
  type Account,
  type Hex,
  type WalletActions,
  type Client,
  type PublicActions,
  type WriteContractParameters,
  type EncodeFunctionDataParameters,
} from "viem";
import { getAction, encodeFunctionData } from "viem/utils";
import { readContract, writeContract as viem_writeContract } from "viem/actions";
import { readHex } from "@latticexyz/common";
import {
  getKeySchema,
  getValueSchema,
  getSchemaTypes,
  decodeValueArgs,
  encodeKey,
} from "@latticexyz/protocol-parser/internal";
import worldConfig from "../../mud.config";
import { worldCallAbi } from "../worldCallAbi";

type CallFromParameters = {
  worldAddress: Hex;
  delegatorAddress: Hex;
  worldFunctionToSystemFunction?: (worldFunctionSelector: Hex) => Promise<SystemFunction>;
  publicClient?: Client;
};

type SystemFunction = { systemId: Hex; systemFunctionSelector: Hex };

// By extending viem clients with this function after delegation, the delegation is automatically applied to World contract writes.
// This means that these writes are made on behalf of the delegator.
// Internally, it transforms the write arguments to use `callFrom`.
//
// Accepts either `worldFunctionToSystemFunction` or `publicClient` as an argument.
// `worldFunctionToSystemFunction` allows manually providing the mapping function, thus users can utilize their client store for the lookup.
// If `publicClient` is provided instead, this function retrieves the corresponding system function from the World contract.
//
// The function mapping is cached to avoid redundant retrievals for the same World function.
export function callFrom(
  params: CallFromParameters,
): <chain extends Chain, account extends Account | undefined>(
  client: Client<Transport, chain, account>,
) => Pick<WalletActions<chain, account>, "writeContract"> {
  return (client) => ({
    async writeContract(writeArgs) {
      const _writeContract = getAction(client, viem_writeContract, "writeContract");

      // Skip if the contract isn't the World or the function called should not be redirected through `callFrom`.
      if (
        writeArgs.address !== params.worldAddress ||
        writeArgs.functionName === "call" ||
        writeArgs.functionName === "callFrom" ||
        writeArgs.functionName === "batchCallFrom" ||
        writeArgs.functionName === "callWithSignature"
      ) {
        return _writeContract(writeArgs);
      }

      // Wrap system calls from `batchCall` with delegator for a `batchCallFrom`
      // TODO: remove this specific workaround once https://github.com/latticexyz/mud/pull/3506 lands
      if (writeArgs.functionName === "batchCall") {
        const batchCallArgs = writeArgs as unknown as WriteContractParameters<worldCallAbi, "batchCall">;
        const [systemCalls] = batchCallArgs.args;
        if (!systemCalls.length) {
          throw new Error("`batchCall` should have at least one system call.");
        }

        return _writeContract({
          ...batchCallArgs,
          functionName: "batchCallFrom",
          args: [systemCalls.map((systemCall) => ({ from: params.delegatorAddress, ...systemCall }))],
        });
      }

      // Encode the World's calldata (which includes the World's function selector).
      const worldCalldata = encodeFunctionData({
        abi: writeArgs.abi,
        functionName: writeArgs.functionName,
        args: writeArgs.args,
      } as unknown as EncodeFunctionDataParameters);

      // The first 4 bytes of calldata represent the function selector.
      const worldFunctionSelector = slice(worldCalldata, 0, 4);

      // Get the systemId and System's function selector.
      const { systemId, systemFunctionSelector } = await worldFunctionToSystemFunction({
        ...params,
        publicClient: params.publicClient ?? client,
        worldFunctionSelector,
      });

      // Construct the System's calldata by replacing the World's function selector with the System's.
      // Use `readHex` instead of `slice` to prevent out-of-bounds errors with calldata that has no args.
      const systemCalldata = concat([systemFunctionSelector, readHex(worldCalldata, 4)]);

      // Call `writeContract` with the new args.
      return _writeContract({
        ...(writeArgs as unknown as WriteContractParameters<worldCallAbi, "callFrom">),
        functionName: "callFrom",
        args: [params.delegatorAddress, systemId, systemCalldata],
      });
    },
  });
}

const systemFunctionCache = new Map<Hex, SystemFunction>();

async function worldFunctionToSystemFunction(params: {
  worldAddress: Hex;
  delegatorAddress: Hex;
  worldFunctionSelector: Hex;
  worldFunctionToSystemFunction?: (worldFunctionSelector: Hex) => Promise<SystemFunction>;
  publicClient: Client;
}): Promise<SystemFunction> {
  const cacheKey = concat([params.worldAddress, params.worldFunctionSelector]);

  // Use cache if the function has been called previously.
  const cached = systemFunctionCache.get(cacheKey);
  if (cached) return cached;

  // If a mapping function is provided, use it. Otherwise, call the World contract.
  const systemFunction = params.worldFunctionToSystemFunction
    ? await params.worldFunctionToSystemFunction(params.worldFunctionSelector)
    : await retrieveSystemFunctionFromContract(params.publicClient, params.worldAddress, params.worldFunctionSelector);

  systemFunctionCache.set(cacheKey, systemFunction);

  return systemFunction;
}

async function retrieveSystemFunctionFromContract(
  publicClient: Client,
  worldAddress: Hex,
  worldFunctionSelector: Hex,
): Promise<SystemFunction> {
  const table = worldConfig.tables.world__FunctionSelectors;

  const keySchema = getSchemaTypes(getKeySchema(table));
  const valueSchema = getSchemaTypes(getValueSchema(table));

  const _readContract = getAction(publicClient, readContract, "readContract") as PublicActions["readContract"];

  const [staticData, encodedLengths, dynamicData] = await _readContract({
    address: worldAddress,
    abi: [
      {
        type: "function",
        name: "getRecord",
        inputs: [
          {
            name: "tableId",
            type: "bytes32",
            internalType: "ResourceId",
          },
          {
            name: "keyTuple",
            type: "bytes32[]",
            internalType: "bytes32[]",
          },
        ],
        outputs: [
          {
            name: "staticData",
            type: "bytes",
            internalType: "bytes",
          },
          {
            name: "encodedLengths",
            type: "bytes32",
            internalType: "EncodedLengths",
          },
          {
            name: "dynamicData",
            type: "bytes",
            internalType: "bytes",
          },
        ],
        stateMutability: "view",
      },
    ],
    functionName: "getRecord",
    args: [table.tableId, encodeKey(keySchema, { worldFunctionSelector })],
  });

  const decoded = decodeValueArgs(valueSchema, { staticData, encodedLengths, dynamicData });

  const systemFunction: SystemFunction = {
    systemId: decoded.systemId,
    systemFunctionSelector: decoded.systemFunctionSelector,
  };

  return systemFunction;
}