Spaces:
Running
on
Zero
Running
on
Zero
fix
Browse files- demo/infer.py +54 -48
demo/infer.py
CHANGED
@@ -156,57 +156,63 @@ class LiveCCDemoInfer:
|
|
156 |
state['past_ids'] = outputs.sequences[:, :-1]
|
157 |
yield (start_timestamp, stop_timestamp), self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
|
158 |
|
|
|
159 |
def video_qa(
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
185 |
else:
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
"
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
inputs = processor(
|
198 |
-
text=
|
199 |
images=image_inputs,
|
200 |
videos=video_inputs,
|
201 |
return_tensors="pt",
|
|
|
202 |
)
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
|
|
|
156 |
state['past_ids'] = outputs.sequences[:, :-1]
|
157 |
yield (start_timestamp, stop_timestamp), self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
|
158 |
|
159 |
+
@torch.inference_mode()
|
160 |
def video_qa(
|
161 |
+
self,
|
162 |
+
query: str,
|
163 |
+
state: dict,
|
164 |
+
default_query: str = 'Please describe the video.',
|
165 |
+
do_sample: bool = False,
|
166 |
+
repetition_penalty: float = 1.05,
|
167 |
+
**kwargs,
|
168 |
+
):
|
169 |
+
"""
|
170 |
+
state: dict, (maybe) with keys:
|
171 |
+
video_path: str, video path
|
172 |
+
video_timestamp: float, current video timestamp
|
173 |
+
last_timestamp: float, last processed video timestamp
|
174 |
+
last_video_pts_index: int, last processed video frame index
|
175 |
+
video_pts: np.ndarray, video pts
|
176 |
+
last_history: list, last processed history
|
177 |
+
"""
|
178 |
+
video_path = state.get('video_path', None)
|
179 |
+
if video_path:
|
180 |
+
message = {
|
181 |
+
"role": "user",
|
182 |
+
"content": [
|
183 |
+
{"type": "video", "video": video_path},
|
184 |
+
{"type": "text", "text": query if query else default_query},
|
185 |
+
],
|
186 |
+
}
|
187 |
+
|
188 |
else:
|
189 |
+
message = {
|
190 |
+
"role": "user",
|
191 |
+
"content": [
|
192 |
+
{"type": "text", "text": query if query else default_query},
|
193 |
+
],
|
194 |
+
}
|
195 |
+
image_inputs, video_inputs = process_vision_info([message])
|
196 |
+
texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
197 |
+
past_ids = state.get('past_ids', None)
|
198 |
+
if past_ids is not None:
|
199 |
+
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
200 |
+
inputs = self.processor(
|
201 |
+
text=texts,
|
202 |
images=image_inputs,
|
203 |
videos=video_inputs,
|
204 |
return_tensors="pt",
|
205 |
+
return_attention_mask=False
|
206 |
)
|
207 |
+
inputs.to('cuda')
|
208 |
+
if past_ids is not None:
|
209 |
+
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
210 |
+
outputs = self.model.generate(
|
211 |
+
**inputs, past_key_values=state.get('past_key_values', None),
|
212 |
+
return_dict_in_generate=True, do_sample=do_sample,
|
213 |
+
repetition_penalty=repetition_penalty,
|
214 |
+
max_new_tokens=512,
|
215 |
+
)
|
216 |
+
state['past_key_values'] = outputs.past_key_values
|
217 |
+
state['past_ids'] = outputs.sequences[:, :-1]
|
218 |
+
return self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
|