diff --git a/export_prior.py b/export_prior.py index 8052532..408232f 100644 --- a/export_prior.py +++ b/export_prior.py @@ -106,20 +106,6 @@ def forward(self, temp: torch.Tensor): x = x[..., -1:] model(x) -logging.info("scripting cached modules") -n_cache = 0 - -cached_modules = [ - cc.CachedConv1d, -] - -for n, m in model.named_modules(): - if any(list(map(lambda c: isinstance(m, c), cached_modules))): - m.script_cache() - n_cache += 1 - -logging.info(f"{n_cache} cached modules found and scripted") - logging.info("script model") model = TraceModel(model) model = torch.jit.script(model)