Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brash6 committed Nov 14, 2024
1 parent 9da25c0 commit ac77369
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 928 deletions.
6 changes: 2 additions & 4 deletions src/med_bench/estimation/mediation_tmle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from sklearn.linear_model import LinearRegression

from med_bench.estimation.base import Estimator
from med_bench.nuisances.utils import _get_regressor
from med_bench.utils.decorators import fitted

ALPHA = 10
Expand Down Expand Up @@ -85,7 +84,7 @@ def _one_step_correction_direct(self, t, m, x, y):
mu_t0_mx_star = mu_t0_mx + epsilon_h * h_corrector_t0
mu_t1_mx_star = mu_t1_mx + epsilon_h * h_corrector_t1

regressor_y = _get_regressor(self._regularize, self._use_forest)
regressor_y = self.regressor
reg_cross = clone(regressor_y)
reg_cross.fit(
x[t == 0], (mu_t1_mx_star[t == 0] -
Expand Down Expand Up @@ -143,8 +142,7 @@ def _one_step_correction_indirect(self, t, m, x, y):
h_corrector_t1 = t1 / p_x - t1 * ratio
mu_t1_mx_star = mu_t1_mx + epsilon_h * h_corrector_t1

regressor_y = _get_regressor(self._regularize,
self._use_forest)
regressor_y = self.regressor

reg_cross = clone(regressor_y)
reg_cross.fit(x[t == 0], mu_t1_mx_star[t == 0])
Expand Down
2 changes: 1 addition & 1 deletion src/med_bench/get_estimation_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from med_bench.estimation.mediation_g_computation import GComputation
from med_bench.estimation.mediation_ipw import InversePropensityWeighting
from med_bench.estimation.mediation_mr import MultiplyRobust
from med_bench.nuisances.utils import _get_regularization_parameters
from med_bench.utils.utils import _get_regularization_parameters
from med_bench.utils.constants import CV_FOLDS

from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
Expand Down
1 change: 0 additions & 1 deletion src/med_bench/mediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
import pandas as pd
from sklearn.base import clone
from sklearn.linear_model import RidgeCV


from .utils.nuisances import (
Expand Down
80 changes: 0 additions & 80 deletions src/med_bench/nuisances/conditional_outcome.py

This file was deleted.

238 changes: 0 additions & 238 deletions src/med_bench/nuisances/cross_conditional_outcome.py

This file was deleted.

Loading

0 comments on commit ac77369

Please sign in to comment.