Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.57】Neural networks for topology optimization #559

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions jointContribution/TopOpt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# TopOpt

Neural networks for topology optimization

Paper link: [[ArXiv](https://arxiv.org/abs/1709.09578)]


## Highlights

- Proposed a deep learning based approach for speeding up the topology optimization methods solving layout problems by stating the problem as an image segmentation task
- Introduce convolutional encoder-decoder architecture (UNet) and the overall approach achieved high performance


## 参考

- <https://github.com/ISosnovik/nn4topopt>


## 数据集

整理原始数据集生成hd5格式数据集

``` shell
mkdir -p ./Dataset/PreparedData/

python ./prepare_h5datasets.py
```

## 训练模型

``` shell
python ./training_case1.py
```

## 指标结果

保存在eval_results.ipynb中
可以与源代码结果对比 <https://github.com/ISosnovik/nn4topopt/blob/master/results.ipynb>
116 changes: 116 additions & 0 deletions jointContribution/TopOpt/TopOptModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import paddle
from paddle import nn

import ppsci


# NCHW data format
class TopOptNN(ppsci.arch.UNetEx):
def __init__(
self,
input_key="input",
output_key="output",
in_channel=2,
out_channel=1,
kernel_size=3,
filters=(16, 32, 64),
layers=2,
weight_norm=False,
batch_norm=False,
activation=nn.ReLU,
):
self.in_channel = in_channel
self.out_channel = out_channel
self.filters = filters
super().__init__(
input_key=input_key,
output_key=output_key,
in_channel=in_channel,
out_channel=out_channel,
kernel_size=kernel_size,
filters=filters,
layers=layers,
weight_norm=weight_norm,
batch_norm=batch_norm,
activation=activation,
)
# Modify Layers
self.encoder[1] = nn.Sequential(
nn.MaxPool2D(self.in_channel, padding="SAME"),
self.encoder[1][0],
nn.Dropout2D(0.1),
self.encoder[1][1],
)
self.encoder[2] = nn.Sequential(
nn.MaxPool2D(2, padding="SAME"), self.encoder[2]
)
# Conv2D used in reference code in decoder
self.decoders[0] = nn.Sequential(
nn.Conv2D(
self.filters[-1], self.filters[-1], kernel_size=3, padding="SAME"
),
nn.ReLU(),
nn.Conv2D(
self.filters[-1], self.filters[-1], kernel_size=3, padding="SAME"
),
nn.ReLU(),
)
self.decoders[1] = nn.Sequential(
nn.Conv2D(
sum(self.filters[-2:]), self.filters[-2], kernel_size=3, padding="SAME"
),
nn.ReLU(),
nn.Dropout2D(0.1),
nn.Conv2D(
self.filters[-2], self.filters[-2], kernel_size=3, padding="SAME"
),
nn.ReLU(),
)
self.decoders[2] = nn.Sequential(
nn.Conv2D(
sum(self.filters[:-1]), self.filters[-3], kernel_size=3, padding="SAME"
),
nn.ReLU(),
nn.Conv2D(
self.filters[-3], self.filters[-3], kernel_size=3, padding="SAME"
),
nn.ReLU(),
)
self.output = nn.Sequential(
nn.Conv2D(
self.filters[-3], self.out_channel, kernel_size=3, padding="SAME"
),
nn.Sigmoid(),
)

def forward(self, x):
x = x[self.input_keys[0]].squeeze(axis=0) # squeeze additional batch dimension
# Layer 1 (bs, 2, 40, 40) -> (bs, 16, 40, 40)
conv1 = self.encoder[0](x)
up_size_2 = conv1.shape[-2:]
# Layer 2 (bs, 16, 40, 40) -> (bs, 32, 20, 20)
conv2 = self.encoder[1](conv1)
up_size_1 = conv2.shape[-2:]
# Layer 3 (bs, 32, 20, 20) -> (bs, 64, 10, 10)
conv3 = self.encoder[2](conv2)

# Layer 4 (bs, 64, 10, 10) -> (bs, 64, 10, 10)
conv4 = self.decoders[0](conv3)
# upsampling (bs, 64, 10, 10) -> (bs, 64, 20, 20)
conv4 = nn.UpsamplingNearest2D(up_size_1)(conv4)

# concat (bs, 64, 20, 20) -> (bs, 96, 20, 20)
conv5 = paddle.concat((conv2, conv4), axis=1)
# Layer 5 (bs, 96, 20, 20) -> (bs, 32, 20, 20)
conv5 = self.decoders[1](conv5)
# upsampling (bs, 32, 20, 20) -> (bs, 32, 40, 40)
conv5 = nn.UpsamplingNearest2D(up_size_2)(conv5)

# concat (bs, 32, 40, 40) -> (bs, 48, 40, 40)
conv6 = paddle.concat((conv1, conv5), axis=1)
# Layer 6 (bs, 48, 40, 40) -> (bs, 16, 40, 40)
conv6 = self.decoders[2](conv6)
# Output (bs, 16, 40, 40) -> (bs, 1, 40, 40)
out = self.output(conv6)

return {self.output_keys[0]: out}
196 changes: 196 additions & 0 deletions jointContribution/TopOpt/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import numpy as np
import paddle
from paddle.io import Dataset

from ppsci.constraint.base import Constraint

# from ppsci.validate.base import Validator


def augmentation(input, label):
"""Apply random transformation from D4 symmetry group
# Arguments
x_batch, y_batch: input tensors of size `(batch_size, any, height, width)`
"""
X = paddle.to_tensor(input["input"])
Y = paddle.to_tensor(label["output"])
n_obj = len(X)
indices = np.arange(n_obj)
np.random.shuffle(indices)

if len(X.shape) == 3:
# random horizontal flip
if np.random.random() > 0.5:
X = paddle.flip(X, axis=2)
Y = paddle.flip(Y, axis=2)
# random vertical flip
if np.random.random() > 0.5:
X = paddle.flip(X, axis=1)
Y = paddle.flip(Y, axis=1)
# random 90* rotation
if np.random.random() > 0.5:
new_perm = list(range(len(X.shape)))
new_perm[1], new_perm[2] = new_perm[2], new_perm[1]
X = paddle.transpose(X, perm=new_perm)
Y = paddle.transpose(Y, perm=new_perm)
X = X.reshape([1] + X.shape)
Y = Y.reshape([1] + Y.shape)
else:
# random horizontal flip
batch_size = X.shape[0]
mask = np.random.random(size=batch_size) > 0.5
X[mask] = paddle.flip(X[mask], axis=3)
Y[mask] = paddle.flip(Y[mask], axis=3)
# random vertical flip
mask = np.random.random(size=batch_size) > 0.5
X[mask] = paddle.flip(X[mask], axis=2)
Y[mask] = paddle.flip(Y[mask], axis=2)
# random 90* rotation
mask = np.random.random(size=batch_size) > 0.5
new_perm = list(range(len(X.shape)))
new_perm[2], new_perm[3] = new_perm[3], new_perm[2]
X[mask] = paddle.transpose(X[mask], perm=new_perm)
Y[mask] = paddle.transpose(Y[mask], perm=new_perm)

return X, Y


def batch_transform_wrapper(sampler):
def batch_transform_fun(batch):
batch_input = paddle.to_tensor([])
batch_label = paddle.to_tensor([])
k = sampler()
for i in range(len(batch)):
x1 = batch[i][0][:, k, :, :]
x2 = batch[i][0][:, k - 1, :, :]
x = paddle.stack((x1, x1 - x2), axis=1)
batch_input = paddle.concat((batch_input, x), axis=0)
batch_label = paddle.concat((batch_label, batch[i][1]), axis=0)
return ({"input": batch_input}, {"output": batch_label}, {})

return batch_transform_fun


class NewNamedArrayDataset(Dataset):
def __init__(
self,
input,
label,
weight=None,
transforms=None,
):
super().__init__()
self.input = input
self.label = label
self.input_keys = tuple(input.keys())
self.label_keys = tuple(label.keys())
self.weight = {} if weight is None else weight
self.transforms = transforms
self._len = len(next(iter(input.values())))

def __getitem__(self, idx):
input_item = {key: value[idx] for key, value in self.input.items()}
label_item = {key: value[idx] for key, value in self.label.items()}
weight_item = {key: value[idx] for key, value in self.weight.items()}

##### Transforms may be applied on label and weight.
if self.transforms is not None:
input_item, label_item = self.transforms(input_item, label_item)

return (input_item, label_item, weight_item)

def __len__(self):
return self._len


class NewSupConstraint(Constraint):
def __init__(
self,
dataset,
data_loader,
loss,
output_expr=None,
name: str = "sup_constraint",
):
##### build dataset
_dataset = dataset

self.input_keys = _dataset.input_keys
self.output_keys = (
tuple(output_expr.keys())
if output_expr is not None
else _dataset.label_keys
)

self.output_expr = output_expr
if self.output_expr is None:
self.output_expr = {
key: (lambda out, k=key: out[k]) for key in self.output_keys
}

##### construct dataloader with dataset and dataloader_cfg
self.data_loader = data_loader
self.data_iter = iter(self.data_loader)
self.loss = loss
self.name = name

def __str__(self):
return ", ".join(
[
self.__class__.__name__,
f"name = {self.name}",
f"input_keys = {self.input_keys}",
f"output_keys = {self.output_keys}",
f"output_expr = {self.output_expr}",
f"loss = {self.loss}",
]
)


# class NewSupValidator(Validator):
# def __init__(
# self,
# dataset,
# data_loader,
# loss,
# output_expr = None,
# metric = None,
# name = "sup_validator",
# ):
# self.output_expr = output_expr

# ##### build dataset
# _dataset = dataset

# self.input_keys = _dataset.input_keys
# self.output_keys = (
# tuple(output_expr.keys())
# if output_expr is not None
# else _dataset.label_keys
# )

# if self.output_expr is None:
# self.output_expr = {
# key: lambda out, k=key: out[k] for key in self.output_keys
# }

# ##### construct dataloader with dataset and dataloader_cfg
# self.data_loader = data_loader
# self.data_iter = iter(self.data_loader)
# self.loss = loss
# self.metric = metric
# self.name = name

# def __str__(self):
# return ", ".join(
# [
# self.__class__.__name__,
# f"name = {self.name}",
# f"input_keys = {self.input_keys}",
# f"output_keys = {self.output_keys}",
# f"output_expr = {self.output_expr}",
# f"len(dataloader) = {len(self.data_loader)}",
# f"loss = {self.loss}",
# f"metric = {list(self.metric.keys())}",
# ]
# )
Loading