Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Processing flow of MovieChat-1K_train #54

Open
LZHgrla opened this issue Apr 29, 2024 · 6 comments
Open

Processing flow of MovieChat-1K_train #54

LZHgrla opened this issue Apr 29, 2024 · 6 comments

Comments

@LZHgrla
Copy link

LZHgrla commented Apr 29, 2024

Hi!
Could you provide the processing script or procedure for MovieChat-1K_train dataset? We plan to fine-tune our model on this dataset and need to ensure that pre-training phase follows the same processing procedure.

@Espere-1119-Song
Copy link
Collaborator

For each video in MovieChat-1K_train dataset, we average sample 8192 frames with eva_clip_g, set the image_size to 224 and store in hdf5. Our feature extraction data is as follow:

import os
import cv2
import numpy as np
import torchvision.transforms as transforms
import torch
import einops
import h5py
from MovieChat.models.eva_vit import create_eva_vit_g

device = "cuda:0"

input_folder = 'our_train_data'

output_folder = 'feature_hdf5'

if not os.path.exists(output_folder):
    os.makedirs(output_folder)
    
subfolders = [f.name for f in os.scandir(output_folder) if f.is_dir()]
mp4_files = [f for f in os.listdir(input_folder) if f.endswith('.mp4')]

frames_to_read = 8192

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def init_vision_encoder(
    img_size=224, drop_path_rate=0, use_grad_checkpoint=False, precision="fp16"
):
    model_name = "eva_clip_g"
    visual_encoder = create_eva_vit_g(
        img_size, drop_path_rate, use_grad_checkpoint, precision
    ).float()
    visual_encoder.eval()

    return visual_encoder



image_encoder = init_vision_encoder().to(device)

count = 0
for mp4_file in mp4_files:
    if mp4_file.split('.')[0] not in subfolders:
        try:
            video_path = os.path.join(input_folder, mp4_file)

            cap = cv2.VideoCapture(video_path)

            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            frame_interval = total_frames // frames_to_read

            features_list = []

            frame_count = 0
            batch_size = 64
            piece_count = 0
            current_batch = []
            while frame_count < total_frames:
                ret, frame = cap.read()
                
                if not ret:
                    break
                
                if frame_count % frame_interval == 0:
                    frame_tensor = transform(frame)
                    frame_tensor = frame_tensor.unsqueeze(0).to(device)
                    current_batch.append(frame_tensor)
                
                    if len(current_batch) == batch_size:
                        batch = torch.cat(current_batch, dim=0).to(device)
                        with torch.no_grad():
                            features = image_encoder(batch).cpu()
                            output_dict = os.path.join(output_folder, os.path.splitext(mp4_file)[0])
                            if not os.path.exists(output_dict):
                                os.makedirs(output_dict)
                            output_filename = str(piece_count) + '.h5'
                            output_path = os.path.join(output_dict, output_filename)
                            with h5py.File(output_path, "w") as hdf5_file:
                                dataset_name = f"frames_{piece_count}"
                                print(dataset_name)
                                hdf5_file.create_dataset(dataset_name, data=features)
                            print(output_filename)
                            piece_count += 1
                        current_batch = []
                
                frame_count += 1

            if len(current_batch) > 0:
                batch = torch.cat(current_batch, dim=0).to(device)
                with torch.no_grad():
                    features = image_encoder(batch).cpu()
                    output_dict = os.path.join(output_folder, os.path.splitext(mp4_file)[0])
                    if not os.path.exists(output_dict):
                        os.makedirs(output_dict)
                    output_filename = str(piece_count) + '.h5'
                    output_path = os.path.join(output_dict, output_filename)
                    with h5py.File(output_path, "w") as hdf5_file:
                        dataset_name = f"frames_{piece_count}"
                        print(dataset_name)
                        hdf5_file.create_dataset(dataset_name, data=features)
                    print(output_filename)
                    piece_count += 1
            
            cap.release()

        except Exception as e:
            print(e)


However, we didn't use the extracted feature to run MovieChat. I think the main difference is about frame reading in inference.py and frame encoding in moviechat.py.

Hope this can be helpful to you! :)

@LZHgrla
Copy link
Author

LZHgrla commented Apr 29, 2024

@Espere-1119-Song
Awesome! Thanks very much!

@LZHgrla LZHgrla closed this as completed Apr 29, 2024
@HIT-cwh
Copy link

HIT-cwh commented May 8, 2024

Hi @Espere-1119-Song !
I'm uncertain if I'm grasping this accurately. In the provided code snippet, the video frames obtained through cv2.VideoCapture are in BGR format, whereas the images passed into transforms.ToPILImage() should adhere to the RGB format, leading to potential inconsistency.

@LZHgrla LZHgrla reopened this May 9, 2024
@Espere-1119-Song
Copy link
Collaborator

Thank you for pointing out the issue! We apologize for any inconvenience caused. We are currently uploading the raw videos to Huggingface, and we expect to complete this by the weekend.

@HIT-cwh
Copy link

HIT-cwh commented May 10, 2024

Thanks

@Espere-1119-Song
Copy link
Collaborator

We upload the raw videos of the training set :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants