diff --git a/phenaki_pytorch/phenaki_pytorch.py b/phenaki_pytorch/phenaki_pytorch.py index b5ee0b0..b198fe9 100644 --- a/phenaki_pytorch/phenaki_pytorch.py +++ b/phenaki_pytorch/phenaki_pytorch.py @@ -967,6 +967,9 @@ def __init__( # sampling + if exists(critic): + critic = critic.eval() + self.critic = critic self.steps = steps self.sample_temperature = sample_temperature @@ -1060,7 +1063,8 @@ def sample( if not is_last_step: if exists(self.critic): - scores = self.critic(video_token_ids) + with torch.no_grad(): + scores = self.critic(video_token_ids) noise = K * (uniform(scores.shape, device) - 0.5) * (steps_til_x0 / self.steps) scores = scores + noise diff --git a/setup.py b/setup.py index 53ad272..dfe110e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'phenaki-pytorch', packages = find_packages(exclude=[]), - version = '0.0.7', + version = '0.0.8', license='MIT', description = 'Phenaki - Pytorch', author = 'Phil Wang',