Spaces:
Running
on
Zero
Running
on
Zero
Update demo/infer.py
Browse files- demo/infer.py +6 -6
demo/infer.py
CHANGED
@@ -56,7 +56,7 @@ class LiveCCDemoInfer:
|
|
56 |
self,
|
57 |
message: str,
|
58 |
state: dict,
|
59 |
-
max_pixels: int =
|
60 |
default_query: str = 'Please describe the video.',
|
61 |
do_sample: bool = False,
|
62 |
repetition_penalty: float = 1.05,
|
@@ -122,20 +122,20 @@ class LiveCCDemoInfer:
|
|
122 |
# 5. make conversation and send to model
|
123 |
for clip, timestamps in zip(interleave_clips, interleave_timestamps):
|
124 |
start_timestamp, stop_timestamp = timestamps[0].item(), timestamps[-1].item() + self.frame_time_interval
|
125 |
-
|
126 |
"role": "user",
|
127 |
"content": [
|
128 |
{"type": "text", "text": f'Time={start_timestamp:.1f}-{stop_timestamp:.1f}s'},
|
129 |
{"type": "video", "video": clip}
|
130 |
]
|
131 |
-
}
|
132 |
if not message and not state.get('message', None):
|
133 |
message = default_query
|
134 |
logger.warning(f'No query provided, use default_query={default_query}')
|
135 |
if message and state.get('message', None) != message:
|
136 |
-
|
137 |
state['message'] = message
|
138 |
-
texts = self.processor.apply_chat_template(
|
139 |
past_ids = state.get('past_ids', None)
|
140 |
if past_ids is not None:
|
141 |
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
@@ -146,7 +146,6 @@ class LiveCCDemoInfer:
|
|
146 |
return_tensors="pt",
|
147 |
return_attention_mask=False
|
148 |
)
|
149 |
-
print(texts)
|
150 |
inputs.to(self.model.device)
|
151 |
if past_ids is not None:
|
152 |
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
@@ -159,6 +158,7 @@ class LiveCCDemoInfer:
|
|
159 |
return_dict_in_generate=True, do_sample=do_sample,
|
160 |
repetition_penalty=repetition_penalty,
|
161 |
logits_processor=logits_processor,
|
|
|
162 |
)
|
163 |
state['past_key_values'] = outputs.past_key_values
|
164 |
state['past_ids'] = outputs.sequences[:, :-1]
|
|
|
56 |
self,
|
57 |
message: str,
|
58 |
state: dict,
|
59 |
+
max_pixels: int = 384 * 28 * 28,
|
60 |
default_query: str = 'Please describe the video.',
|
61 |
do_sample: bool = False,
|
62 |
repetition_penalty: float = 1.05,
|
|
|
122 |
# 5. make conversation and send to model
|
123 |
for clip, timestamps in zip(interleave_clips, interleave_timestamps):
|
124 |
start_timestamp, stop_timestamp = timestamps[0].item(), timestamps[-1].item() + self.frame_time_interval
|
125 |
+
conversation = [{
|
126 |
"role": "user",
|
127 |
"content": [
|
128 |
{"type": "text", "text": f'Time={start_timestamp:.1f}-{stop_timestamp:.1f}s'},
|
129 |
{"type": "video", "video": clip}
|
130 |
]
|
131 |
+
}]
|
132 |
if not message and not state.get('message', None):
|
133 |
message = default_query
|
134 |
logger.warning(f'No query provided, use default_query={default_query}')
|
135 |
if message and state.get('message', None) != message:
|
136 |
+
conversation[0]['content'].append({"type": "text", "text": message})
|
137 |
state['message'] = message
|
138 |
+
texts = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
139 |
past_ids = state.get('past_ids', None)
|
140 |
if past_ids is not None:
|
141 |
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
|
|
146 |
return_tensors="pt",
|
147 |
return_attention_mask=False
|
148 |
)
|
|
|
149 |
inputs.to(self.model.device)
|
150 |
if past_ids is not None:
|
151 |
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
|
|
158 |
return_dict_in_generate=True, do_sample=do_sample,
|
159 |
repetition_penalty=repetition_penalty,
|
160 |
logits_processor=logits_processor,
|
161 |
+
max_new_tokens=16,
|
162 |
)
|
163 |
state['past_key_values'] = outputs.past_key_values
|
164 |
state['past_ids'] = outputs.sequences[:, :-1]
|