Skip to content

Commit

Permalink
fix visualglm attention_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
LokeZhou committed Oct 15, 2023
1 parent 67fb559 commit fad8c3b
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 8 deletions.
1 change: 0 additions & 1 deletion applications/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ result = task(prompt=prompt)['result']
| [文本引导的图像放大(Text-Guided Image Upscaling)](./image2image/README.md/#文本引导的图像放大text-guided-image-upscaling) | `ldm-super-resolution-4x-openimages`||
| [文本引导的图像编辑(Text-Guided Image Inpainting)](./Inpainting/README.md/#文本引导的图像编辑text-guided-image-inpainting) | `stable-diffusion-2-inpainting` | [fastdeploy](../ppdiffusers/deploy/README.md/#文本引导的图像编辑text-guided-image-inpainting) |
| [文本引导的图像变换(Image-to-Image Text-Guided Generation)](./image2image/README.md/#文本引导的图像变换image-to-image-text-guided-generation) | `stable-diffusion-v1-5` | [fastdeploy](../ppdiffusers/deploy/README.md/#文本引导的图像变换image-to-image-text-guided-generation) |
| [文本图像双引导图像生成(Dual Text and Image Guided Generation)](./image2image/README.md/#文本图像双引导图像生成dual-text-and-image-guided-generation) | `versatile-diffusion` ||
| [文本条件的视频生成(Text-to-Video Generation)](./text2video/README.md/#文本条件的视频生成text-to-video-generation) | `text-to-video-ms-1.7b` ||
| [音频生成图像(Audio-to-Image Generation)](./Audio2Img/README.md/#audio-to-image) | `imagebind stable-diffusion-2-1-unclip` | |
| [音频描述(Audio-to-Caption Generation)](./Audio2Caption/README.md/#音频描述audio-to-caption-generation) | `chatglm-6b whisper` | |
Expand Down
1 change: 0 additions & 1 deletion applications/README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ result = task(prompt=prompt)['result']
| [文本引导的图像放大(Text-Guided Image Upscaling)](./image2image/README.md/#文本引导的图像放大text-guided-image-upscaling) | `ldm-super-resolution-4x-openimages`||
| [文本引导的图像编辑(Text-Guided Image Inpainting)](./Inpainting/README.md/#文本引导的图像编辑text-guided-image-inpainting) | `stable-diffusion-2-inpainting` | [fastdeploy](../ppdiffusers/deploy/README.md/#文本引导的图像编辑text-guided-image-inpainting) |
| [文本引导的图像变换(Image-to-Image Text-Guided Generation)](./image2image/README.md/#文本引导的图像变换image-to-image-text-guided-generation) | `stable-diffusion-v1-5` | [fastdeploy](../ppdiffusers/deploy/README.md/#文本引导的图像变换image-to-image-text-guided-generation) |
| [文本图像双引导图像生成(Dual Text and Image Guided Generation)](./image2image/README.md/#文本图像双引导图像生成dual-text-and-image-guided-generation) | `versatile-diffusion` ||
| [文本条件的视频生成(Text-to-Video Generation)](./text2video/README.md/#文本条件的视频生成text-to-video-generation) | `text-to-video-ms-1.7b` ||

More applications under continuous development......
Expand Down
10 changes: 4 additions & 6 deletions paddlemix/examples/visualglm/run_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,20 @@ def predict(args):
# Epoch 1
query = "写诗描述一下这个场景"
history = []
inputs = processor(image, query)
inputs = processor(image, query, max_length=1024)

generate_ids, _ = model.generate(**inputs, **generate_kwargs)
responses = processor.get_responses(generate_ids)
history.append([query, responses[0]])
print("query->", query)
print("responses->", responses)
print(responses)

# Epoch 2
query = "这部电影的导演是谁?"
inputs = processor(image, query, history=history)
generate_ids, _ = model.generate(**inputs, **generate_kwargs)
responses = processor.get_responses(generate_ids)
history.append([query, responses[0]])
print("query->", query)
print("responses->", responses)
# print(responses)
print(responses)


if __name__ == "__main__":
Expand Down
14 changes: 14 additions & 0 deletions paddlemix/models/visualglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,19 @@ def __init__(self, config: ChatGLMConfig):
super(ChatGLMForConditionalGenerationWithImage, self).__init__(config)
self.config = config

def get_masks(self, input_ids):

batch_size, seq_length = input_ids.shape
context_lengths = []
for seq in input_ids:
context_lengths.append(paddle.where(seq == self.config.bos_token_id)[0][0])
attention_mask = paddle.tril(paddle.ones([batch_size, seq_length, seq_length]))
for i, context_length in enumerate(context_lengths):
attention_mask[i, :, :context_length] = 1
attention_mask = attention_mask.unsqueeze(1)
attention_mask = (attention_mask > 0.5).astype("int64")
return attention_mask

def prepare_inputs_for_generation(
self, input_ids, position_ids=None, attention_mask=None, past_key_values=None, cache=None, **kwargs
):
Expand Down Expand Up @@ -1624,6 +1637,7 @@ def generate(
"""

image_features = self.encode_images(pixel_values)
attention_mask = self.language_model.get_masks(input_ids)

outputs = self.language_model.generate(
input_ids=input_ids,
Expand Down

0 comments on commit fad8c3b

Please sign in to comment.