Skip to content

Commit

Permalink
fix issue with shaping in permuted-write cases
Browse files Browse the repository at this point in the history
  • Loading branch information
johnkerl committed Oct 16, 2024
1 parent 4e54e27 commit 251542b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def from_isolated_dataframe(
next_soma_joinid += 1
return cls(data=data, field_name=index_field_name)

def get_shape(self) -> int:
if len(self.data.values()) == 0:
return 0
else:
return 1 + max(self.data.values())

def to_json(self) -> str:
return json.dumps(self, default=attrs.asdict, sort_keys=True, indent=4)

Expand Down Expand Up @@ -490,20 +496,15 @@ def get_obs_shape(self) -> int:
"""Reports the new obs shape which the experiment will need to be
resized to in order to accommodate the data contained within the
registration."""
if len(self.obs_axis.data.values()) == 0:
return 0
return 1 + max(self.obs_axis.data.values())
return self.obs_axis.get_shape()

def get_var_shapes(self) -> Dict[str, int]:
"""Reports the new var shapes, one per measurement, which the experiment
will need to be resized to in order to accommodate the data contained
within the registration."""
retval: Dict[str, int] = {}
for key, axis in self.var_axes.items():
if len(axis.data.values()) == 0:
retval[key] = 0
else:
retval[key] = 1 + max(axis.data.values())
retval[key] = axis.get_shape()
return retval

def to_json(self) -> str:
Expand Down
6 changes: 6 additions & 0 deletions apis/python/src/tiledbsoma/io/_registration/id_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def is_identity(self) -> bool:
return False
return True

def get_shape(self) -> int:
if len(self.data) == 0:
return 0
else:
return 1 + max(self.data)

@classmethod
def identity(cls, n: int) -> Self:
"""This maps 0-up input-file offsets to 0-up soma_joinid values. This is
Expand Down
13 changes: 11 additions & 2 deletions apis/python/src/tiledbsoma/io/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,7 @@ def _write_dataframe(
df,
df_uri,
id_column_name,
shape=axis_mapping.get_shape(),
ingestion_params=ingestion_params,
additional_metadata=additional_metadata,
original_index_metadata=original_index_metadata,
Expand All @@ -1178,6 +1179,7 @@ def _write_dataframe_impl(
df_uri: str,
id_column_name: Optional[str],
*,
shape: int,
ingestion_params: IngestionParams,
additional_metadata: AdditionalMetadata = None,
original_index_metadata: OriginalIndexMetadata = None,
Expand Down Expand Up @@ -1206,7 +1208,7 @@ def _write_dataframe_impl(
try:
domain = None
if NEW_SHAPE_FEATURE_FLAG_ENABLED:
domain = ((0, int(df.shape[0]) - 1),)
domain = ((0, shape - 1),)
soma_df = DataFrame.create(
df_uri,
schema=arrow_table.schema,
Expand Down Expand Up @@ -1312,7 +1314,12 @@ def _create_from_matrix(
shape: Sequence[Union[int, None]] = ()
# A SparseNDArray must be appendable in soma.io.
if NEW_SHAPE_FEATURE_FLAG_ENABLED:
shape = tuple(int(e) for e in matrix.shape)
# Instead of
# shape = tuple(int(e) for e in matrix.shape)
# we consult the registration mapping. This is important
# in the case when multiple H5ADs/AnnDatas are being
# ingested to an experiment which doesn't pre-exist.
shape = (axis_0_mapping.get_shape(), axis_1_mapping.get_shape())
elif cls.is_sparse:
shape = tuple(None for _ in matrix.shape)

Check warning on line 1324 in apis/python/src/tiledbsoma/io/ingest.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/io/ingest.py#L1323-L1324

Added lines #L1323 - L1324 were not covered by tests
else:
Expand Down Expand Up @@ -2722,6 +2729,7 @@ def _ingest_uns_1d_string_array(
df,
df_uri,
None,
shape=df.shape[0],
ingestion_params=ingestion_params,
platform_config=platform_config,
context=context,
Expand Down Expand Up @@ -2767,6 +2775,7 @@ def _ingest_uns_2d_string_array(
df,
df_uri,
None,
shape=df.shape[0],
ingestion_params=ingestion_params,
additional_metadata=additional_metadata,
platform_config=platform_config,
Expand Down
20 changes: 11 additions & 9 deletions apis/python/src/tiledbsoma/io/shaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,18 @@ def resize_experiment(
output_handle=output_handle,
)

# Do an early check on the nvars keys vs the experiment's
# measurent names. This isn't a can-do status for the experiment;
# it's a failure of the user's arguments.
# Extra user-provided keys not relevant to the experiment are ignored. This
# is important for the case when a new measurement, which is registered from
# AnnData/H5AD inputs, is registered and is about to be created but does not
# exist just yet in the experiment storage.
#
# If the user hasn't provided a key -- e.g. a from-anndata-append-with-resize
# on one measurement while the experiment's other measurements aren't being
# updated -- then we need to find those other measurements' var-shapes.
with tiledbsoma.Experiment.open(uri) as exp:
arg_keys = sorted(nvars.keys())
ms_keys = sorted(exp.ms.keys())
if arg_keys != ms_keys:
raise ValueError(
f"resize_experiment: provided nvar keys {arg_keys} do not match experiment keys {ms_keys}"
)
for ms_key in exp.ms.keys():
if ms_key not in nvars.keys():
nvars[ms_key] = exp.ms[ms_key].var._maybe_soma_joinid_shape or 1

ok = _treewalk(
uri,
Expand Down
5 changes: 0 additions & 5 deletions apis/python/tests/test_registration_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,6 @@ def test_multiples_without_experiment(
var_field_name=var_field_name,
)

# XXX TO DO

assert rd.obs_axis.id_mapping_from_values(["AGAG", "GGAG"]).data == (2, 8)
assert rd.var_axes["measname"].id_mapping_from_values(["ESR1", "VEGFA"]).data == (
2,
Expand Down Expand Up @@ -466,7 +464,6 @@ def test_multiples_without_experiment(
nvars=rd.get_var_shapes(),
)

# XXX FIXME
tiledbsoma.io.from_h5ad(
experiment_uri,
h5ad_file_name,
Expand Down Expand Up @@ -860,7 +857,6 @@ def test_append_with_disjoint_measurements(
)

if tiledbsoma._flags.NEW_SHAPE_FEATURE_FLAG_ENABLED:
# XXX FIXME
tiledbsoma.io.resize_experiment(
soma_uri,
nobs=rd.get_obs_shape(),
Expand Down Expand Up @@ -1223,7 +1219,6 @@ def test_enum_bit_width_append(tmp_path, all_at_once, nobs_a, nobs_b):
)

if tiledbsoma._flags.NEW_SHAPE_FEATURE_FLAG_ENABLED:
# XXX FIXME
tiledbsoma.io.resize_experiment(
soma_uri,
nobs=rd.get_obs_shape(),
Expand Down

0 comments on commit 251542b

Please sign in to comment.