Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply transforms on label and weight #565

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/zh/examples/tempoGAN.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ examples/tempoGAN/functions.py:411:427

``` py linenums="430"
--8<--
examples/tempoGAN/functions.py:430:481
examples/tempoGAN/functions.py:430:488
--8<--
```

Expand Down
13 changes: 10 additions & 3 deletions examples/tempoGAN/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,16 @@ def __init__(
self.density_min = density_min
self.max_turn = max_turn

def transform(self, input_item: Dict[str, np.ndarray]) -> Dict[str, paddle.Tensor]:
def transform(
self,
input_item: Dict[str, np.ndarray],
label_item: Dict[str, np.ndarray],
weight_item: Dict[str, np.ndarray],
) -> Union[
Dict[str, paddle.Tensor], Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]
]:
if self.tile_ratio == 1:
return input_item
return input_item, label_item, weight_item
for _ in range(self.max_turn):
rand_ratio = np.random.rand()
density_low = self.cut_data(input_item["density_low"], rand_ratio)
Expand All @@ -455,7 +462,7 @@ def transform(self, input_item: Dict[str, np.ndarray]) -> Dict[str, paddle.Tenso

input_item["density_low"] = density_low
input_item["density_high"] = density_high
return input_item
return input_item, label_item, weight_item

def cut_data(self, data: np.ndarray, rand_ratio: float) -> paddle.Tensor:
# data: C,H,W
Expand Down
5 changes: 3 additions & 2 deletions ppsci/data/dataset/array_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def __getitem__(self, idx):
label_item = {key: value[idx] for key, value in self.label.items()}
weight_item = {key: value[idx] for key, value in self.weight.items()}

# TODO(sensen): Transforms may be applied on label and weight.
if self.transforms is not None:
input_item = self.transforms(input_item)
input_item, label_item, weight_item = self.transforms(
(input_item, label_item, weight_item)
)

return (input_item, label_item, weight_item)

Expand Down
5 changes: 3 additions & 2 deletions ppsci/data/dataset/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ def __getitem__(self, idx):
label_item = {key: value[idx] for key, value in self.label.items()}
weight_item = {key: value[idx] for key, value in self.weight.items()}

# TODO(sensen): Transforms may be applied on label and weight.
if self.transforms is not None:
input_item = self.transforms(input_item)
input_item, label_item, weight_item = self.transforms(
(input_item, label_item, weight_item)
)

return (input_item, label_item, weight_item)

Expand Down
5 changes: 3 additions & 2 deletions ppsci/data/dataset/mat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ def __getitem__(self, idx):
label_item = {key: value[idx] for key, value in self.label.items()}
weight_item = {key: value[idx] for key, value in self.weight.items()}

# TODO(sensen): Transforms may be applied on label and weight.
if self.transforms is not None:
input_item = self.transforms(input_item)
input_item, label_item, weight_item = self.transforms(
(input_item, label_item, weight_item)
)

return (input_item, label_item, weight_item)

Expand Down
5 changes: 3 additions & 2 deletions ppsci/data/dataset/npz_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@ def __getitem__(self, idx):
label_item = {key: value[idx] for key, value in self.label.items()}
weight_item = {key: value[idx] for key, value in self.weight.items()}

# TODO(sensen): Transforms may be applied on label and weight.
if self.transforms is not None:
input_item = self.transforms(input_item)
input_item, label_item, weight_item = self.transforms(
(input_item, label_item, weight_item)
)

return (input_item, label_item, weight_item)

Expand Down
80 changes: 48 additions & 32 deletions ppsci/data/process/transform/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@ class Translate:
def __init__(self, offset: Dict[str, float]):
self.offset = offset

def __call__(self, data_dict):
def __call__(self, data):
data_dict, label_dict, weight_dict = data
data_dict_copy = {**data_dict}
for key in self.offset:
if key in data_dict:
data_dict[key] += self.offset[key]
return data_dict
if key in data_dict_copy:
data_dict_copy[key] += self.offset[key]
return data_dict_copy, label_dict, weight_dict


class Scale:
Expand All @@ -59,11 +61,13 @@ class Scale:
def __init__(self, scale: Dict[str, float]):
self.scale = scale

def __call__(self, data_dict):
def __call__(self, data):
data_dict, label_dict, weight_dict = data
data_dict_copy = {**data_dict}
for key in self.scale:
if key in data_dict:
data_dict[key] *= self.scale[key]
return data_dict
if key in data_dict_copy:
data_dict_copy[key] *= self.scale[key]
return data_dict_copy, label_dict, weight_dict


class Normalize:
Expand Down Expand Up @@ -95,13 +99,15 @@ def __init__(

def __call__(self, data):
input_item, label_item, weight_item = data
input_item_copy = {**input_item}
label_item_copy = {**label_item}
if "input" in self.apply_keys:
for key, value in input_item.items():
input_item[key] = (value - self.mean) / self.std
for key, value in input_item_copy.items():
input_item_copy[key] = (value - self.mean) / self.std
if "label" in self.apply_keys:
for key, value in label_item.items():
label_item[key] = (value - self.mean) / self.std
return input_item, label_item, weight_item
for key, value in label_item_copy.items():
label_item_copy[key] = (value - self.mean) / self.std
return input_item_copy, label_item_copy, weight_item


class Log1p:
Expand Down Expand Up @@ -130,13 +136,15 @@ def __init__(

def __call__(self, data):
input_item, label_item, weight_item = data
input_item_copy = {**input_item}
label_item_copy = {**label_item}
if "input" in self.apply_keys:
for key, value in input_item.items():
input_item[key] = np.log1p(value / self.scale)
for key, value in input_item_copy.items():
input_item_copy[key] = np.log1p(value / self.scale)
if "label" in self.apply_keys:
for key, value in label_item.items():
label_item[key] = np.log1p(value / self.scale)
return input_item, label_item, weight_item
for key, value in label_item_copy.items():
label_item_copy[key] = np.log1p(value / self.scale)
return input_item_copy, label_item_copy, weight_item


class CropData:
Expand Down Expand Up @@ -168,17 +176,19 @@ def __init__(

def __call__(self, data):
input_item, label_item, weight_item = data
input_item_copy = {**input_item}
label_item_copy = {**label_item}
if "input" in self.apply_keys:
for key, value in input_item.items():
input_item[key] = value[
for key, value in input_item_copy.items():
input_item_copy[key] = value[
:, self.xmin[0] : self.xmax[0], self.xmin[1] : self.xmax[1]
]
if "label" in self.apply_keys:
for key, value in label_item.items():
label_item[key] = value[
for key, value in label_item_copy.items():
label_item_copy[key] = value[
:, self.xmin[0] : self.xmax[0], self.xmin[1] : self.xmax[1]
]
return input_item, label_item, weight_item
return input_item_copy, label_item_copy, weight_item


class SqueezeData:
Expand All @@ -201,25 +211,27 @@ def __init__(self, apply_keys: Tuple[str, ...] = ("input", "label")):

def __call__(self, data):
input_item, label_item, weight_item = data
input_item_copy = {**input_item}
label_item_copy = {**label_item}
if "input" in self.apply_keys:
for key, value in input_item.items():
for key, value in input_item_copy.items():
if value.ndim == 4:
B, C, H, W = value.shape
input_item[key] = value.reshape((B * C, H, W))
input_item_copy[key] = value.reshape((B * C, H, W))
if value.ndim != 3:
raise ValueError(
f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}"
)
if "label" in self.apply_keys:
for key, value in label_item.items():
for key, value in label_item_copy.items():
if value.ndim == 4:
B, C, H, W = value.shape
label_item[key] = value.reshape((B * C, H, W))
label_item_copy[key] = value.reshape((B * C, H, W))
if value.ndim != 3:
raise ValueError(
f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}"
)
return input_item, label_item, weight_item
return input_item_copy, label_item_copy, weight_item


class FunctionalTransform:
Expand All @@ -231,11 +243,11 @@ class FunctionalTransform:
Examples:
>>> import ppsci
>>> import numpy as np
>>> def transform_func(data_dict):
>>> def transform_func(data_dict, label_dict, weight_dict):
... rand_ratio = np.random.rand()
... for key in data_dict:
... data_dict[key] = data_dict[key] * rand_ratio
... return data_dict
... return data_dict, label_dict, weight_dict
>>> transform_cfg = {
... "transforms": (
... {
Expand All @@ -253,5 +265,9 @@ def __init__(
):
self.transform_func = transform_func

def __call__(self, data_dict: Dict[str, np.ndarray]):
return self.transform_func(data_dict)
def __call__(self, data: Tuple[Dict[str, np.ndarray], ...]):
data_dict, label_dict, weight_dict = data
data_dict_copy = {**data_dict}
label_dict_copy = {**label_dict}
weight_dict_copy = {**weight_dict} if weight_dict else None
return self.transform_func(data_dict_copy, label_dict_copy, weight_dict_copy)