Skip to content

Commit

Permalink
add hessian support in output def. (#3246)
Browse files Browse the repository at this point in the history
hessian not implemented in neither tf nor pt.

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Feb 8, 2024
1 parent 5ad3d96 commit cfdda1d
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 62 deletions.
24 changes: 19 additions & 5 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,12 @@ class OutputVariableDef:
are differentiable.
Virial, the transposed negative gradient with cell tensor times
cell tensor, will be calculated, see eq 40 JCP 159, 054801 (2023).
atomic : bool
If the variable is defined for each atom.
category : int
The category of the output variable.
hessian : bool
If hessian is requred
"""

def __init__(
Expand All @@ -185,6 +189,7 @@ def __init__(
c_differentiable: bool = False,
atomic: bool = True,
category: int = OutputVariableCategory.OUT.value,
r_hessian: bool = False,
):
self.name = name
self.shape = list(shape)
Expand All @@ -194,13 +199,15 @@ def __init__(
self.c_differentiable = c_differentiable
if self.c_differentiable and not self.r_differentiable:
raise ValueError("c differentiable requires r_differentiable")
if not self.reduciable and self.r_differentiable:
raise ValueError("only reduciable variable are r differentiable")
if not self.reduciable and self.c_differentiable:
raise ValueError("only reduciable variable are c differentiable")
if self.reduciable and not self.atomic:
raise ValueError("a reduciable variable should be atomic")
self.category = category
self.r_hessian = r_hessian
if self.r_hessian:
if not self.reduciable:
raise ValueError("only reduciable variable can calculate hessian")
if not self.r_differentiable:
raise ValueError("only r_differentiable variable can calculate hessian")


class FittingOutputDef:
Expand Down Expand Up @@ -257,6 +264,7 @@ def __init__(
self.def_outp = fit_defs
self.def_redu = do_reduce(self.def_outp.get_data())
self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp.get_data())
self.def_hess_r, _ = do_derivative(self.def_derv_r)
self.def_derv_c_redu = do_reduce(self.def_derv_c)
self.var_defs: Dict[str, OutputVariableDef] = {}
for ii in [
Expand All @@ -265,6 +273,7 @@ def __init__(
self.def_derv_c,
self.def_derv_r,
self.def_derv_c_redu,
self.def_hess_r,
]:
self.var_defs.update(ii)

Expand Down Expand Up @@ -292,6 +301,9 @@ def keys_redu(self):
def keys_derv_r(self):
return self.def_derv_r.keys()

def keys_hess_r(self):
return self.def_hess_r.keys()

def keys_derv_c(self):
return self.def_derv_c.keys()

Expand Down Expand Up @@ -392,7 +404,9 @@ def do_derivative(
rkr,
vv.shape + [3], # noqa: RUF005
reduciable=False,
r_differentiable=False,
r_differentiable=(
vv.r_hessian and vv.category == OutputVariableCategory.OUT.value
),
c_differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_R),
Expand Down
Loading

0 comments on commit cfdda1d

Please sign in to comment.