diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index ec0a6c80d1..10b48a768c 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -840,15 +840,42 @@ fn try_cast_batch(from_fields: &Fields, to_fields: &Fields) -> Result<(), ArrowE (f.data_type(), target_field.data_type()) { try_cast_batch(fields0, fields1) - } else if !can_cast_types(f.data_type(), target_field.data_type()) { - Err(ArrowError::SchemaError(format!( - "Cannot cast field {} from {} to {}", - f.name(), - f.data_type(), - target_field.data_type() - ))) } else { - Ok(()) + match (f.data_type(), target_field.data_type()) { + ( + DataType::Decimal128(left_precision, left_scale) | DataType::Decimal256(left_precision, left_scale), + DataType::Decimal128(right_precision, right_scale) + ) => { + if left_precision <= right_precision && left_scale <= right_scale { + Ok(()) + } else { + Err(ArrowError::SchemaError(format!( + "Cannot cast field {} from {} to {}", + f.name(), + f.data_type(), + target_field.data_type() + ))) + } + }, + ( + _, + DataType::Decimal256(_, _), + ) => { + unreachable!("Target field can never be Decimal 256. According to the protocol: 'The precision and scale can be up to 38.'") + }, + (left, right) => { + if !can_cast_types(left, right) { + Err(ArrowError::SchemaError(format!( + "Cannot cast field {} from {} to {}", + f.name(), + f.data_type(), + target_field.data_type() + ))) + } else { + Ok(()) + } + } + } } } else { Err(ArrowError::SchemaError(format!( diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 9baab32d9a..eb8244dbb3 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -1498,3 +1498,33 @@ def test_empty(existing_table: DeltaTable): empty_table = pa.Table.from_pylist([], schema=schema) with pytest.raises(DeltaError, match="No data source supplied to write command"): write_deltalake(existing_table, empty_table, mode="append", engine="rust") + + +def test_rust_decimal_cast(tmp_path: pathlib.Path): + import re + from decimal import Decimal + + data = pa.table({"x": pa.array([Decimal("100.1")])}) + + write_deltalake(tmp_path, data, mode="append", engine="rust") + + assert DeltaTable(tmp_path).to_pyarrow_table()["x"][0].as_py() == Decimal("100.1") + + # Write smaller decimal, works since it's fits in the previous decimal precision, scale + data = pa.table({"x": pa.array([Decimal("10.1")])}) + write_deltalake(tmp_path, data, mode="append", engine="rust") + + data = pa.table({"x": pa.array([Decimal("1000.1")])}) + # write decimal that is larger than target type in table + with pytest.raises( + SchemaMismatchError, + match=re.escape( + "Cannot cast field x from Decimal128(5, 1) to Decimal128(4, 1)" + ), + ): + write_deltalake(tmp_path, data, mode="append", engine="rust") + + with pytest.raises(SchemaMismatchError, match="Cannot merge types decimal"): + write_deltalake( + tmp_path, data, mode="append", schema_mode="merge", engine="rust" + )