xiangan winking636 commited on
Commit
75adb33
·
verified ·
1 Parent(s): a9fe3a6

Update README.md (#1)

Browse files

- Update README.md (940d577697f78f9e165de63ac5b35598975bfefe)


Co-authored-by: wkzhang <[email protected]>

Files changed (1) hide show
  1. README.md +82 -0
README.md CHANGED
@@ -71,11 +71,93 @@ seg_prompt = "Could you provide a segmentation mask for the right giraffe in thi
71
  pred_mask = model.seg(seg_img, seg_prompt, tokenizer, force_seg=True)
72
  ```
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  ## Example
75
 
76
 
77
  <img src="https://github.com/user-attachments/assets/85c023a1-3e0c-4ea5-a764-1eb9ee0fbddf" alt="output" width="1024"/>
78
  <img src="https://github.com/user-attachments/assets/5b767327-bd0a-4185-8f7e-b1ab0aa260c9" alt="output" width="1024"/>
 
 
 
79
 
80
  ## Citations
81
  ```
 
71
  pred_mask = model.seg(seg_img, seg_prompt, tokenizer, force_seg=True)
72
  ```
73
 
74
+ If you want to use this code in video, please refer to this sample below
75
+ ```python
76
+ from transformers import AutoModel, AutoTokenizer
77
+ from PIL import Image
78
+ import torch
79
+ from torchvision import transforms
80
+ import subprocess
81
+ import os
82
+
83
+ # video path
84
+ video_path = "updownfunk.mp4"
85
+ input_dir = "frames"
86
+ output_dir = "mask_frames"
87
+ os.makedirs(input_dir, exist_ok=True)
88
+ os.makedirs(output_dir, exist_ok=True)
89
+
90
+ # assert you have ffmpeg installed, mp4 -> jpg
91
+ cmd = [
92
+ "ffmpeg",
93
+ "-i", video_path,
94
+ "-vf", "fps=30", # 30FPS
95
+ "-qscale:v", "1",
96
+ os.path.join(input_dir, "frame_%04d.jpg")
97
+ ]
98
+ subprocess.run(cmd)
99
+
100
+ # model path
101
+ model_path = "DeepGlint-AI/MLCD-Seg" # or use your local path
102
+ mlcd_seg = AutoModel.from_pretrained(
103
+ model_path,
104
+ torch_dtype=torch.float16,
105
+ trust_remote_code=True
106
+ ).cuda()
107
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
108
+
109
+ # read jpgs
110
+ image_files = sorted([f for f in os.listdir(input_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
111
+
112
+ for idx, filename in enumerate(image_files, start=1):
113
+
114
+ src_path = os.path.join(input_dir, filename)
115
+ seg_img = Image.open(src_path).convert('RGB')
116
+
117
+ seg_prompt = "This <video> depicts a group of people dancing.\nCould you provide a segmentation mask for the man in pink suit?"
118
+ pred_mask = mlcd_seg.predict_forward(seg_img, seg_prompt, tokenizer, force_seg=True)
119
+
120
+ # Mask visualization
121
+ pred_mask = pred_mask.squeeze(0).cpu()
122
+ pred_mask = (pred_mask > 0.5).float()
123
+ img_tensor = transforms.ToTensor()(seg_img)
124
+ alpha = 0.2 # 20% transparency
125
+ red_mask = torch.tensor([0.0, 1.0, 0.0]).view(3, 1, 1).to(img_tensor.device) # green mask
126
+ black_bg = torch.zeros_like(img_tensor) # black background
127
+ masked_area = red_mask * alpha + img_tensor * (1 - alpha)
128
+ background = black_bg * alpha + img_tensor * (1 - alpha)
129
+ combined = torch.where(pred_mask.unsqueeze(0).bool(), masked_area, background)
130
+ combined = combined.cpu() # [3, H, W], CPU
131
+
132
+ # Save masked jpgs
133
+ new_name = f"{idx:04d}{os.path.splitext(filename)[1]}"
134
+ dst_path = os.path.join(output_dir, new_name)
135
+ transforms.ToPILImage()(combined.clamp(0, 1)).save(dst_path)
136
+
137
+ cmd = [
138
+ "ffmpeg",
139
+ "-y",
140
+ "-framerate", str(30), # fps
141
+ "-i", os.path.join(output_dir, "%04d.jpg"),
142
+ "-c:v", "libx264",
143
+ "-crf", str(23),
144
+ "-pix_fmt", "yuv420p",
145
+ "-vf", "fps=" + str(23),
146
+ "updownfunk_mask.mp4" # output video
147
+ ]
148
+ # jpgs -> mp4
149
+ subprocess.run(cmd, check=True)
150
+ ```
151
+
152
+
153
  ## Example
154
 
155
 
156
  <img src="https://github.com/user-attachments/assets/85c023a1-3e0c-4ea5-a764-1eb9ee0fbddf" alt="output" width="1024"/>
157
  <img src="https://github.com/user-attachments/assets/5b767327-bd0a-4185-8f7e-b1ab0aa260c9" alt="output" width="1024"/>
158
+ <video width="80%" controls>
159
+ <source src="https://github.com/user-attachments/assets/380dee0d-47c4-4e01-8ff0-e69e62cccd7c">
160
+ </video>
161
 
162
  ## Citations
163
  ```