diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4e5804f5134..d034d472157 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -991,6 +991,7 @@ def generate_token( if stopped: del batch + torch.cuda.empty_cache() # No need to return a batch if we know that all requests stopped return generations, None