diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py
index f1732ea5..9775126b 100644
--- a/awq/modules/fused/attn.py
+++ b/awq/modules/fused/attn.py
@@ -188,16 +188,19 @@ def forward(
             # Always reset to 0
             self.start_pos = 0
 
+        hf_is_generating = False
+
+        if self.is_hf_transformers and "use_cache" in kwargs:
+            hf_is_generating = kwargs["use_cache"]
+
+
         # In case we re-generate, we need to refresh the starting position
         # to 0. We detect it by checking if `past_key_values` is set to None,
         # which indicates that we are on the first step of `generate()`.
         # This is only applicable for `transformers` integration
-        if (
-            self.is_hf_transformers
-            and "past_key_value" in kwargs
-            and kwargs["past_key_value"] is None
-        ):
+        if (self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None) or (self.is_hf_transformers and not hf_is_generating):
             self.start_pos = 0
+    
 
         xqkv = self.qkv_proj(hidden_states)
         xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
@@ -214,8 +217,6 @@ def forward(
             if not self.use_alibi:
                 xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
 
-            self.cache.to(xq)
-
             values_store = xv.transpose(2, 1)
             keys_store = (
                 xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
@@ -223,6 +224,7 @@ def forward(
                 .contiguous()
             )
 
+            self.cache.to(xq)
             self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
 
             # Only necessary to retrieve from cache when we are not processing context
@@ -248,6 +250,11 @@ def forward(
 
             # When seqlen is 1, there is nothing else to attend to
             if attention_mask is not None and seqlen > 1:
+                # For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we 
+                # need to slice it
+                if attention_mask.shape[-1] != seqlen:
+                    attention_mask = attention_mask[:, :, :seqlen, :seqlen]
+
                 scores = (
                     scores + attention_mask
                 )  # (bs, n_local_heads, slen, cache_len + slen)
@@ -278,11 +285,15 @@ def forward(
         attn_output = self.o_proj(attention_weight)
         self.start_pos += seqlen
 
+        if self.is_hf_transformers and not hf_is_generating:
+            self.start_pos = 0
+
         # past_key_value is replaced with cache_v, cache_k, returning empty data
         # we pass a dummy past kv cache for transformers to be able to retrieve the correct info
         # about past key length
         past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
 
+
         if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
             new_cache = DynamicCache()
             new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)