Skip to content

Commit

Permalink
Add pyndantic config (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
benmalef authored Nov 22, 2024
1 parent 53d3849 commit 8c4c6dd
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 4 deletions.
7 changes: 6 additions & 1 deletion GANDLF/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from GANDLF.metrics import surface_distance_ids
from importlib.metadata import version
from GANDLF.utils.pydantic_config import Parameters

## dictionary to define defaults for appropriate options, which are evaluated
parameter_defaults = {
Expand Down Expand Up @@ -653,6 +654,7 @@ def _parseConfig(
if "opt" in params:
print("DeprecationWarning: 'opt' has been superseded by 'optimizer'")
params["optimizer"] = params["opt"]
params.pop("opt")

# initialize defaults for patch sampler
temp_patch_sampler_dict = {
Expand Down Expand Up @@ -747,7 +749,10 @@ def ConfigManager(
dict: The parameter dictionary.
"""
try:
return _parseConfig(config_file_path, version_check_flag)
parameters = Parameters(
**_parseConfig(config_file_path, version_check_flag)
).model_dump()
return parameters
except Exception as e:
## todo: ensure logging captures assertion errors
assert (
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/models/imagenet_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def __init__(self, parameters) -> None:
)

# all BatchNorm should be replaced with InstanceNorm for DP experiments
if "differential_privacy" in parameters:
if parameters["differential_privacy"] is not None:
self.replace_batchnorm(self.model)

if self.n_dimensions == 3:
Expand Down
4 changes: 2 additions & 2 deletions GANDLF/utils/data_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def split_data(
"nested_training" in parameters
), "`nested_training` key missing in parameters"
# populate the headers
if "headers" not in parameters:
if parameters["headers"] is None:
_, parameters["headers"] = parseTrainingCSV(full_dataset)

parameters = (
populate_header_in_parameters(parameters, parameters["headers"])
if "problem_type" not in parameters
if parameters["problem_type"] is None
else parameters
)

Expand Down
78 changes: 78 additions & 0 deletions GANDLF/utils/pydantic_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from pydantic import BaseModel, ConfigDict, Field
from typing import Dict, Any, List, Optional
from enum import Enum
from GANDLF.models.modelBase import ModelBase
from typing import Union


class Version(BaseModel):
minimum: str
maximum: str


class Model(BaseModel):
dimension: int
base_filters: int
architecture: str
norm_type: str
final_layer: str
class_list: list[Union[int, str]]
ignore_label_validation: Union[int, None]
amp: bool
print_summary: bool
type: str
data_type: str
save_at_every_epoch: bool
num_channels: Optional[int] = None


class Parameters(BaseModel):
model_config = ConfigDict(extra="forbid")
version: Version
model: Model
modality: str
scheduler: dict
learning_rate: float
weighted_loss: bool
verbose: bool
q_verbose: bool
medcam_enabled: bool
save_training: bool
save_output: bool
in_memory: bool
pin_memory_dataloader: bool
scaling_factor: Union[float, int]
q_max_length: int
q_samples_per_volume: int
q_num_workers: int
num_epochs: int
patience: int
batch_size: int
learning_rate: float
clip_grad: Union[None, float]
track_memory_usage: bool
memory_save_mode: bool
print_rgb_label_warning: bool
data_postprocessing: Dict # TODO: maybe is better to create a class
data_preprocessing: Dict # TODO: maybe is better to create a class
grid_aggregator_overlap: str
determinism: bool
previous_parameters: None
metrics: Union[List, dict]
patience: int
parallel_compute_command: Union[str, bool, None]
loss_function: Union[str, Dict]
data_augmentation: dict # TODO: maybe is better to create a class
nested_training: dict # TODO: maybe is better to create a class
optimizer: Union[dict, str]
patch_sampler: Union[dict, str]
patch_size: Union[List[int], int]
clip_mode: Union[str, None]
inference_mechanism: dict
data_postprocessing_after_reverse_one_hot_encoding: dict
enable_padding: Optional[Union[dict, bool]] = None
headers: Optional[dict] = None
output_dir: Optional[str] = ""
problem_type: Optional[str] = None
differential_privacy: Optional[dict] = None
# opt: Optional[Union[dict, str]] = {} # TODO find a better way
10 changes: 10 additions & 0 deletions samples/test_config_file.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version:
{
minimum: 0.1.2-dev,
maximum: 0.1.2-dev # this should NOT be made a variable, but should be tested after every tag is created
}
model:
{
dimension: 3, # the dimension of the model and dataset: defines dimensionality of computations
base_filters: "adsa"
}
1 change: 1 addition & 0 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_patch_size_in_microns,
convert_to_tiff,
)
from GANDLF.utils.pydantic_config import Parameters
from GANDLF.config_manager import ConfigManager
from GANDLF.parseConfig import parseConfig
from GANDLF.training_manager import TrainingManager
Expand Down

0 comments on commit 8c4c6dd

Please sign in to comment.