Skip to content

Commit

Permalink
Introducing a more general KroneckerFactored block class.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 531184538
  • Loading branch information
botev authored and KfacJaxDev committed May 15, 2023
1 parent 3be9b1a commit dfa68b5
Show file tree
Hide file tree
Showing 9 changed files with 667 additions and 444 deletions.
360 changes: 214 additions & 146 deletions examples/training.py

Large diffs are not rendered by default.

41 changes: 28 additions & 13 deletions kfac_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,32 @@
NegativeLogProbLoss = loss_functions.NegativeLogProbLoss
DistributionNegativeLogProbLoss = loss_functions.DistributionNegativeLogProbLoss
NormalMeanNegativeLogProbLoss = loss_functions.NormalMeanNegativeLogProbLoss
NormalMeanVarianceNegativeLogProbLoss = loss_functions.NormalMeanVarianceNegativeLogProbLoss
MultiBernoulliNegativeLogProbLoss = loss_functions.MultiBernoulliNegativeLogProbLoss
CategoricalLogitsNegativeLogProbLoss = loss_functions.CategoricalLogitsNegativeLogProbLoss
OneHotCategoricalLogitsNegativeLogProbLoss = loss_functions.OneHotCategoricalLogitsNegativeLogProbLoss
register_sigmoid_cross_entropy_loss = loss_functions.register_sigmoid_cross_entropy_loss
register_multi_bernoulli_predictive_distribution = loss_functions.register_multi_bernoulli_predictive_distribution
register_softmax_cross_entropy_loss = loss_functions.register_softmax_cross_entropy_loss
register_categorical_predictive_distribution = loss_functions.register_categorical_predictive_distribution
NormalMeanVarianceNegativeLogProbLoss = (
loss_functions.NormalMeanVarianceNegativeLogProbLoss)
MultiBernoulliNegativeLogProbLoss = (
loss_functions.MultiBernoulliNegativeLogProbLoss)
CategoricalLogitsNegativeLogProbLoss = (
loss_functions.CategoricalLogitsNegativeLogProbLoss)
OneHotCategoricalLogitsNegativeLogProbLoss = (
loss_functions.OneHotCategoricalLogitsNegativeLogProbLoss)
register_sigmoid_cross_entropy_loss = (
loss_functions.register_sigmoid_cross_entropy_loss)
register_multi_bernoulli_predictive_distribution = (
loss_functions.register_multi_bernoulli_predictive_distribution)
register_softmax_cross_entropy_loss = (
loss_functions.register_softmax_cross_entropy_loss)
register_categorical_predictive_distribution = (
loss_functions.register_categorical_predictive_distribution)
register_squared_error_loss = loss_functions.register_squared_error_loss
register_normal_predictive_distribution = loss_functions.register_normal_predictive_distribution
register_normal_predictive_distribution = (
loss_functions.register_normal_predictive_distribution)

# Curvature blocks
CurvatureBlock = curvature_blocks.CurvatureBlock
ScaledIdentity = curvature_blocks.ScaledIdentity
Diagonal = curvature_blocks.Diagonal
Full = curvature_blocks.Full
KroneckerFactored = curvature_blocks.KroneckerFactored
TwoKroneckerFactored = curvature_blocks.TwoKroneckerFactored
NaiveDiagonal = curvature_blocks.NaiveDiagonal
NaiveFull = curvature_blocks.NaiveFull
Expand All @@ -82,16 +92,20 @@
ScaleAndShiftFull = curvature_blocks.ScaleAndShiftFull
set_max_parallel_elements = curvature_blocks.set_max_parallel_elements
get_max_parallel_elements = curvature_blocks.get_max_parallel_elements
set_default_eigen_decomposition_threshold = curvature_blocks.set_default_eigen_decomposition_threshold
get_default_eigen_decomposition_threshold = curvature_blocks.get_default_eigen_decomposition_threshold
set_default_eigen_decomposition_threshold = (
curvature_blocks.set_default_eigen_decomposition_threshold)
get_default_eigen_decomposition_threshold = (
curvature_blocks.get_default_eigen_decomposition_threshold)

# Curvature estimators
CurvatureEstimator = curvature_estimator.CurvatureEstimator
BlockDiagonalCurvature = curvature_estimator.BlockDiagonalCurvature
ExplicitExactCurvature = curvature_estimator.ExplicitExactCurvature
ImplicitExactCurvature = curvature_estimator.ImplicitExactCurvature
set_default_tag_to_block_ctor = curvature_estimator.set_default_tag_to_block_ctor
get_default_tag_to_block_ctor = curvature_estimator.get_default_tag_to_block_ctor
set_default_tag_to_block_ctor = (
curvature_estimator.set_default_tag_to_block_ctor)
get_default_tag_to_block_ctor = (
curvature_estimator.get_default_tag_to_block_ctor)

# Optimizers
Optimizer = optimizer.Optimizer
Expand Down Expand Up @@ -146,6 +160,7 @@
"ScaledIdentity",
"Diagonal",
"Full",
"KroneckerFactored",
"TwoKroneckerFactored",
"NaiveDiagonal",
"NaiveFull",
Expand Down
Loading

0 comments on commit dfa68b5

Please sign in to comment.