diff --git a/.gitignore b/.gitignore index 2efd78b..8025c7c 100644 --- a/.gitignore +++ b/.gitignore @@ -30,4 +30,10 @@ related-projects.md docs/**/*.dat docs/**/*.npz -.DS_Store \ No newline at end of file +.DS_Store +docs/apis/ +!docs/apis/pinnx.rst +!docs/apis/pinnx.callbacks.rst +!docs/apis/pinnx.fnspace.rst +!docs/apis/pinnx.grad.rst +!docs/apis/pinnx.metrics.rst \ No newline at end of file diff --git a/docs/_templates/classtemplate.rst b/docs/_templates/classtemplate.rst new file mode 100644 index 0000000..e783a6c --- /dev/null +++ b/docs/_templates/classtemplate.rst @@ -0,0 +1,9 @@ +.. role:: hidden + :class: hidden-section +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autoclass:: {{ name }} + :members: \ No newline at end of file diff --git a/docs/apis/pinnx.callbacks.rst b/docs/apis/pinnx.callbacks.rst new file mode 100644 index 0000000..e803538 --- /dev/null +++ b/docs/apis/pinnx.callbacks.rst @@ -0,0 +1,26 @@ +``pinnx.callbacks`` module +========================== + +.. currentmodule:: pinnx.callbacks +.. automodule:: pinnx.callbacks + +Callbacks +--------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + Callback + CallbackList + ModelCheckpoint + EarlyStopping + Timer + DropoutUncertainty + VariableValue + OperatorPredictor + MovieDumper + PDEPointResampler + + diff --git a/docs/apis/pinnx.fnspace.rst b/docs/apis/pinnx.fnspace.rst new file mode 100644 index 0000000..4cfa92b --- /dev/null +++ b/docs/apis/pinnx.fnspace.rst @@ -0,0 +1,23 @@ +``pinnx.fnspace`` module +======================== + +.. currentmodule:: pinnx.fnspace +.. automodule:: pinnx.fnspace + +Function Space +-------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + wasserstein2 + FunctionSpace + PowerSeries + Chebyshev + GRF + GRF_KL + GRF2D + + diff --git a/docs/apis/pinnx.grad.rst b/docs/apis/pinnx.grad.rst new file mode 100644 index 0000000..69fd4a7 --- /dev/null +++ b/docs/apis/pinnx.grad.rst @@ -0,0 +1,19 @@ +``pinnx.grad`` module +===================== + +.. currentmodule:: pinnx.grad +.. automodule:: pinnx.grad + +Automatic Differentiation +------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + jacobian + hessian + gradient + + diff --git a/docs/apis/pinnx.metrics.rst b/docs/apis/pinnx.metrics.rst new file mode 100644 index 0000000..abc04b3 --- /dev/null +++ b/docs/apis/pinnx.metrics.rst @@ -0,0 +1,24 @@ +``pinnx.metrics`` module +======================== + +.. currentmodule:: pinnx.metrics +.. automodule:: pinnx.metrics + +Metrics +------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + accuracy + l2_relative_error + nanl2_relative_error + mean_l2_relative_error + mean_squared_error + mean_absolute_percentage_error + max_absolute_percentage_error + absolute_percentage_error_std + + diff --git a/docs/apis/pinnx.rst b/docs/apis/pinnx.rst new file mode 100644 index 0000000..43c0552 --- /dev/null +++ b/docs/apis/pinnx.rst @@ -0,0 +1,17 @@ +``pinnx`` module +================ + +.. currentmodule:: pinnx +.. automodule:: pinnx + +Trainer +------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + Trainer + + diff --git a/docs/auto_generater.py b/docs/auto_generater.py new file mode 100644 index 0000000..b771e51 --- /dev/null +++ b/docs/auto_generater.py @@ -0,0 +1,370 @@ +# -*- coding: utf-8 -*- + +import importlib +import inspect +import os + +block_list = ['test', 'register_pytree_node', 'call', 'namedtuple', 'jit', 'wraps', 'index', 'function'] + + +def get_class_funcs(module): + classes, functions, others = [], [], [] + # Solution from: https://stackoverflow.com/questions/43059267/how-to-do-from-module-import-using-importlib + if "__all__" in module.__dict__: + names = module.__dict__["__all__"] + else: + names = [x for x in module.__dict__ if not x.startswith("_")] + for k in names: + data = getattr(module, k) + if not inspect.ismodule(data) and not k.startswith("_"): + if inspect.isfunction(data): + functions.append(k) + elif isinstance(data, type): + classes.append(k) + else: + others.append(k) + + return classes, functions, others + + +def _write_module(module_name, automodule, filename, header=None, template=False): + module = importlib.import_module(module_name) + classes, functions, others = get_class_funcs(module) + + fout = open(filename, 'w') + # write header + if header is None: + header = f'``{module_name}`` module' + fout.write(header + '\n') + fout.write('=' * len(header) + '\n\n') + fout.write(f'.. currentmodule:: {automodule} \n') + fout.write(f'.. automodule:: {automodule} \n\n') + + # write autosummary + fout.write('.. autosummary::\n') + if template: + fout.write(' :template: classtemplate.rst\n') + fout.write(' :toctree: generated/\n\n') + for m in functions: + fout.write(f' {m}\n') + for m in classes: + fout.write(f' {m}\n') + for m in others: + fout.write(f' {m}\n') + + fout.close() + + +def _write_submodules(module_name, filename, header=None, submodule_names=(), section_names=()): + fout = open(filename, 'w') + # write header + if header is None: + header = f'``{module_name}`` module' + else: + header = header + fout.write(header + '\n') + fout.write('=' * len(header) + '\n\n') + fout.write(f'.. currentmodule:: {module_name} \n') + fout.write(f'.. automodule:: {module_name} \n\n') + + # whole module + for i, name in enumerate(submodule_names): + module = importlib.import_module(module_name + '.' + name) + classes, functions, others = get_class_funcs(module) + + fout.write(section_names[i] + '\n') + fout.write('-' * len(section_names[i]) + '\n\n') + + # write autosummary + fout.write('.. autosummary::\n') + fout.write(' :toctree: generated/\n') + fout.write(' :nosignatures:\n') + fout.write(' :template: classtemplate.rst\n\n') + for m in functions: + fout.write(f' {m}\n') + for m in classes: + fout.write(f' {m}\n') + for m in others: + fout.write(f' {m}\n') + + fout.write(f'\n\n') + + fout.close() + + +def _write_subsections(module_name, + filename, + subsections: dict, + header: str = None): + fout = open(filename, 'w') + header = f'``{module_name}`` module' if header is None else header + fout.write(header + '\n') + fout.write('=' * len(header) + '\n\n') + fout.write(f'.. currentmodule:: {module_name} \n') + fout.write(f'.. automodule:: {module_name} \n\n') + + fout.write('.. contents::' + '\n') + fout.write(' :local:' + '\n') + fout.write(' :depth: 1' + '\n\n') + + for name, values in subsections.items(): + fout.write(name + '\n') + fout.write('-' * len(name) + '\n\n') + fout.write('.. autosummary::\n') + fout.write(' :toctree: generated/\n') + fout.write(' :nosignatures:\n') + fout.write(' :template: classtemplate.rst\n\n') + for m in values: + fout.write(f' {m}\n') + fout.write(f'\n\n') + + fout.close() + + +def _write_subsections_v2(module_path, + out_path, + filename, + subsections: dict, + header: str = None): + fout = open(filename, 'w') + header = f'``{out_path}`` module' if header is None else header + fout.write(header + '\n') + fout.write('=' * len(header) + '\n\n') + fout.write(f'.. currentmodule:: {out_path} \n') + fout.write(f'.. automodule:: {out_path} \n\n') + + fout.write('.. contents::' + '\n') + fout.write(' :local:' + '\n') + fout.write(' :depth: 1' + '\n\n') + + for name, subheader in subsections.items(): + module = importlib.import_module(f'{module_path}.{name}') + classes, functions, others = get_class_funcs(module) + + fout.write(subheader + '\n') + fout.write('-' * len(subheader) + '\n\n') + fout.write('.. autosummary::\n') + fout.write(' :toctree: generated/\n') + fout.write(' :nosignatures:\n') + fout.write(' :template: classtemplate.rst\n\n') + for m in functions: + fout.write(f' {m}\n') + for m in classes: + fout.write(f' {m}\n') + for m in others: + fout.write(f' {m}\n') + fout.write(f'\n\n') + + fout.close() + + +def _write_subsections_v3(module_path, + out_path, + filename, + subsections: dict, + header: str = None): + fout = open(filename, 'w') + header = f'``{out_path}`` module' if header is None else header + fout.write(header + '\n') + fout.write('=' * len(header) + '\n\n') + fout.write(f'.. currentmodule:: {out_path} \n') + fout.write(f'.. automodule:: {out_path} \n\n') + + fout.write('.. contents::' + '\n') + fout.write(' :local:' + '\n') + fout.write(' :depth: 2' + '\n\n') + + for section in subsections: + fout.write(subsections[section]['header'] + '\n') + fout.write('-' * len(subsections[section]['header']) + '\n\n') + + fout.write(f'.. currentmodule:: {out_path}.{section} \n') + fout.write(f'.. automodule:: {out_path}.{section} \n\n') + + for name, subheader in subsections[section]['content'].items(): + module = importlib.import_module(f'{module_path}.{section}.{name}') + classes, functions, others = get_class_funcs(module) + + fout.write(subheader + '\n') + fout.write('~' * len(subheader) + '\n\n') + fout.write('.. autosummary::\n') + fout.write(' :toctree: generated/\n') + fout.write(' :nosignatures:\n') + fout.write(' :template: classtemplate.rst\n\n') + for m in functions: + fout.write(f' {m}\n') + for m in classes: + fout.write(f' {m}\n') + for m in others: + fout.write(f' {m}\n') + fout.write(f'\n\n') + + fout.close() + + +def _write_subsections_v4(module_path, + filename, + subsections: dict, + header: str = None): + fout = open(filename, 'w') + header = f'``{module_path}`` module' if header is None else header + fout.write(header + '\n') + fout.write('=' * len(header) + '\n\n') + + fout.write('.. contents::' + '\n') + fout.write(' :local:' + '\n') + fout.write(' :depth: 1' + '\n\n') + + for name, (subheader, out_path) in subsections.items(): + + module = importlib.import_module(f'{module_path}.{name}') + classes, functions, others = get_class_funcs(module) + + fout.write(subheader + '\n') + fout.write('-' * len(subheader) + '\n\n') + + fout.write(f'.. currentmodule:: {out_path} \n') + fout.write(f'.. automodule:: {out_path} \n\n') + + fout.write('.. autosummary::\n') + fout.write(' :toctree: generated/\n') + fout.write(' :nosignatures:\n') + fout.write(' :template: classtemplate.rst\n\n') + for m in functions: + fout.write(f' {m}\n') + for m in classes: + fout.write(f' {m}\n') + for m in others: + fout.write(f' {m}\n') + fout.write(f'\n\n') + + fout.close() + + +def _get_functions(obj): + return set([n for n in dir(obj) + if (n not in block_list # not in blacklist + and callable(getattr(obj, n)) # callable + and not isinstance(getattr(obj, n), type) # not class + and n[0].islower() # starts with lower char + and not n.startswith('__') # not special methods + ) + ]) + + +def _import(mod, klass=None, is_jax=False): + obj = importlib.import_module(mod) + if klass: + obj = getattr(obj, klass) + return obj, ':meth:`{}.{}.{{}}`'.format(mod, klass) + else: + if not is_jax: + return obj, ':obj:`{}.{{}}`'.format(mod) + else: + from docs import implemented_jax_funcs + return implemented_jax_funcs, ':obj:`{}.{{}}`'.format(mod) + + +def main(): + os.makedirs('apis/', exist_ok=True) + + + module_and_name = [ + ('base', 'Base Geometry Class'), + ('geometry_1d', 'Geometry in 1D'), + ('geometry_2d', 'Geometry in 2D'), + ('geometry_3d', 'Geometry in 3D'), + ('geometry_nd', 'Geometry in ND'), + ('pointcloud', 'Point Cloud'), + ('timedomain', 'Time Domain'), + ] + + _write_submodules(module_name='pinnx.geometry', + filename='apis/pinnx.geometry.rst', + header='``pinnx.geometry`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) + + module_and_name = [ + ('base', 'Base Initial and Boundary Conditions Class'), + ('boundary_conditions', 'Boundary Conditions'), + ('initial_conditions', 'Initial Conditions'), + ] + + _write_submodules(module_name='pinnx.icbc', + filename='apis/pinnx.icbc.rst', + header='``pinnx.icbc`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) + + module_and_name = [ + ('base', 'Base Neural Network Class'), + ('convert', 'Dict and Array Converters'), + ('deeponet_strategy', 'DeepONet Strategy'), + ('deeponet', 'DeepONet'), + ('fnn', 'Fully Connected Neural Network'), + ('mionet', 'Multiple Input Operators Network'), + ('model', 'Model'), + ] + + _write_submodules(module_name='pinnx.nn', + filename='apis/pinnx.nn.rst', + header='``pinnx.nn`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) + + module_and_name = [ + ('base', 'Base Problem Class'), + ('dataset_function', 'Dataset Function'), + ('dataset_general', 'General Dataset'), + ('dataset_mf', 'Multifidelity Dataset'), + ('dataset_quadruple', 'Quadruple Point Dataset'), + ('dataset_triple', 'Triple Point Dataset'), + ('fpde', 'Forward PDE'), + ('ide', 'Inverse PDE'), + ('pde_operator', 'PDE Operator'), + ('pde', 'PDE'), + ] + + _write_submodules(module_name='pinnx.problem', + filename='apis/pinnx.problem.rst', + header='``pinnx.problem`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) + + module_and_name = [ + ('_convert', 'Dict and Array Converters'), + ('_display', 'Display training progress.'), + ('array_ops', 'Array Operations'), + ('external', 'External Functions'), + ('internal', 'Internal Functions'), + ('losses', 'Loss Functions'), + ('sampler', 'Sampler'), + ('sampling', 'Sampling'), + ('transformers', 'Transformer'), + ] + + _write_submodules(module_name='pinnx.utils', + filename='apis/pinnx.utils.rst', + header='``pinnx.utils`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) + + # module_and_name = [ + # ('_trainer', 'Trainer'), + # ('callbacks', 'Callbacks'), + # ('fnspace', 'Function Space'), + # ('grad', 'Automatic Differentiation'), + # ('metrics', 'Metrics'), + # ] + # + # _write_submodules(module_name='pinnx', + # filename='apis/pinnx.rst', + # header='``pinnx`` module', + # submodule_names=[k[0] for k in module_and_name], + # section_names=[k[1] for k in module_and_name]) + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 83eb221..3b77390 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,9 +29,14 @@ import os import sys +sys.path.insert(0, os.path.abspath(os.path.curdir)) sys.path.insert(0, os.path.abspath('../')) import pinnx +import auto_generater +auto_generater.main() + +os.makedirs('apis/', exist_ok=True) # -- Project information ----------------------------------------------------- diff --git a/docs/index.rst b/docs/index.rst index 4c79bf0..8aa8be7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -121,3 +121,18 @@ See also the BDP ecosystem We are building the `brain dynamics programming ecosystem `_. +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: API Documentation + + apis/pinnx.rst + apis/pinnx.callbacks.rst + apis/pinnx.fnspace.rst + apis/pinnx.grad.rst + apis/pinnx.geometry.rst + apis/pinnx.icbc.rst + apis/pinnx.metrics.rst + apis/pinnx.nn.rst + apis/pinnx.problem.rst + apis/pinnx.utils.rst \ No newline at end of file diff --git a/pinnx/metrics.py b/pinnx/metrics.py index db7b260..164e2dd 100644 --- a/pinnx/metrics.py +++ b/pinnx/metrics.py @@ -20,6 +20,7 @@ def _accuracy(y_true, y_pred): def accuracy(y_true, y_pred): + """Computes accuracy across nested structures of labels and predictions.""" return jax.tree_util.tree_map(_accuracy, y_true, y_pred, is_leaf=u.math.is_quantity) @@ -28,11 +29,11 @@ def _l2_relative_error(y_true, y_pred): def l2_relative_error(y_true, y_pred): + """Computes L2 relative error across nested structures of labels and predictions.""" return jax.tree_util.tree_map(_l2_relative_error, y_true, y_pred, is_leaf=u.math.is_quantity) def _nanl2_relative_error(y_true, y_pred): - """Return the L2 relative error treating Not a Numbers (NaNs) as zero.""" err = y_true - y_pred err = u.math.nan_to_num(err) y_true = u.math.nan_to_num(y_true) @@ -40,11 +41,11 @@ def _nanl2_relative_error(y_true, y_pred): def nanl2_relative_error(y_true, y_pred): + """Computes L2 relative error across nested structures of labels and predictions.""" return jax.tree_util.tree_map(_nanl2_relative_error, y_true, y_pred, is_leaf=u.math.is_quantity) def _mean_l2_relative_error(y_true, y_pred): - """Compute the average of L2 relative error along the first axis.""" return u.math.mean( u.linalg.norm(y_true - y_pred, axis=1) / u.linalg.norm(y_true, axis=1) @@ -52,6 +53,7 @@ def _mean_l2_relative_error(y_true, y_pred): def mean_l2_relative_error(y_true, y_pred): + """Computes mean L2 relative error across nested structures of labels and predictions.""" return jax.tree_util.tree_map(_mean_l2_relative_error, y_true, y_pred, is_leaf=u.math.is_quantity) @@ -60,6 +62,7 @@ def _absolute_percentage_error(y_true, y_pred): def mean_absolute_percentage_error(y_true, y_pred): + """Computes mean absolute percentage error across nested structures of labels and predictions.""" return jax.tree_util.tree_map(lambda x, y: _absolute_percentage_error(x, y).mean(), y_true, y_pred, @@ -74,6 +77,7 @@ def max_absolute_percentage_error(y_true, y_pred): def absolute_percentage_error_std(y_true, y_pred): + """Computes standard deviation of absolute percentage error across nested structures of labels and predictions.""" return jax.tree_util.tree_map(lambda x, y: _absolute_percentage_error(x, y).std(), y_true, y_pred, @@ -85,6 +89,7 @@ def _mean_squared_error(y_true, y_pred): def mean_squared_error(y_true, y_pred): + """Computes mean squared error across nested structures of labels and predictions.""" return jax.tree_util.tree_map(_mean_squared_error, y_true, y_pred, is_leaf=u.math.is_quantity)