diff --git a/firecrown/data_functions.py b/firecrown/data_functions.py index 663d5c46..17a65574 100644 --- a/firecrown/data_functions.py +++ b/firecrown/data_functions.py @@ -240,16 +240,16 @@ class TwoPointTracerSpec(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) - bin_name: Annotated[str, Field(description="The name of the tracer bin.")] - bin_measurement: Annotated[ + name: Annotated[str, Field(description="The name of the tracer bin.")] + measurement: Annotated[ Measurement, Field(description="The measurement of the tracer bin."), BeforeValidator(make_measurement), ] - @field_serializer("bin_measurement") + @field_serializer("measurement") @classmethod - def serialize_bin_measurement(cls, value: Measurement) -> dict[str, str]: + def serialize_measurement(cls, value: Measurement) -> dict[str, str]: """Serialize the Measurement.""" return make_measurement_dict(value) @@ -276,13 +276,13 @@ class TwoPointBinFilter(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) - bin_spec: Annotated[ + spec: Annotated[ list[TwoPointTracerSpec], Field( description="The two-point bin specification.", ), ] - bin_filter: Annotated[ + interval: Annotated[ tuple[float, float], BeforeValidator(make_interval_from_list), Field(description="The range of the bin to filter."), @@ -291,55 +291,47 @@ class TwoPointBinFilter(BaseModel): @model_validator(mode="after") def check_bin_filter(self) -> "TwoPointBinFilter": """Check the bin filter.""" - if self.bin_filter[0] >= self.bin_filter[1]: + if self.interval[0] >= self.interval[1]: raise ValueError("The bin filter should be a valid range.") - if 1 > len(self.bin_spec) > 2: + if 1 > len(self.spec) > 2: raise ValueError("The bin_spec must contain one or two elements.") return self - @field_serializer("bin_filter") + @field_serializer("interval") @classmethod - def serialize_bin_filter(cls, value: tuple[float, float]) -> list[float]: + def serialize_interval(cls, value: tuple[float, float]) -> list[float]: """Serialize the Measurement.""" return list(value) @classmethod def from_args( cls, - bin_name1: str, - bin_measurement1: Measurement, - bin_name2: str, - bin_measurement2: Measurement, - bin_lower: float, - bin_upper: float, + name1: str, + measurement1: Measurement, + name2: str, + measurement2: Measurement, + lower: float, + upper: float, ) -> "TwoPointBinFilter": """Create a TwoPointBinFilter from the arguments.""" return cls( - bin_spec=[ - TwoPointTracerSpec( - bin_name=bin_name1, bin_measurement=bin_measurement1 - ), - TwoPointTracerSpec( - bin_name=bin_name2, bin_measurement=bin_measurement2 - ), + spec=[ + TwoPointTracerSpec(name=name1, measurement=measurement1), + TwoPointTracerSpec(name=name2, measurement=measurement2), ], - bin_filter=(bin_lower, bin_upper), + interval=(lower, upper), ) @classmethod def from_args_auto( - cls, - bin_name: str, - bin_measurement: Measurement, - bin_lower: float, - bin_upper: float, + cls, name: str, measurement: Measurement, lower: float, upper: float ) -> "TwoPointBinFilter": """Create a TwoPointBinFilter from the arguments.""" return cls( - bin_spec=[ - TwoPointTracerSpec(bin_name=bin_name, bin_measurement=bin_measurement), + spec=[ + TwoPointTracerSpec(name=name, measurement=measurement), ], - bin_filter=(bin_lower, bin_upper), + interval=(lower, upper), ) @@ -351,12 +343,12 @@ def bin_spec_from_metadata(metadata: TwoPointReal | TwoPointHarmonic) -> BinSpec return frozenset( ( TwoPointTracerSpec( - bin_name=metadata.XY.x.bin_name, - bin_measurement=metadata.XY.x_measurement, + name=metadata.XY.x.bin_name, + measurement=metadata.XY.x_measurement, ), TwoPointTracerSpec( - bin_name=metadata.XY.y.bin_name, - bin_measurement=metadata.XY.y_measurement, + name=metadata.XY.y.bin_name, + measurement=metadata.XY.y_measurement, ), ) ) @@ -367,12 +359,12 @@ class TwoPointBinFilterCollection(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) - bin_filters: list[TwoPointBinFilter] = Field( + filters: list[TwoPointBinFilter] = Field( description="The list of bin filters.", ) require_filter_for_all: bool = Field( default=False, - description="If True, all bins should have a filter.", + description="If True, all bins should match a filter.", ) allow_empty: bool = Field( default=False, @@ -388,18 +380,18 @@ class TwoPointBinFilterCollection(BaseModel): def check_bin_filters(self) -> "TwoPointBinFilterCollection": """Check the bin filters.""" bin_specs = set() - for bin_filter in self.bin_filters: - bin_spec = frozenset(bin_filter.bin_spec) + for bin_filter in self.filters: + bin_spec = frozenset(bin_filter.spec) if bin_spec in bin_specs: raise ValueError( - f"The bin name {bin_filter.bin_spec} is repeated " + f"The bin name {bin_filter.spec} is repeated " f"in the bin filters." ) bin_specs.add(bin_spec) self._bin_filter_dict = { - frozenset(bin_filter.bin_spec): bin_filter.bin_filter - for bin_filter in self.bin_filters + frozenset(bin_filter.spec): bin_filter.interval + for bin_filter in self.filters } return self diff --git a/tests/metadata/test_data_functions.py b/tests/metadata/test_data_functions.py index 8ab205a5..ee316c03 100644 --- a/tests/metadata/test_data_functions.py +++ b/tests/metadata/test_data_functions.py @@ -111,7 +111,7 @@ def fixture_harmonic_filter_collection(request) -> TwoPointBinFilterCollection: for (name_a, m_a), (name_b, m_b) in request.param ] - return TwoPointBinFilterCollection(bin_filters=bin_filters) + return TwoPointBinFilterCollection(filters=bin_filters) ALL_REAL_BINS = list( @@ -137,75 +137,75 @@ def fixture_real_filter_collection(request) -> TwoPointBinFilterCollection: for (name_a, m_a), (name_b, m_b) in request.param ] - return TwoPointBinFilterCollection(bin_filters=bin_filters) + return TwoPointBinFilterCollection(filters=bin_filters) def test_two_point_bin_filter_construct(): bin_spec = [ - TwoPointTracerSpec(bin_name="bin_1", bin_measurement=Galaxies.COUNTS), - TwoPointTracerSpec(bin_name="bin_2", bin_measurement=Galaxies.SHEAR_E), + TwoPointTracerSpec(name="bin_1", measurement=Galaxies.COUNTS), + TwoPointTracerSpec(name="bin_2", measurement=Galaxies.SHEAR_E), ] - bin_filter = TwoPointBinFilter(bin_spec=bin_spec, bin_filter=(0.1, 0.5)) - assert bin_filter.bin_spec == bin_spec - assert bin_filter.bin_filter == (0.1, 0.5) + bin_filter = TwoPointBinFilter(spec=bin_spec, interval=(0.1, 0.5)) + assert bin_filter.spec == bin_spec + assert bin_filter.interval == (0.1, 0.5) bin_filter_from_args = TwoPointBinFilter.from_args( "bin_1", Galaxies.COUNTS, "bin_2", Galaxies.SHEAR_E, 0.1, 0.5 ) - assert bin_filter_from_args.bin_spec == bin_spec - assert bin_filter_from_args.bin_filter == (0.1, 0.5) + assert bin_filter_from_args.spec == bin_spec + assert bin_filter_from_args.interval == (0.1, 0.5) def test_two_point_bin_filter_construct_auto(): - bin_spec = [TwoPointTracerSpec(bin_name="bin_1", bin_measurement=Galaxies.COUNTS)] - bin_filter = TwoPointBinFilter(bin_spec=bin_spec, bin_filter=(0.1, 0.5)) - assert bin_filter.bin_spec == bin_spec - assert bin_filter.bin_filter == (0.1, 0.5) + bin_spec = [TwoPointTracerSpec(name="bin_1", measurement=Galaxies.COUNTS)] + bin_filter = TwoPointBinFilter(spec=bin_spec, interval=(0.1, 0.5)) + assert bin_filter.spec == bin_spec + assert bin_filter.interval == (0.1, 0.5) bin_filter_from_args = TwoPointBinFilter.from_args_auto( "bin_1", Galaxies.COUNTS, 0.1, 0.5 ) - assert bin_filter_from_args.bin_spec == bin_spec - assert bin_filter_from_args.bin_filter == (0.1, 0.5) + assert bin_filter_from_args.spec == bin_spec + assert bin_filter_from_args.interval == (0.1, 0.5) def test_two_point_bin_filter_construct_invalid_range(): bin_spec = [ - TwoPointTracerSpec(bin_name="bin_1", bin_measurement=Galaxies.COUNTS), - TwoPointTracerSpec(bin_name="bin_2", bin_measurement=Galaxies.SHEAR_E), + TwoPointTracerSpec(name="bin_1", measurement=Galaxies.COUNTS), + TwoPointTracerSpec(name="bin_2", measurement=Galaxies.SHEAR_E), ] with pytest.raises( ValueError, match="Value error, The bin filter should be a valid range." ): - TwoPointBinFilter(bin_spec=bin_spec, bin_filter=(0.5, 0.1)) + TwoPointBinFilter(spec=bin_spec, interval=(0.5, 0.1)) def test_two_point_bin_filter_collection_construct(): bin_spec = ( - TwoPointTracerSpec(bin_name="bin_1", bin_measurement=Galaxies.COUNTS), - TwoPointTracerSpec(bin_name="bin_2", bin_measurement=Galaxies.SHEAR_E), + TwoPointTracerSpec(name="bin_1", measurement=Galaxies.COUNTS), + TwoPointTracerSpec(name="bin_2", measurement=Galaxies.SHEAR_E), ) bin_filter = TwoPointBinFilter.from_args( "bin_1", Galaxies.COUNTS, "bin_2", Galaxies.SHEAR_E, 0.1, 0.5 ) - bin_filter_collection = TwoPointBinFilterCollection(bin_filters=[bin_filter]) - assert bin_filter_collection.bin_filters == [bin_filter] + bin_filter_collection = TwoPointBinFilterCollection(filters=[bin_filter]) + assert bin_filter_collection.filters == [bin_filter] assert bin_filter_collection.bin_filter_dict == {frozenset(bin_spec): (0.1, 0.5)} def test_two_point_bin_filter_collection_construct_same_name() -> None: bin_spec = [ - TwoPointTracerSpec(bin_name="bin_1", bin_measurement=Galaxies.COUNTS), - TwoPointTracerSpec(bin_name="bin_2", bin_measurement=Galaxies.SHEAR_E), + TwoPointTracerSpec(name="bin_1", measurement=Galaxies.COUNTS), + TwoPointTracerSpec(name="bin_2", measurement=Galaxies.SHEAR_E), ] - bin_filter_1 = TwoPointBinFilter(bin_spec=bin_spec, bin_filter=(0.1, 0.5)) - bin_filter_2 = TwoPointBinFilter(bin_spec=bin_spec, bin_filter=(0.5, 0.9)) + bin_filter_1 = TwoPointBinFilter(spec=bin_spec, interval=(0.1, 0.5)) + bin_filter_2 = TwoPointBinFilter(spec=bin_spec, interval=(0.5, 0.9)) with pytest.raises( ValueError, match="The bin name .* is repeated in the bin filters." ): - TwoPointBinFilterCollection(bin_filters=[bin_filter_1, bin_filter_2]) + TwoPointBinFilterCollection(filters=[bin_filter_1, bin_filter_2]) def test_two_point_harmonic_bin_filter_collection_filter_match( @@ -380,7 +380,7 @@ def test_two_point_harmonic_bin_filter_collection_call_require( harmonic_bin_1: InferredGalaxyZDist, ) -> None: harmonic_filter_collection_no_empty = TwoPointBinFilterCollection( - bin_filters=[ + filters=[ TwoPointBinFilter.from_args( "bin_2", Galaxies.SHEAR_E, "bin_2", Galaxies.SHEAR_E, 5, 60 ) @@ -407,7 +407,7 @@ def test_two_point_harmonic_bin_filter_collection_call_no_empty( ) -> None: cm = list(harmonic_bin_1.measurements)[0] harmonic_filter_collection_no_empty = TwoPointBinFilterCollection( - bin_filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 1000, 2000)], + filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 1000, 2000)], require_filter_for_all=True, ) harmonic_bins = [ @@ -436,7 +436,7 @@ def test_two_point_harmonic_bin_filter_collection_call_empty( ) -> None: cm = list(harmonic_bin_1.measurements)[0] harmonic_filter_collection_no_empty = TwoPointBinFilterCollection( - bin_filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 1000, 2000)], + filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 1000, 2000)], allow_empty=True, ) harmonic_bins = [ @@ -489,7 +489,7 @@ def test_two_point_real_bin_filter_collection_call_require( ) -> None: cm = list(real_bin_1.measurements)[0] real_filter_collection_no_empty = TwoPointBinFilterCollection( - bin_filters=[TwoPointBinFilter.from_args("bin_2", cm, "bin_2", cm, 0.1, 0.6)], + filters=[TwoPointBinFilter.from_args("bin_2", cm, "bin_2", cm, 0.1, 0.6)], require_filter_for_all=True, ) real_bins = [ @@ -512,7 +512,7 @@ def test_two_point_real_bin_filter_collection_call_no_empty( ) -> None: cm = list(real_bin_1.measurements)[0] real_filter_collection_no_empty = TwoPointBinFilterCollection( - bin_filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 10.1, 10.6)], + filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 10.1, 10.6)], require_filter_for_all=True, ) real_bins = [ @@ -541,7 +541,7 @@ def test_two_point_real_bin_filter_collection_call_empty( ) -> None: cm = list(real_bin_1.measurements)[0] real_filter_collection_no_empty = TwoPointBinFilterCollection( - bin_filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 10.1, 10.6)], + filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 10.1, 10.6)], allow_empty=True, ) real_bins = [