-
Notifications
You must be signed in to change notification settings - Fork 721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scaling ortholearners using Ray #800
Conversation
Signed-off-by: Vishal Verma <[email protected]>
Signed-off-by: Vishal Verma <[email protected]>
Signed-off-by: Vishal Verma <[email protected]>
Signed-off-by: Vishal Verma <[email protected]>
Signed-off-by: Vishal Verma <[email protected]>
Signed-off-by: Vishal Verma <[email protected]>
Signed-off-by: Vishal Verma <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this looks like a great addition to the library. However, there are a few changes that need to be addressed before it can be added.
First of all, please revert your changes to setup.cfg, merge the latest main back into your branch, and then make those changes to pyproject.toml instead - sorry that we changed this out from under you while your PR was in progress, but the package metadata has been moved there instead.
In addition to my comments on individual files, here are some other thoughts:
- To be broadly useful, these changes need to be propagated to at least the main DML subclasses, rather than just OrthoLearner, RLearner, and DML, but really ideally to everything that uses _crossfit.
- The coverage report shows that most of the new code in the _OrthoLearner class is never run by any of the tests, since you set
use_ray=False
for all of the tests that use the class directly. Settinguse_ray=True
should fix that specific coverage issue, but consider whether additional tests for RLearner or DML would also be useful. - This seems like a potentially very helpful feature, so it's probably worth creating a documentation page or notebook, or at the very least an FAQ entry, describing when/why/how to use it.
.github/workflows/ci.yml
Outdated
extras: "[tf,plt]" | ||
extras: "[tf,plt,ray]" | ||
- kind: other | ||
opts: '-m "cate_api" -n auto' | ||
extras: "[tf,plt]" | ||
extras: "[tf,plt,ray]" | ||
- kind: dml | ||
opts: '-m "dml"' | ||
extras: "[tf,plt]" | ||
extras: "[tf,plt,ray]" | ||
- kind: main | ||
opts: '-m "not (notebook or automl or dml or serial or cate_api or treatment_featurization)" -n 2' | ||
extras: "[tf,plt,dowhy]" | ||
extras: "[tf,plt,dowhy,ray]" | ||
- kind: treatment | ||
opts: '-m "treatment_featurization" -n auto' | ||
extras: "[tf,plt]" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe ray
only needs to be added to the main
test kind, since that is where the test_ortho_learner tests are run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in latest commit
.github/workflows/ci.yml
Outdated
- kind: "except-customer-scenarios" | ||
extras: "[tf,plt]" | ||
extras: "[tf,plt,ray]" | ||
pattern: "(?!CustomerScenarios)" | ||
install_graphviz: true | ||
version: '3.8' # no supported version of tensorflow for 3.9 | ||
- kind: "customer-scenarios" | ||
extras: "[plt,dowhy]" | ||
extras: "[plt,dowhy,ray]" | ||
pattern: "CustomerScenarios" | ||
version: '3.9' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless you make any changes to the notebooks to take advantage of the new ray functionality, these changes should not be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in latest commit
econml/dml/dml.py
Outdated
random_state=None): | ||
random_state=None, | ||
use_ray=False, | ||
**ray_remote_func_options |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**ray_remote_func_options | |
ray_remote_func_options={} |
I think it would be better to make this an explicit dictionary argument, rather than having it implicitly include any other keyword arguments passed to the DML initializer since in the future we might want similar arguments for other compute backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(This also applies all the way up the hierarchy, to the RLearner and OrthoLearner initializer arguments)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in latest commit
setup.cfg
Outdated
@@ -66,6 +66,8 @@ plt = | |||
matplotlib < 3.6.0 | |||
dowhy = | |||
dowhy < 0.9 | |||
ray = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the inconvenience but these changes now need to be made to pyproject.toml instead - we've tried to move as much of the static metadata for the project as possible to that file.
econml/dml/_rlearner.py
Outdated
@@ -272,15 +272,17 @@ def _gen_rlearner_model_final(self): | |||
""" | |||
|
|||
def __init__(self, *, discrete_treatment, treatment_featurizer, categories, | |||
cv, random_state, mc_iters=None, mc_agg='mean'): | |||
cv, random_state, mc_iters=None, mc_agg='mean', use_ray=False, **ray_remote_func_options): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cv, random_state, mc_iters=None, mc_agg='mean', use_ray=False, **ray_remote_func_options): | |
cv, random_state, mc_iters=None, mc_agg='mean', use_ray=False, ray_remote_func_options=ray_remote_func_options): | |
econml/_ortho_learner.py
Outdated
return nuisance_temp, model, test_idxs, (score_temp if calculate_scores else None) | ||
|
||
|
||
def _crossfit(model, use_ray, folds, ray_remote_fun_option, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes more sense for folds to come before the ray arguments (and certainly for the ray arguments to be adjacent), and these changes make the specification match the docstring.
def _crossfit(model, use_ray, folds, ray_remote_fun_option, *args, **kwargs): | |
def _crossfit(model, folds, use_ray=False, ray_remote_fun_option={}, *args, **kwargs): |
econml/_ortho_learner.py
Outdated
@@ -60,6 +120,10 @@ def _crossfit(model, folds, *args, **kwargs): | |||
function estimates a model of the nuisance function, based on the input | |||
data to fit. Predict evaluates the fitted nuisance function on the input | |||
data to predict. | |||
use_ray: bool, default False (optional) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_ray: bool, default False (optional) | |
use_ray: bool, default False |
having a default implies optional
econml/_ortho_learner.py
Outdated
@@ -60,6 +120,10 @@ def _crossfit(model, folds, *args, **kwargs): | |||
function estimates a model of the nuisance function, based on the input | |||
data to fit. Predict evaluates the fitted nuisance function on the input | |||
data to predict. | |||
use_ray: bool, default False (optional) | |||
Flag to indicate whether to use ray to parallelize the cross-fitting step. | |||
ray_remote_fun_option: dict, default None (optional) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ray_remote_fun_option: dict, default None (optional) | |
ray_remote_fun_option: dict, default {} |
Having a default implies optional
econml/_ortho_learner.py
Outdated
nuisance, model_list, fitted_inds, scores = _crossfit(Wrapper(model), folds, X, y, W=y, Z=None) | ||
use_ray = False | ||
ray_remote_fun_option = {} | ||
nuisance, model_list, fitted_inds, scores = _crossfit(Wrapper(model),use_ray, folds,ray_remote_fun_option, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nuisance, model_list, fitted_inds, scores = _crossfit(Wrapper(model),use_ray, folds,ray_remote_fun_option, | |
nuisance, model_list, fitted_inds, scores = _crossfit(Wrapper(model), folds, use_ray, ray_remote_fun_option, |
# Conflicts: # setup.cfg
1) Fixed ci.yml extras dependencies 2) Added Description of all the added option in doc string in case of dml and rlearner 3) Addressed chaneges suggested for _ortho_learner.py 4)Removed ray.shutdown(), it can be taken care of explicitly on case to case basis . 5)Made ray_remote_func_options as explicit dictionary argument. What has been added ? 1) Extended the changes to all estimators using _crossfit. 2) Added Test case to run for with_ray and without_ray for above changes 3) Added Notebook on how to use this feature. Signed-off-by: Vishal Verma <[email protected]>
What have been fixed since last commit ?
What has been added ?
@kbattocchi kindly review the latest commit and provide feedback if any ! |
… testcases Signed-off-by: Vishal Verma <[email protected]>
…mode for tests. Signed-off-by: Vishal Verma <[email protected]>
econml/dml/dml.py
Outdated
use_ray=False, | ||
ray_remote_func_options=None, | ||
): | ||
if ray_remote_func_options is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this logic to within the OrthoLearner fit function, and remove this logic from all subclass __init__
functions. That way we avoid redundant code in all of the subclass __init__
functions and maintain a scikit-learn-like API. If interested in more context, see the "Instantiation" section of this page https://scikit-learn.org/stable/developers/develop.html#apis-of-scikit-learn-objects.
For instance, imagine a user does the following
est = LinearDML(use_ray=some_dict)
est.use_ray = None # user changes their mind about use_ray
est.fit(…)
We want the logic of converting None to an empty dict in .fit so we can allow for this kind of behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted make sense, I will move the redundant code, to fit function within Ortholearner
@@ -642,6 +657,12 @@ class LinearDML(StatsModelsCateEstimatorMixin, DML): | |||
If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used | |||
by :mod:`np.random<numpy.random>`. | |||
|
|||
use_ray: bool, default False | |||
Whether to use Ray to parallelize the cross-fitting step. If True, Ray must be installed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you fix the spacing here with a new line in between the arg descriptions. Same for SparseLinearDML and KernelDML
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted
-removed redundant code for ray_remote_function and moved to ortholearner's fit Signed-off-by: Vishal Verma <[email protected]>
Signed-off-by: Keith Battocchi <[email protected]>
a7c168a
to
d76dff0
Compare
Signed-off-by: Keith Battocchi <[email protected]>
d76dff0
to
3c2eb4b
Compare
Signed-off-by: Keith Battocchi <[email protected]>
Signed-off-by: Keith Battocchi <[email protected]>
Signed-off-by: Keith Battocchi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've broken the tests out into a new mark and I think things look good, so I'll merge once the checks pass. Thanks for this contribution!
econml/_ortho_learner.py
Outdated
@@ -412,7 +418,6 @@ def _gen_ortho_learner_model_final(self): | |||
discrete_instrument=False, categories='auto', random_state=None) | |||
est.fit(y, X[:, 0], W=X[:, 1:]) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Or (for parallelization using ray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if my previous comment was unclear: I think including the comment is helpful for understanding why est
is being redefined; it's just that it needs to be a comment so that the entire block is valid python code that can be run.
econml/_ortho_learner.py
Outdated
if ray_remote_func_options is None: | ||
ray_remote_func_options = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider whether just making the default {}
instead of None
would make sense. In general, we try not to put any logic in our initializers, because it's possible the user will do something like this:
est = LinearDML()
est.use_ray = True
est.ray_remote_options = None
and then the logic to turn it into {}
won't run. So I think it's fine to require it to be an actual dictionary instead of None and skip the extra logic.
Signed-off-by: Keith Battocchi <[email protected]>
Signed-off-by: Keith Battocchi <[email protected]>
Added Implementation of ray-based distributed parallelization to crossfit. --------- Signed-off-by: Vishal Verma <[email protected]> Signed-off-by: Keith Battocchi <[email protected]> Co-authored-by: Keith Battocchi <[email protected]>
issue : 793