diff --git a/docs/zh/examples/tempoGAN.md b/docs/zh/examples/tempoGAN.md index ffa69b5cf..4d3382b7f 100644 --- a/docs/zh/examples/tempoGAN.md +++ b/docs/zh/examples/tempoGAN.md @@ -247,7 +247,7 @@ examples/tempoGAN/functions.py:411:427 ``` py linenums="430" --8<-- -examples/tempoGAN/functions.py:430:481 +examples/tempoGAN/functions.py:430:488 --8<-- ``` diff --git a/examples/tempoGAN/functions.py b/examples/tempoGAN/functions.py index bcf27a384..535126e14 100644 --- a/examples/tempoGAN/functions.py +++ b/examples/tempoGAN/functions.py @@ -443,9 +443,16 @@ def __init__( self.density_min = density_min self.max_turn = max_turn - def transform(self, input_item: Dict[str, np.ndarray]) -> Dict[str, paddle.Tensor]: + def transform( + self, + input_item: Dict[str, np.ndarray], + label_item: Dict[str, np.ndarray], + weight_item: Dict[str, np.ndarray], + ) -> Union[ + Dict[str, paddle.Tensor], Dict[str, paddle.Tensor], Dict[str, paddle.Tensor] + ]: if self.tile_ratio == 1: - return input_item + return input_item, label_item, weight_item for _ in range(self.max_turn): rand_ratio = np.random.rand() density_low = self.cut_data(input_item["density_low"], rand_ratio) @@ -455,7 +462,7 @@ def transform(self, input_item: Dict[str, np.ndarray]) -> Dict[str, paddle.Tenso input_item["density_low"] = density_low input_item["density_high"] = density_high - return input_item + return input_item, label_item, weight_item def cut_data(self, data: np.ndarray, rand_ratio: float) -> paddle.Tensor: # data: C,H,W diff --git a/ppsci/data/dataset/array_dataset.py b/ppsci/data/dataset/array_dataset.py index 7f0621702..8e16e6b80 100644 --- a/ppsci/data/dataset/array_dataset.py +++ b/ppsci/data/dataset/array_dataset.py @@ -62,9 +62,10 @@ def __getitem__(self, idx): label_item = {key: value[idx] for key, value in self.label.items()} weight_item = {key: value[idx] for key, value in self.weight.items()} - # TODO(sensen): Transforms may be applied on label and weight. if self.transforms is not None: - input_item = self.transforms(input_item) + input_item, label_item, weight_item = self.transforms( + (input_item, label_item, weight_item) + ) return (input_item, label_item, weight_item) diff --git a/ppsci/data/dataset/csv_dataset.py b/ppsci/data/dataset/csv_dataset.py index eec11e3b9..22e5f0f8a 100644 --- a/ppsci/data/dataset/csv_dataset.py +++ b/ppsci/data/dataset/csv_dataset.py @@ -138,9 +138,10 @@ def __getitem__(self, idx): label_item = {key: value[idx] for key, value in self.label.items()} weight_item = {key: value[idx] for key, value in self.weight.items()} - # TODO(sensen): Transforms may be applied on label and weight. if self.transforms is not None: - input_item = self.transforms(input_item) + input_item, label_item, weight_item = self.transforms( + (input_item, label_item, weight_item) + ) return (input_item, label_item, weight_item) diff --git a/ppsci/data/dataset/mat_dataset.py b/ppsci/data/dataset/mat_dataset.py index 344efb3d7..f681b5b44 100644 --- a/ppsci/data/dataset/mat_dataset.py +++ b/ppsci/data/dataset/mat_dataset.py @@ -138,9 +138,10 @@ def __getitem__(self, idx): label_item = {key: value[idx] for key, value in self.label.items()} weight_item = {key: value[idx] for key, value in self.weight.items()} - # TODO(sensen): Transforms may be applied on label and weight. if self.transforms is not None: - input_item = self.transforms(input_item) + input_item, label_item, weight_item = self.transforms( + (input_item, label_item, weight_item) + ) return (input_item, label_item, weight_item) diff --git a/ppsci/data/dataset/npz_dataset.py b/ppsci/data/dataset/npz_dataset.py index 9f5526e96..6be0a1720 100644 --- a/ppsci/data/dataset/npz_dataset.py +++ b/ppsci/data/dataset/npz_dataset.py @@ -136,9 +136,10 @@ def __getitem__(self, idx): label_item = {key: value[idx] for key, value in self.label.items()} weight_item = {key: value[idx] for key, value in self.weight.items()} - # TODO(sensen): Transforms may be applied on label and weight. if self.transforms is not None: - input_item = self.transforms(input_item) + input_item, label_item, weight_item = self.transforms( + (input_item, label_item, weight_item) + ) return (input_item, label_item, weight_item) diff --git a/ppsci/data/process/transform/preprocess.py b/ppsci/data/process/transform/preprocess.py index 9a7ec0be8..fd4ad31e7 100644 --- a/ppsci/data/process/transform/preprocess.py +++ b/ppsci/data/process/transform/preprocess.py @@ -37,11 +37,13 @@ class Translate: def __init__(self, offset: Dict[str, float]): self.offset = offset - def __call__(self, data_dict): + def __call__(self, data): + data_dict, label_dict, weight_dict = data + data_dict_copy = {**data_dict} for key in self.offset: - if key in data_dict: - data_dict[key] += self.offset[key] - return data_dict + if key in data_dict_copy: + data_dict_copy[key] += self.offset[key] + return data_dict_copy, label_dict, weight_dict class Scale: @@ -59,11 +61,13 @@ class Scale: def __init__(self, scale: Dict[str, float]): self.scale = scale - def __call__(self, data_dict): + def __call__(self, data): + data_dict, label_dict, weight_dict = data + data_dict_copy = {**data_dict} for key in self.scale: - if key in data_dict: - data_dict[key] *= self.scale[key] - return data_dict + if key in data_dict_copy: + data_dict_copy[key] *= self.scale[key] + return data_dict_copy, label_dict, weight_dict class Normalize: @@ -95,13 +99,15 @@ def __init__( def __call__(self, data): input_item, label_item, weight_item = data + input_item_copy = {**input_item} + label_item_copy = {**label_item} if "input" in self.apply_keys: - for key, value in input_item.items(): - input_item[key] = (value - self.mean) / self.std + for key, value in input_item_copy.items(): + input_item_copy[key] = (value - self.mean) / self.std if "label" in self.apply_keys: - for key, value in label_item.items(): - label_item[key] = (value - self.mean) / self.std - return input_item, label_item, weight_item + for key, value in label_item_copy.items(): + label_item_copy[key] = (value - self.mean) / self.std + return input_item_copy, label_item_copy, weight_item class Log1p: @@ -130,13 +136,15 @@ def __init__( def __call__(self, data): input_item, label_item, weight_item = data + input_item_copy = {**input_item} + label_item_copy = {**label_item} if "input" in self.apply_keys: - for key, value in input_item.items(): - input_item[key] = np.log1p(value / self.scale) + for key, value in input_item_copy.items(): + input_item_copy[key] = np.log1p(value / self.scale) if "label" in self.apply_keys: - for key, value in label_item.items(): - label_item[key] = np.log1p(value / self.scale) - return input_item, label_item, weight_item + for key, value in label_item_copy.items(): + label_item_copy[key] = np.log1p(value / self.scale) + return input_item_copy, label_item_copy, weight_item class CropData: @@ -168,17 +176,19 @@ def __init__( def __call__(self, data): input_item, label_item, weight_item = data + input_item_copy = {**input_item} + label_item_copy = {**label_item} if "input" in self.apply_keys: - for key, value in input_item.items(): - input_item[key] = value[ + for key, value in input_item_copy.items(): + input_item_copy[key] = value[ :, self.xmin[0] : self.xmax[0], self.xmin[1] : self.xmax[1] ] if "label" in self.apply_keys: - for key, value in label_item.items(): - label_item[key] = value[ + for key, value in label_item_copy.items(): + label_item_copy[key] = value[ :, self.xmin[0] : self.xmax[0], self.xmin[1] : self.xmax[1] ] - return input_item, label_item, weight_item + return input_item_copy, label_item_copy, weight_item class SqueezeData: @@ -201,25 +211,27 @@ def __init__(self, apply_keys: Tuple[str, ...] = ("input", "label")): def __call__(self, data): input_item, label_item, weight_item = data + input_item_copy = {**input_item} + label_item_copy = {**label_item} if "input" in self.apply_keys: - for key, value in input_item.items(): + for key, value in input_item_copy.items(): if value.ndim == 4: B, C, H, W = value.shape - input_item[key] = value.reshape((B * C, H, W)) + input_item_copy[key] = value.reshape((B * C, H, W)) if value.ndim != 3: raise ValueError( f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}" ) if "label" in self.apply_keys: - for key, value in label_item.items(): + for key, value in label_item_copy.items(): if value.ndim == 4: B, C, H, W = value.shape - label_item[key] = value.reshape((B * C, H, W)) + label_item_copy[key] = value.reshape((B * C, H, W)) if value.ndim != 3: raise ValueError( f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}" ) - return input_item, label_item, weight_item + return input_item_copy, label_item_copy, weight_item class FunctionalTransform: @@ -231,11 +243,11 @@ class FunctionalTransform: Examples: >>> import ppsci >>> import numpy as np - >>> def transform_func(data_dict): + >>> def transform_func(data_dict, label_dict, weight_dict): ... rand_ratio = np.random.rand() ... for key in data_dict: ... data_dict[key] = data_dict[key] * rand_ratio - ... return data_dict + ... return data_dict, label_dict, weight_dict >>> transform_cfg = { ... "transforms": ( ... { @@ -253,5 +265,9 @@ def __init__( ): self.transform_func = transform_func - def __call__(self, data_dict: Dict[str, np.ndarray]): - return self.transform_func(data_dict) + def __call__(self, data: Tuple[Dict[str, np.ndarray], ...]): + data_dict, label_dict, weight_dict = data + data_dict_copy = {**data_dict} + label_dict_copy = {**label_dict} + weight_dict_copy = {**weight_dict} if weight_dict else None + return self.transform_func(data_dict_copy, label_dict_copy, weight_dict_copy)