From d2b71eaa3d0aa82f60c70c1abfc0d178fe49f950 Mon Sep 17 00:00:00 2001 From: PinetreePantry Date: Thu, 11 Apr 2024 11:47:58 -0700 Subject: [PATCH] Fix loading IntervenableModel for its subclasses --- pyreft/reft_model.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pyreft/reft_model.py b/pyreft/reft_model.py index 97f7afe..9ff6e0d 100644 --- a/pyreft/reft_model.py +++ b/pyreft/reft_model.py @@ -13,6 +13,19 @@ class ReftModel(pv.IntervenableModel): def __init__(self, config, model, **kwargs): super().__init__(config, model, **kwargs) + @staticmethod + def _convert_to_reft_model(intervenable_model): + reft_model = ReftModel(intervenable_model.config, intervenable_model.model) + # Copy any other necessary attributes + for attr in vars(intervenable_model): + setattr(reft_model, attr, getattr(intervenable_model, attr)) + return reft_model + + @staticmethod + def load(*args, **kwargs): + model = pv.IntervenableModel.load(*args, **kwargs) + return ReftModel._convert_to_reft_model(model) + def print_trainable_parameters(self): """ Print trainable parameters.