Skip to content

Commit

Permalink
Fix OrderBookDepth10 Arrow decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
cjdsellers committed Jan 10, 2024
1 parent a58a8d7 commit 86b3265
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
25 changes: 19 additions & 6 deletions nautilus_core/persistence/src/arrow/depth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,13 @@ impl DecodeFromRecordBatch for OrderBookDepth10 {
)?);
}

let flags =
extract_column::<UInt8Array>(cols, "flags", 4 * DEPTH10_LEN + 2, DataType::UInt8)?;
let flags = extract_column::<UInt8Array>(cols, "flags", 6 * DEPTH10_LEN, DataType::UInt8)?;
let sequence =
extract_column::<UInt64Array>(cols, "sequence", 4 * DEPTH10_LEN + 3, DataType::UInt64)?;
extract_column::<UInt64Array>(cols, "sequence", 6 * DEPTH10_LEN + 1, DataType::UInt64)?;
let ts_event =
extract_column::<UInt64Array>(cols, "ts_event", 4 * DEPTH10_LEN + 4, DataType::UInt64)?;
extract_column::<UInt64Array>(cols, "ts_event", 6 * DEPTH10_LEN + 2, DataType::UInt64)?;
let ts_init =
extract_column::<UInt64Array>(cols, "ts_init", 4 * DEPTH10_LEN + 5, DataType::UInt64)?;
extract_column::<UInt64Array>(cols, "ts_init", 6 * DEPTH10_LEN + 3, DataType::UInt64)?;

// Map record batch rows to vector of OrderBookDepth10
let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
Expand Down Expand Up @@ -505,9 +504,11 @@ mod tests {
];
let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
assert_eq!(schema, expected_schema);
assert_eq!(schema.metadata()["instrument_id"], "AAPL.XNAS");
assert_eq!(schema.metadata()["price_precision"], "2");
assert_eq!(schema.metadata()["size_precision"], "0");
}

#[ignore] // WIP
#[rstest]
fn test_get_schema_map() {
let schema_map = OrderBookDepth10::get_schema_map();
Expand Down Expand Up @@ -800,4 +801,16 @@ mod tests {
assert_eq!(ts_init_values.len(), 1);
assert_eq!(ts_init_values.value(0), 2);
}

#[rstest]
fn test_decode_batch(stub_depth10: OrderBookDepth10) {
let instrument_id = InstrumentId::from("AAPL.XNAS");
let metadata = OrderBookDepth10::get_metadata(&instrument_id, 2, 0);

let data = vec![stub_depth10];
let record_batch = OrderBookDepth10::encode_batch(&metadata, &data).unwrap();
let decoded_data = OrderBookDepth10::decode_batch(&metadata, record_batch).unwrap();

assert_eq!(decoded_data.len(), 1);
}
}
7 changes: 7 additions & 0 deletions nautilus_trader/serialization/arrow/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nautilus_trader.model.data import CustomData
from nautilus_trader.model.data import OrderBookDelta
from nautilus_trader.model.data import OrderBookDeltas
from nautilus_trader.model.data import OrderBookDepth10
from nautilus_trader.model.data import QuoteTick
from nautilus_trader.model.data import TradeTick
from nautilus_trader.model.events import AccountState
Expand Down Expand Up @@ -146,6 +147,11 @@ def rust_defined_to_record_batch(data: list[Data], data_cls: type) -> pa.Table |
elif data_cls == Bar:
pyo3_bars = Bar.to_pyo3_list(data)
batch_bytes = DataTransformer.pyo3_bars_to_record_batch_bytes(pyo3_bars)
elif data_cls == OrderBookDepth10:
raise RuntimeError(
f"Unsupported Rust defined data type for catalog write, was `{data_cls}`. "
"You need to use a loader which returns `nautilus_pyo3.OrderBookDepth10` objects.",
)
else:
raise RuntimeError(
f"Unsupported Rust defined data type for catalog write, was `{data_cls}`",
Expand Down Expand Up @@ -289,6 +295,7 @@ def dicts_to_record_batch(data: list[dict], schema: pa.Schema) -> pa.RecordBatch
Bar,
OrderBookDelta,
OrderBookDeltas,
OrderBookDepth10,
}
RUST_STR_SERIALIZERS = {s.__name__ for s in RUST_SERIALIZERS}

Expand Down

0 comments on commit 86b3265

Please sign in to comment.