Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update docs #18

Merged
merged 1 commit into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/unit-examples-forward/Beltrami_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,11 @@ def icbc_cond_func(x, include_p: bool = False):
* u.math.exp(-2 * d ** 2 * x['t'])
)

r = {'u_vel': u_ * unit_of_speed,
'v_vel': v * unit_of_speed,
'w_vel': w * unit_of_speed}
r = {
'u_vel': u_ * unit_of_speed,
'v_vel': v * unit_of_speed,
'w_vel': w * unit_of_speed
}
if include_p:
r['p'] = p * unit_of_pressure
return r
Expand Down
3 changes: 1 addition & 2 deletions docs/unit-examples-forward/heat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@
"outputs": [],
"source": [
"import brainstate as bst\n",
"import numpy as np\n",
"import optax\n",
"import brainunit as u\n",
"\n",
Expand Down Expand Up @@ -238,7 +237,7 @@
" lambda x : {'y': 0. * uy}\n",
")\n",
"ic = pinnx.icbc.IC(\n",
" lambda x: {'y': u.math.sin(n * u.math.pi * x['x'][:] / L, unit_to_scale=u.becquerel) * uy},\n",
" lambda x: {'y': u.math.sin(n * u.math.pi * x['x'] / L, unit_to_scale=u.becquerel) * uy},\n",
")"
]
},
Expand Down
25 changes: 25 additions & 0 deletions pinnx/_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import Union, Sequence, Callable, Optional

import brainstate as bst
Expand Down Expand Up @@ -74,6 +75,7 @@ def compile(
self,
optimizer: bst.optim.Optimizer,
metrics: Union[str, Sequence[str]] = None,
measture_train_step_compile_time: bool = False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (typo): Parameter name contains a typo: 'measture' should be 'measure'

Suggested implementation:

        measure_train_step_compile_time: bool = False,
        if measure_train_step_compile_time:

You may need to:

  1. Update any documentation strings that reference this parameter
  2. Update any calls to this method that explicitly name this parameter

):
"""
Configures the trainer for training.
Expand Down Expand Up @@ -138,6 +140,12 @@ def _loss_fun():
self.fn_outputs_losses_test = bst.compile.jit(fn_outputs_losses_test)
self.fn_train_step = bst.compile.jit(fn_train_step)

if measture_train_step_compile_time:
t0 = time.time()
self._compile_training_step(self.batch_size)
t1 = time.time()
return self, t1 - t0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Inconsistent return types could cause bugs - method returns tuple when measuring time but self otherwise

Consider using a consistent return type and providing the timing information through a different mechanism, such as a class attribute or logging.


return self

@utils.timing
Expand All @@ -150,6 +158,7 @@ def train(
callbacks: Union[Callback, Sequence[Callback]] = None,
model_restore_path: str = None,
model_save_path: str = None,
measture_train_step_time: bool = False,
):
"""
Trains the trainer.
Expand Down Expand Up @@ -177,6 +186,9 @@ def train(
model_save_path (String): Prefix of filenames created for the checkpoint.
"""

if measture_train_step_time:
t0 = time.time()

if self.metrics is None:
raise ValueError("Compile the trainer before training.")

Expand Down Expand Up @@ -210,8 +222,21 @@ def train(
training_display.summary(self.train_state)
if model_save_path is not None:
self.save(model_save_path, verbose=1)

if measture_train_step_time:
t1 = time.time()
return self, t1 - t0
return self

def _compile_training_step(self, batch_size=None):
# get data
self.train_state.set_data_train(*self.problem.train_next_batch(batch_size))

# train one batch
self.fn_train_step.compile(self.train_state.X_train,
self.train_state.y_train,
**self.train_state.Aux_train)

def _train(self, iterations, display_every, batch_size, callbacks):
for i in range(iterations):
callbacks.on_epoch_begin()
Expand Down
3 changes: 1 addition & 2 deletions pinnx/geometry/geometry_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
from typing import Union, Literal

import brainstate as bst
import jax.numpy as jnp
import numpy as np
from scipy import spatial

from pinnx import utils
from pinnx.utils.sampling import sample
from .base import Geometry
from .geometry_nd import Hypercube, Hypersphere
from ..utils import isclose, vectorize
from pinnx import utils


class Disk(Hypersphere):
Expand Down
Loading