Skip to content

Latest commit

 

History

History
78 lines (62 loc) · 2.35 KB

README.md

File metadata and controls

78 lines (62 loc) · 2.35 KB

JAXBT: JAX Backtesting framework codecov

WIP: Differentiable backtesting with JAX.

Installation

pip install https://github.com/richwomanbtc/jaxbt.git

Example

import jax
import pandas as pd
from jaxbt.backtest import backtest_from_order_func, Backtest, OHLC, OrderType

pd = pd.DataFrame(
    columns=["timestamp", "open", "high", "low", "close"],
    data=[
        [1, 1.0, 1.0, 1.0, 1.0],
        [2, 2.0, 2.0, 2.0, 2.0],
        [3, 3.0, 3.0, 3.0, 3.0],
        [4, 4.0, 4.0, 4.0, 4.0],
        [5, 5.0, 5.0, 5.0, 5.0],
    ]
)

ohlc = OHLC.from_pandas(df)

@jax.jit
def f(bt: Backtest, idx: int):
    order_type = jax.lax.cond(
        bt.position[idx] == 0,
        lambda _: OrderType.MARKET_BUY, # if position is 0, perform market buy
        lambda _: OrderType.MARKET_SELL, # if position is not 0, perform market sell
        None
    )
    return order_type, 1., jnp.nan

result = backtest_from_order_func(ohlc, f)

@jax.jit
def f_param(params: jax.Array, bt: Backtest, idx: int):
    order_type = jax.lax.cond(
        bt.position[idx] == 0,
        lambda _: OrderType.MARKET_BUY, # if position is 0, perform market buy
        lambda _: OrderType.MARKET_SELL, # if position is not 0, perform market sell
        None
    )
    return order_type, params[0], jnp.nan

@jax.jit
def loss(param: jax.Array):
    result = backtest_from_order_func(
        df, lambda bt, idx: f_param(param, bt, idx)
    )
    return -result.pl.sum()

grad_fun = jax.grad(loss, argnums=0)

@jax.jit
def train(epoch, params, lr=0.01):
    def body_fun(idx, params):
        grads = grad_fun(params)
        params = params - lr * grads
        return params

    params = jax.lax.fori_loop(0, epoch, body_fun, params)
    return params

init_params = jnp.array([0.1])
result_params = train(100, init_params)
print(loss(init_params), loss(result_params))