From a3cc81c233489d41de308f1fe77fc8f4e2f741cd Mon Sep 17 00:00:00 2001 From: Osvaldo Martin Date: Fri, 8 Oct 2021 20:04:40 +0300 Subject: [PATCH] move BART to its own module (#5058) * move BART to its own module * add missing file --- pymc/__init__.py | 1 + pymc/bart/__init__.py | 19 +++++++++++++++++++ pymc/{distributions => bart}/bart.py | 0 pymc/{step_methods => bart}/pgbart.py | 4 ++-- pymc/{distributions => bart}/tree.py | 0 pymc/distributions/__init__.py | 2 -- pymc/sampling.py | 2 +- pymc/step_methods/__init__.py | 1 - pymc/step_methods/hmc/nuts.py | 2 +- pymc/tests/test_bart.py | 4 ++-- 10 files changed, 26 insertions(+), 9 deletions(-) create mode 100644 pymc/bart/__init__.py rename pymc/{distributions => bart}/bart.py (100%) rename pymc/{step_methods => bart}/pgbart.py (99%) rename pymc/{distributions => bart}/tree.py (100%) diff --git a/pymc/__init__.py b/pymc/__init__.py index 13599000d02..cb30dd2d487 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -83,6 +83,7 @@ def __set_compiler_flags(): to_inference_data, ) from pymc.backends.tracetab import * +from pymc.bart import * from pymc.blocking import * from pymc.data import * from pymc.distributions import * diff --git a/pymc/bart/__init__.py b/pymc/bart/__init__.py new file mode 100644 index 00000000000..abace693c1f --- /dev/null +++ b/pymc/bart/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pymc.bart.bart import BART +from pymc.bart.pgbart import PGBART + +__all__ = ["BART", "PGBART"] diff --git a/pymc/distributions/bart.py b/pymc/bart/bart.py similarity index 100% rename from pymc/distributions/bart.py rename to pymc/bart/bart.py diff --git a/pymc/step_methods/pgbart.py b/pymc/bart/pgbart.py similarity index 99% rename from pymc/step_methods/pgbart.py rename to pymc/bart/pgbart.py index a837f21c199..a591a93586d 100644 --- a/pymc/step_methods/pgbart.py +++ b/pymc/bart/pgbart.py @@ -23,9 +23,9 @@ from pandas import DataFrame, Series from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements +from pymc.bart.bart import BARTRV +from pymc.bart.tree import LeafNode, SplitNode, Tree from pymc.blocking import RaveledVars -from pymc.distributions.bart import BARTRV -from pymc.distributions.tree import LeafNode, SplitNode, Tree from pymc.model import modelcontext from pymc.step_methods.arraystep import ArrayStepShared, Competence diff --git a/pymc/distributions/tree.py b/pymc/bart/tree.py similarity index 100% rename from pymc/distributions/tree.py rename to pymc/bart/tree.py diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index dc721e241bb..d9c26738ba7 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -22,7 +22,6 @@ logpt_sum, ) -from pymc.distributions.bart import BART from pymc.distributions.bound import Bound from pymc.distributions.continuous import ( AsymmetricLaplace, @@ -190,7 +189,6 @@ "Rice", "Moyal", "Simulator", - "BART", "CAR", "PolyaGamma", "logpt", diff --git a/pymc/sampling.py b/pymc/sampling.py index e00437e0fa7..994677f3551 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -41,6 +41,7 @@ from pymc.backends.arviz import _DefaultTrace from pymc.backends.base import BaseTrace, MultiTrace from pymc.backends.ndarray import NDArray +from pymc.bart.pgbart import PGBART from pymc.blocking import DictToArrayBijection from pymc.distributions import NoDistribution from pymc.exceptions import IncorrectArgumentsError, SamplingError @@ -48,7 +49,6 @@ from pymc.parallel_sampling import Draw, _cpu_count from pymc.step_methods import ( NUTS, - PGBART, BinaryGibbsMetropolis, BinaryMetropolis, CategoricalGibbsMetropolis, diff --git a/pymc/step_methods/__init__.py b/pymc/step_methods/__init__.py index 112ebaaa345..2b419feecc6 100644 --- a/pymc/step_methods/__init__.py +++ b/pymc/step_methods/__init__.py @@ -35,5 +35,4 @@ MetropolisMLDA, RecursiveDAProposal, ) -from pymc.step_methods.pgbart import PGBART from pymc.step_methods.slicer import Slice diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 1050fffac14..b55650eb806 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -18,7 +18,7 @@ from pymc.aesaraf import floatX from pymc.backends.report import SamplerWarning, WarningType -from pymc.distributions.bart import BARTRV +from pymc.bart.bart import BARTRV from pymc.math import logbern, logdiffexp_numpy from pymc.step_methods.arraystep import Competence from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index fe3b268eed6..20b14c6966d 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -7,7 +7,7 @@ def test_split_node(): - split_node = pm.distributions.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0) + split_node = pm.bart.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0) assert split_node.index == 5 assert split_node.idx_split_variable == 2 assert split_node.split_value == 3.0 @@ -18,7 +18,7 @@ def test_split_node(): def test_leaf_node(): - leaf_node = pm.distributions.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3]) + leaf_node = pm.bart.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3]) assert leaf_node.index == 5 assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3]) assert leaf_node.value == 3.14