Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global masking #68

Merged
merged 7 commits into from
Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion velociraptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ def load(

if registration_file_path is not None:
catalogue.register_derived_quantities(registration_file_path)

return catalogue
89 changes: 44 additions & 45 deletions velociraptor/autoplotter/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from unyt import unyt_quantity, unyt_array, matplotlib_support
from unyt.exceptions import UnitConversionError
from numpy import log10, linspace, logspace, array, logical_and
from numpy import log10, linspace, logspace, array, logical_and, ones
from matplotlib.pyplot import Axes, Figure, close
from yaml import safe_load
from typing import Union, List, Dict, Tuple
Expand Down Expand Up @@ -95,6 +95,8 @@ class VelociraptorPlot(object):
observational_data_filenames: List[str]
observational_data_bracket_width: float
observational_data_directory: str
# global mask
global_mask: Union[None, array]

def __init__(
self,
Expand Down Expand Up @@ -743,7 +745,7 @@ def _add_lines_to_axes(self, ax: Axes, x: unyt_array, y: unyt_array) -> None:
return

def get_quantity_from_catalogue_with_mask(
self, quantity: str, catalogue: VelociraptorCatalogue
self, quantity: str, catalogue: VelociraptorCatalogue,
) -> unyt_array:
"""
Get a quantity from the catalogue using the mask.
Expand All @@ -754,61 +756,51 @@ def get_quantity_from_catalogue_with_mask(
# in versions of unyt less than 2.6.0
name = x.name

# temporary masked array to apply global mask
# in concert with plot-specific masks
x_mask = self.global_mask

EvgeniiChaikin marked this conversation as resolved.
Show resolved Hide resolved
if self.structure_mask is not None:
x = x[self.structure_mask]
# if structure_mask already set, mask and return
x_mask = logical_and(x_mask, self.structure_mask)
x = x[x_mask]
x.name = name
elif self.selection_mask is not None:
return x
EvgeniiChaikin marked this conversation as resolved.
Show resolved Hide resolved

# allow all entries by default
self.structure_mask = ones(x.shape).astype(bool)

if self.selection_mask is not None:
# Create mask
self.structure_mask = reduce(
getattr, self.selection_mask.split("."), catalogue
).astype(bool)

if self.select_structure_type is not None:
if self.select_structure_type == self.exclude_structure_type:
raise AutoPlotterError(
f"Cannot simultaneously select and exclude structure"
" type {self.select_structure_type}"
)
self.structure_mask = logical_and(
self.structure_mask,
catalogue.structure_type.structuretype
== self.select_structure_type,
)

elif self.exclude_structure_type is not None:
self.structure_mask = logical_and(
self.structure_mask,
catalogue.structure_type.structuretype
!= self.exclude_structure_type,
)

x = x[self.structure_mask]
x.name = name
elif self.select_structure_type is not None:
if self.select_structure_type is not None:
if self.select_structure_type == self.exclude_structure_type:
raise AutoPlotterError(
f"Cannot simultaneously select and exclude structure"
" type {self.select_structure_type}"
)

# Need to create mask
self.structure_mask = (
catalogue.structure_type.structuretype == self.select_structure_type
self.structure_mask = logical_and(
self.structure_mask,
catalogue.structure_type.structuretype
== self.select_structure_type,
)

x = x[self.structure_mask]
x.name = name
elif self.exclude_structure_type is not None:
# Need to create mask
self.structure_mask = (
catalogue.structure_type.structuretype != self.exclude_structure_type
if self.exclude_structure_type is not None:
self.structure_mask = logical_and(
self.structure_mask,
catalogue.structure_type.structuretype
!= self.exclude_structure_type,
)

# combine global and structure masks
x_mask = logical_and(x_mask, self.structure_mask)

x = x[self.structure_mask]
x.name = name

# apply to the unyt array of values
x = x[x_mask]
x.name = name
return x

def _make_plot_scatter(
self, catalogue: VelociraptorCatalogue
) -> Tuple[Figure, Axes]:
Expand Down Expand Up @@ -974,7 +966,7 @@ def _make_plot_cumulative_histogram(
return fig, ax

def make_plot(
self, catalogue: VelociraptorCatalogue, directory: str, file_extension: str
self, catalogue: VelociraptorCatalogue, directory: str, file_extension: str,
):
"""
Federates out data parsing to individual functions based on the
Expand Down Expand Up @@ -1058,7 +1050,9 @@ class AutoPlotter(object):
observational_data_directory: str
# Whether or not the plots were created successfully.
created_successfully: List[bool]

# global mask
global_mask: Union[None, array]

def __init__(
self,
filename: Union[str, List[str]],
Expand Down Expand Up @@ -1123,14 +1117,18 @@ def parse_yaml(self):

return

def link_catalogue(self, catalogue: VelociraptorCatalogue):
def link_catalogue(self, catalogue: VelociraptorCatalogue, global_mask_tag: str):
"""
Links a catalogue with this object so that the plots
can actually be created.
"""

self.catalogue = catalogue

if global_mask_tag is not None:
EvgeniiChaikin marked this conversation as resolved.
Show resolved Hide resolved
self.global_mask = reduce(getattr, global_mask_tag.split("."), catalogue)
else:
self.global_mask = True
return

def create_plots(
Expand All @@ -1150,6 +1148,7 @@ def create_plots(

for plot in self.plots:
try:
plot.global_mask = self.global_mask
plot.make_plot(
catalogue=self.catalogue,
directory=directory,
Expand Down
1 change: 0 additions & 1 deletion velociraptor/catalogue/catalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,4 +412,3 @@ def register_derived_quantities(self, registration_file_path: str) -> None:
self.derived_quantities = DerivedQuantities(registration_file_path, self)

return

15 changes: 14 additions & 1 deletion velociraptor/catalogue/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def registration_fail_all(
+ name: A fancy (possibly LaTeX'd) name for the field.
+ snake_case: A correct snake_case name for the field.
"""



if field_path == "ThisFieldPathWouldNeverExist":
return (
unit_system.length,
Expand Down Expand Up @@ -1442,7 +1443,18 @@ def registration_spherical_overdensities(
else:
raise RegistrationDoesNotMatchError

def registration_bgpart_masses(
field_path: str, unit_system: VelociraptorUnits
) -> (unyt.Unit, str, str):
"""
Registers the halo mass contributed by background particle interlopers.
"""
if field_path == "Mass_interloper":
return unit_system.mass, "Mass from background particles", "bgpart_masses"
else:
raise RegistrationDoesNotMatchError


EvgeniiChaikin marked this conversation as resolved.
Show resolved Hide resolved
# TODO
# lambda_B
# q
Expand Down Expand Up @@ -1502,6 +1514,7 @@ def registration_spherical_overdensities(
"cold_dense_gas_properties",
"log_element_ratios_times_masses",
"lin_element_ratios_times_masses",
"bgpart_masses",
"fail_all",
]
}