From 5fabf1a1afa28c45f62700955b1b4538f4db29bd Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 2 Jan 2025 13:32:30 +0800 Subject: [PATCH] update --- docs/unit-examples-forward/Beltrami_flow.py | 8 ++++--- docs/unit-examples-forward/heat.ipynb | 3 +-- pinnx/_trainer.py | 25 +++++++++++++++++++++ pinnx/geometry/geometry_2d.py | 3 +-- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/docs/unit-examples-forward/Beltrami_flow.py b/docs/unit-examples-forward/Beltrami_flow.py index 08ecfe2..6578bbd 100644 --- a/docs/unit-examples-forward/Beltrami_flow.py +++ b/docs/unit-examples-forward/Beltrami_flow.py @@ -154,9 +154,11 @@ def icbc_cond_func(x, include_p: bool = False): * u.math.exp(-2 * d ** 2 * x['t']) ) - r = {'u_vel': u_ * unit_of_speed, - 'v_vel': v * unit_of_speed, - 'w_vel': w * unit_of_speed} + r = { + 'u_vel': u_ * unit_of_speed, + 'v_vel': v * unit_of_speed, + 'w_vel': w * unit_of_speed + } if include_p: r['p'] = p * unit_of_pressure return r diff --git a/docs/unit-examples-forward/heat.ipynb b/docs/unit-examples-forward/heat.ipynb index 8d26907..8e1227a 100644 --- a/docs/unit-examples-forward/heat.ipynb +++ b/docs/unit-examples-forward/heat.ipynb @@ -146,7 +146,6 @@ "outputs": [], "source": [ "import brainstate as bst\n", - "import numpy as np\n", "import optax\n", "import brainunit as u\n", "\n", @@ -238,7 +237,7 @@ " lambda x : {'y': 0. * uy}\n", ")\n", "ic = pinnx.icbc.IC(\n", - " lambda x: {'y': u.math.sin(n * u.math.pi * x['x'][:] / L, unit_to_scale=u.becquerel) * uy},\n", + " lambda x: {'y': u.math.sin(n * u.math.pi * x['x'] / L, unit_to_scale=u.becquerel) * uy},\n", ")" ] }, diff --git a/pinnx/_trainer.py b/pinnx/_trainer.py index 0abaef5..7492a4e 100644 --- a/pinnx/_trainer.py +++ b/pinnx/_trainer.py @@ -1,3 +1,4 @@ +import time from typing import Union, Sequence, Callable, Optional import brainstate as bst @@ -74,6 +75,7 @@ def compile( self, optimizer: bst.optim.Optimizer, metrics: Union[str, Sequence[str]] = None, + measture_train_step_compile_time: bool = False, ): """ Configures the trainer for training. @@ -138,6 +140,12 @@ def _loss_fun(): self.fn_outputs_losses_test = bst.compile.jit(fn_outputs_losses_test) self.fn_train_step = bst.compile.jit(fn_train_step) + if measture_train_step_compile_time: + t0 = time.time() + self._compile_training_step(self.batch_size) + t1 = time.time() + return self, t1 - t0 + return self @utils.timing @@ -150,6 +158,7 @@ def train( callbacks: Union[Callback, Sequence[Callback]] = None, model_restore_path: str = None, model_save_path: str = None, + measture_train_step_time: bool = False, ): """ Trains the trainer. @@ -177,6 +186,9 @@ def train( model_save_path (String): Prefix of filenames created for the checkpoint. """ + if measture_train_step_time: + t0 = time.time() + if self.metrics is None: raise ValueError("Compile the trainer before training.") @@ -210,8 +222,21 @@ def train( training_display.summary(self.train_state) if model_save_path is not None: self.save(model_save_path, verbose=1) + + if measture_train_step_time: + t1 = time.time() + return self, t1 - t0 return self + def _compile_training_step(self, batch_size=None): + # get data + self.train_state.set_data_train(*self.problem.train_next_batch(batch_size)) + + # train one batch + self.fn_train_step.compile(self.train_state.X_train, + self.train_state.y_train, + **self.train_state.Aux_train) + def _train(self, iterations, display_every, batch_size, callbacks): for i in range(iterations): callbacks.on_epoch_begin() diff --git a/pinnx/geometry/geometry_2d.py b/pinnx/geometry/geometry_2d.py index dcc8368..1e28299 100644 --- a/pinnx/geometry/geometry_2d.py +++ b/pinnx/geometry/geometry_2d.py @@ -17,15 +17,14 @@ from typing import Union, Literal import brainstate as bst -import jax.numpy as jnp import numpy as np from scipy import spatial +from pinnx import utils from pinnx.utils.sampling import sample from .base import Geometry from .geometry_nd import Hypercube, Hypersphere from ..utils import isclose, vectorize -from pinnx import utils class Disk(Hypersphere):