Skip to content

Commit

Permalink
Better names.
Browse files Browse the repository at this point in the history
  • Loading branch information
vitenti committed Jan 9, 2025
1 parent 461a3b5 commit 7ed861b
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 76 deletions.
78 changes: 35 additions & 43 deletions firecrown/data_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,16 @@ class TwoPointTracerSpec(BaseModel):

model_config = ConfigDict(extra="forbid", frozen=True)

bin_name: Annotated[str, Field(description="The name of the tracer bin.")]
bin_measurement: Annotated[
name: Annotated[str, Field(description="The name of the tracer bin.")]
measurement: Annotated[
Measurement,
Field(description="The measurement of the tracer bin."),
BeforeValidator(make_measurement),
]

@field_serializer("bin_measurement")
@field_serializer("measurement")
@classmethod
def serialize_bin_measurement(cls, value: Measurement) -> dict[str, str]:
def serialize_measurement(cls, value: Measurement) -> dict[str, str]:
"""Serialize the Measurement."""
return make_measurement_dict(value)

Expand All @@ -276,13 +276,13 @@ class TwoPointBinFilter(BaseModel):

model_config = ConfigDict(extra="forbid", frozen=True)

bin_spec: Annotated[
spec: Annotated[
list[TwoPointTracerSpec],
Field(
description="The two-point bin specification.",
),
]
bin_filter: Annotated[
interval: Annotated[
tuple[float, float],
BeforeValidator(make_interval_from_list),
Field(description="The range of the bin to filter."),
Expand All @@ -291,55 +291,47 @@ class TwoPointBinFilter(BaseModel):
@model_validator(mode="after")
def check_bin_filter(self) -> "TwoPointBinFilter":
"""Check the bin filter."""
if self.bin_filter[0] >= self.bin_filter[1]:
if self.interval[0] >= self.interval[1]:
raise ValueError("The bin filter should be a valid range.")
if 1 > len(self.bin_spec) > 2:
if 1 > len(self.spec) > 2:
raise ValueError("The bin_spec must contain one or two elements.")

Check warning on line 297 in firecrown/data_functions.py

View check run for this annotation

Codecov / codecov/patch

firecrown/data_functions.py#L297

Added line #L297 was not covered by tests
return self

@field_serializer("bin_filter")
@field_serializer("interval")
@classmethod
def serialize_bin_filter(cls, value: tuple[float, float]) -> list[float]:
def serialize_interval(cls, value: tuple[float, float]) -> list[float]:
"""Serialize the Measurement."""
return list(value)

@classmethod
def from_args(
cls,
bin_name1: str,
bin_measurement1: Measurement,
bin_name2: str,
bin_measurement2: Measurement,
bin_lower: float,
bin_upper: float,
name1: str,
measurement1: Measurement,
name2: str,
measurement2: Measurement,
lower: float,
upper: float,
) -> "TwoPointBinFilter":
"""Create a TwoPointBinFilter from the arguments."""
return cls(
bin_spec=[
TwoPointTracerSpec(
bin_name=bin_name1, bin_measurement=bin_measurement1
),
TwoPointTracerSpec(
bin_name=bin_name2, bin_measurement=bin_measurement2
),
spec=[
TwoPointTracerSpec(name=name1, measurement=measurement1),
TwoPointTracerSpec(name=name2, measurement=measurement2),
],
bin_filter=(bin_lower, bin_upper),
interval=(lower, upper),
)

@classmethod
def from_args_auto(
cls,
bin_name: str,
bin_measurement: Measurement,
bin_lower: float,
bin_upper: float,
cls, name: str, measurement: Measurement, lower: float, upper: float
) -> "TwoPointBinFilter":
"""Create a TwoPointBinFilter from the arguments."""
return cls(
bin_spec=[
TwoPointTracerSpec(bin_name=bin_name, bin_measurement=bin_measurement),
spec=[
TwoPointTracerSpec(name=name, measurement=measurement),
],
bin_filter=(bin_lower, bin_upper),
interval=(lower, upper),
)


Expand All @@ -351,12 +343,12 @@ def bin_spec_from_metadata(metadata: TwoPointReal | TwoPointHarmonic) -> BinSpec
return frozenset(
(
TwoPointTracerSpec(
bin_name=metadata.XY.x.bin_name,
bin_measurement=metadata.XY.x_measurement,
name=metadata.XY.x.bin_name,
measurement=metadata.XY.x_measurement,
),
TwoPointTracerSpec(
bin_name=metadata.XY.y.bin_name,
bin_measurement=metadata.XY.y_measurement,
name=metadata.XY.y.bin_name,
measurement=metadata.XY.y_measurement,
),
)
)
Expand All @@ -367,12 +359,12 @@ class TwoPointBinFilterCollection(BaseModel):

model_config = ConfigDict(extra="forbid", frozen=True)

bin_filters: list[TwoPointBinFilter] = Field(
filters: list[TwoPointBinFilter] = Field(
description="The list of bin filters.",
)
require_filter_for_all: bool = Field(
default=False,
description="If True, all bins should have a filter.",
description="If True, all bins should match a filter.",
)
allow_empty: bool = Field(
default=False,
Expand All @@ -388,18 +380,18 @@ class TwoPointBinFilterCollection(BaseModel):
def check_bin_filters(self) -> "TwoPointBinFilterCollection":
"""Check the bin filters."""
bin_specs = set()
for bin_filter in self.bin_filters:
bin_spec = frozenset(bin_filter.bin_spec)
for bin_filter in self.filters:
bin_spec = frozenset(bin_filter.spec)
if bin_spec in bin_specs:
raise ValueError(
f"The bin name {bin_filter.bin_spec} is repeated "
f"The bin name {bin_filter.spec} is repeated "
f"in the bin filters."
)
bin_specs.add(bin_spec)

self._bin_filter_dict = {
frozenset(bin_filter.bin_spec): bin_filter.bin_filter
for bin_filter in self.bin_filters
frozenset(bin_filter.spec): bin_filter.interval
for bin_filter in self.filters
}
return self

Expand Down
66 changes: 33 additions & 33 deletions tests/metadata/test_data_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def fixture_harmonic_filter_collection(request) -> TwoPointBinFilterCollection:
for (name_a, m_a), (name_b, m_b) in request.param
]

return TwoPointBinFilterCollection(bin_filters=bin_filters)
return TwoPointBinFilterCollection(filters=bin_filters)


ALL_REAL_BINS = list(
Expand All @@ -137,75 +137,75 @@ def fixture_real_filter_collection(request) -> TwoPointBinFilterCollection:
for (name_a, m_a), (name_b, m_b) in request.param
]

return TwoPointBinFilterCollection(bin_filters=bin_filters)
return TwoPointBinFilterCollection(filters=bin_filters)


def test_two_point_bin_filter_construct():
bin_spec = [
TwoPointTracerSpec(bin_name="bin_1", bin_measurement=Galaxies.COUNTS),
TwoPointTracerSpec(bin_name="bin_2", bin_measurement=Galaxies.SHEAR_E),
TwoPointTracerSpec(name="bin_1", measurement=Galaxies.COUNTS),
TwoPointTracerSpec(name="bin_2", measurement=Galaxies.SHEAR_E),
]
bin_filter = TwoPointBinFilter(bin_spec=bin_spec, bin_filter=(0.1, 0.5))
assert bin_filter.bin_spec == bin_spec
assert bin_filter.bin_filter == (0.1, 0.5)
bin_filter = TwoPointBinFilter(spec=bin_spec, interval=(0.1, 0.5))
assert bin_filter.spec == bin_spec
assert bin_filter.interval == (0.1, 0.5)

bin_filter_from_args = TwoPointBinFilter.from_args(
"bin_1", Galaxies.COUNTS, "bin_2", Galaxies.SHEAR_E, 0.1, 0.5
)

assert bin_filter_from_args.bin_spec == bin_spec
assert bin_filter_from_args.bin_filter == (0.1, 0.5)
assert bin_filter_from_args.spec == bin_spec
assert bin_filter_from_args.interval == (0.1, 0.5)


def test_two_point_bin_filter_construct_auto():
bin_spec = [TwoPointTracerSpec(bin_name="bin_1", bin_measurement=Galaxies.COUNTS)]
bin_filter = TwoPointBinFilter(bin_spec=bin_spec, bin_filter=(0.1, 0.5))
assert bin_filter.bin_spec == bin_spec
assert bin_filter.bin_filter == (0.1, 0.5)
bin_spec = [TwoPointTracerSpec(name="bin_1", measurement=Galaxies.COUNTS)]
bin_filter = TwoPointBinFilter(spec=bin_spec, interval=(0.1, 0.5))
assert bin_filter.spec == bin_spec
assert bin_filter.interval == (0.1, 0.5)

bin_filter_from_args = TwoPointBinFilter.from_args_auto(
"bin_1", Galaxies.COUNTS, 0.1, 0.5
)

assert bin_filter_from_args.bin_spec == bin_spec
assert bin_filter_from_args.bin_filter == (0.1, 0.5)
assert bin_filter_from_args.spec == bin_spec
assert bin_filter_from_args.interval == (0.1, 0.5)


def test_two_point_bin_filter_construct_invalid_range():
bin_spec = [
TwoPointTracerSpec(bin_name="bin_1", bin_measurement=Galaxies.COUNTS),
TwoPointTracerSpec(bin_name="bin_2", bin_measurement=Galaxies.SHEAR_E),
TwoPointTracerSpec(name="bin_1", measurement=Galaxies.COUNTS),
TwoPointTracerSpec(name="bin_2", measurement=Galaxies.SHEAR_E),
]
with pytest.raises(
ValueError, match="Value error, The bin filter should be a valid range."
):
TwoPointBinFilter(bin_spec=bin_spec, bin_filter=(0.5, 0.1))
TwoPointBinFilter(spec=bin_spec, interval=(0.5, 0.1))


def test_two_point_bin_filter_collection_construct():
bin_spec = (
TwoPointTracerSpec(bin_name="bin_1", bin_measurement=Galaxies.COUNTS),
TwoPointTracerSpec(bin_name="bin_2", bin_measurement=Galaxies.SHEAR_E),
TwoPointTracerSpec(name="bin_1", measurement=Galaxies.COUNTS),
TwoPointTracerSpec(name="bin_2", measurement=Galaxies.SHEAR_E),
)
bin_filter = TwoPointBinFilter.from_args(
"bin_1", Galaxies.COUNTS, "bin_2", Galaxies.SHEAR_E, 0.1, 0.5
)
bin_filter_collection = TwoPointBinFilterCollection(bin_filters=[bin_filter])
assert bin_filter_collection.bin_filters == [bin_filter]
bin_filter_collection = TwoPointBinFilterCollection(filters=[bin_filter])
assert bin_filter_collection.filters == [bin_filter]
assert bin_filter_collection.bin_filter_dict == {frozenset(bin_spec): (0.1, 0.5)}


def test_two_point_bin_filter_collection_construct_same_name() -> None:
bin_spec = [
TwoPointTracerSpec(bin_name="bin_1", bin_measurement=Galaxies.COUNTS),
TwoPointTracerSpec(bin_name="bin_2", bin_measurement=Galaxies.SHEAR_E),
TwoPointTracerSpec(name="bin_1", measurement=Galaxies.COUNTS),
TwoPointTracerSpec(name="bin_2", measurement=Galaxies.SHEAR_E),
]
bin_filter_1 = TwoPointBinFilter(bin_spec=bin_spec, bin_filter=(0.1, 0.5))
bin_filter_2 = TwoPointBinFilter(bin_spec=bin_spec, bin_filter=(0.5, 0.9))
bin_filter_1 = TwoPointBinFilter(spec=bin_spec, interval=(0.1, 0.5))
bin_filter_2 = TwoPointBinFilter(spec=bin_spec, interval=(0.5, 0.9))
with pytest.raises(
ValueError, match="The bin name .* is repeated in the bin filters."
):
TwoPointBinFilterCollection(bin_filters=[bin_filter_1, bin_filter_2])
TwoPointBinFilterCollection(filters=[bin_filter_1, bin_filter_2])


def test_two_point_harmonic_bin_filter_collection_filter_match(
Expand Down Expand Up @@ -380,7 +380,7 @@ def test_two_point_harmonic_bin_filter_collection_call_require(
harmonic_bin_1: InferredGalaxyZDist,
) -> None:
harmonic_filter_collection_no_empty = TwoPointBinFilterCollection(
bin_filters=[
filters=[
TwoPointBinFilter.from_args(
"bin_2", Galaxies.SHEAR_E, "bin_2", Galaxies.SHEAR_E, 5, 60
)
Expand All @@ -407,7 +407,7 @@ def test_two_point_harmonic_bin_filter_collection_call_no_empty(
) -> None:
cm = list(harmonic_bin_1.measurements)[0]
harmonic_filter_collection_no_empty = TwoPointBinFilterCollection(
bin_filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 1000, 2000)],
filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 1000, 2000)],
require_filter_for_all=True,
)
harmonic_bins = [
Expand Down Expand Up @@ -436,7 +436,7 @@ def test_two_point_harmonic_bin_filter_collection_call_empty(
) -> None:
cm = list(harmonic_bin_1.measurements)[0]
harmonic_filter_collection_no_empty = TwoPointBinFilterCollection(
bin_filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 1000, 2000)],
filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 1000, 2000)],
allow_empty=True,
)
harmonic_bins = [
Expand Down Expand Up @@ -489,7 +489,7 @@ def test_two_point_real_bin_filter_collection_call_require(
) -> None:
cm = list(real_bin_1.measurements)[0]
real_filter_collection_no_empty = TwoPointBinFilterCollection(
bin_filters=[TwoPointBinFilter.from_args("bin_2", cm, "bin_2", cm, 0.1, 0.6)],
filters=[TwoPointBinFilter.from_args("bin_2", cm, "bin_2", cm, 0.1, 0.6)],
require_filter_for_all=True,
)
real_bins = [
Expand All @@ -512,7 +512,7 @@ def test_two_point_real_bin_filter_collection_call_no_empty(
) -> None:
cm = list(real_bin_1.measurements)[0]
real_filter_collection_no_empty = TwoPointBinFilterCollection(
bin_filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 10.1, 10.6)],
filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 10.1, 10.6)],
require_filter_for_all=True,
)
real_bins = [
Expand Down Expand Up @@ -541,7 +541,7 @@ def test_two_point_real_bin_filter_collection_call_empty(
) -> None:
cm = list(real_bin_1.measurements)[0]
real_filter_collection_no_empty = TwoPointBinFilterCollection(
bin_filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 10.1, 10.6)],
filters=[TwoPointBinFilter.from_args("bin_1", cm, "bin_1", cm, 10.1, 10.6)],
allow_empty=True,
)
real_bins = [
Expand Down

0 comments on commit 7ed861b

Please sign in to comment.