File size: 11,218 Bytes
d1a4ede
82c2aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbbe166
924633d
82c2aee
 
fddefb0
bac1c17
82c2aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55f2c80
82c2aee
8f86fe7
82c2aee
9a66726
82c2aee
 
 
d1a4ede
82c2aee
 
 
 
 
 
 
 
 
 
 
 
 
9356aac
 
 
82c2aee
 
 
 
 
 
9356aac
be0c9b1
9356aac
d1a4ede
82c2aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f86fe7
82c2aee
 
 
 
 
8f86fe7
55f2c80
 
82c2aee
55f2c80
8f86fe7
55f2c80
8f86fe7
82c2aee
 
 
 
 
 
 
 
 
 
d1a4ede
82c2aee
 
 
 
 
 
 
 
 
 
 
8f86fe7
82c2aee
cc236f1
 
d1a4ede
deec01c
 
 
 
 
82c2aee
b9706c3
82c2aee
b9706c3
ea5bc09
 
b9706c3
1a4879a
b9706c3
d1a4ede
b9706c3
 
 
 
 
 
 
 
 
 
 
 
d1a4ede
 
 
 
 
 
 
 
82c2aee
d1a4ede
b9706c3
d1a4ede
 
 
 
 
 
 
b9706c3
 
 
 
82c2aee
 
 
d1a4ede
82c2aee
821526b
b9706c3
 
 
 
 
 
 
 
d1a4ede
 
834efbc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import functools, torch
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
apply_liger_kernel_to_qwen2_vl()
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, LogitsProcessor, logging
from livecc_utils import prepare_multiturn_multimodal_inputs_for_generation, get_smart_resized_clip, get_smart_resized_video_reader
from qwen_vl_utils import process_vision_info

logger = logging.get_logger(__name__)

class ThresholdLogitsProcessor(LogitsProcessor):
    def __init__(self, token_id: int, base_threshold: float, step: float):
        self.token_id = token_id
        self.base_threshold = base_threshold
        self.step = step
        self.count = 0
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        threshold = self.base_threshold + self.step * self.count 
        low_confidence = torch.softmax(scores, dim=-1)[:, self.token_id] <= threshold
        if low_confidence.any():
            scores[low_confidence, self.token_id] = -float("inf")
        self.count += 1
        return scores
    
class LiveCCDemoInfer:
    VIDEO_PLAY_END = object()
    VIDEO_PLAY_CONTINUE = object()
    fps = 2
    initial_fps_frames = 6
    streaming_fps_frames = 2
    initial_time_interval = initial_fps_frames / fps
    streaming_time_interval = streaming_fps_frames / fps
    frame_time_interval = 1 / fps

    def __init__(self, model_path: str = None, device: str = 'cuda'):
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_path, torch_dtype="auto", 
            device_map=device, 
            attn_implementation='sdpa'
        )
        self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False)
        self.streaming_eos_token_id = self.processor.tokenizer(' ...').input_ids[-1]
        self.model.prepare_inputs_for_generation = functools.partial(prepare_multiturn_multimodal_inputs_for_generation, self.model)
        message = {
            "role": "user",
            "content": [
                {"type": "text", "text": 'livecc'},
            ]
        }
        texts = self.processor.apply_chat_template([message], tokenize=False)
        self.system_prompt_offset = texts.index('<|im_start|>user')
        self._cached_video_readers_with_hw = {}

    @torch.inference_mode()
    def live_cc(
        self,
        message: str,
        state: dict,
        max_pixels: int = 384 * 28 * 28,
        default_query: str = 'Please describe the video.',
        do_sample: bool = True,
        repetition_penalty: float = 1.05,
        streaming_eos_base_threshold: float = None, 
        streaming_eos_threshold_step: float = None, 
        hf_spaces: bool = False,
        **kwargs,
    ): 
        """
        state: dict, (maybe) with keys:
            video_path: str, video path
            video_timestamp: float, current video timestamp
            last_timestamp: float, last processed video timestamp
            last_video_pts_index: int, last processed video frame index
            video_pts: np.ndarray, video pts
            last_history: list, last processed history
        """
        # 1. preparation: video_reader, and last processing info
        video_timestamp, last_timestamp = state.get('video_timestamp', 0), state.get('last_timestamp', -1 / self.fps)
        video_path = state.get('video_path', None)
        if not video_path:
            return
        if video_path not in self._cached_video_readers_with_hw:
            self._cached_video_readers_with_hw[video_path] = get_smart_resized_video_reader(video_path, max_pixels)
            video_reader = self._cached_video_readers_with_hw[video_path][0]
            video_reader.get_frame_timestamp(0)
            state['video_pts'] = torch.from_numpy(video_reader._frame_pts[:, 1])
            state['last_video_pts_index'] = -1
        video_pts = state.get('video_pts', None)
        if video_pts is None:
            return
        video_timestamp = min(video_timestamp, video_pts[-1])
        if last_timestamp + self.frame_time_interval > video_pts[-1]:
            state['video_end'] = True
            return 
        video_reader, resized_height, resized_width = self._cached_video_readers_with_hw[video_path]
        last_video_pts_index = state['last_video_pts_index']

        # 2. which frames will be processed
        initialized = last_timestamp >= 0
        if not initialized:
            video_timestamp = max(video_timestamp, self.initial_time_interval)
        if video_timestamp <= last_timestamp + self.frame_time_interval:
            return
        timestamps = torch.arange(last_timestamp + self.frame_time_interval, video_timestamp, self.frame_time_interval) # add compensation
        
        # 3. fetch frames in required timestamps
        clip, clip_timestamps, clip_idxs = get_smart_resized_clip(video_reader, resized_height, resized_width, timestamps, video_pts, video_pts_index_from=last_video_pts_index+1)
        state['last_video_pts_index'] = clip_idxs[-1]
        state['last_timestamp'] = clip_timestamps[-1]

        # 4. organize to interleave frames
        interleave_clips, interleave_timestamps = [], []
        if not initialized:
            interleave_clips.append(clip[:self.initial_fps_frames])
            interleave_timestamps.append(clip_timestamps[:self.initial_fps_frames])
            clip = clip[self.initial_fps_frames:]
            clip_timestamps = clip_timestamps[self.initial_fps_frames:]
        if len(clip) > 0:
            interleave_clips.extend(list(clip.split(self.streaming_fps_frames)))
            interleave_timestamps.extend(list(clip_timestamps.split(self.streaming_fps_frames)))

        # 5. make conversation and send to model
        for clip, timestamps in zip(interleave_clips, interleave_timestamps):
            start_timestamp, stop_timestamp = timestamps[0].item(), timestamps[-1].item() + self.frame_time_interval
            conversation = [{
                "role": "user",
                "content": [
                    {"type": "text", "text": f'Time={start_timestamp:.1f}-{stop_timestamp:.1f}s'},
                    {"type": "video", "video": clip}
                ]
            }]
            if not message and not state.get('message', None):
                message = default_query
                logger.warning(f'No query provided, use default_query={default_query}')
            if message and state.get('message', None) != message:
                conversation[0]['content'].append({"type": "text", "text": message})
                state['message'] = message
            texts = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, return_tensors='pt')
            past_ids = state.get('past_ids', None)
            if past_ids is not None:
                texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
            inputs = self.processor(
                text=texts,
                images=None,
                videos=[clip],
                return_tensors="pt",
                return_attention_mask=False
            )
            inputs.to(self.model.device)
            if past_ids is not None:
                inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1) 
            if streaming_eos_base_threshold is not None:
                logits_processor = [ThresholdLogitsProcessor(self.streaming_eos_token_id, streaming_eos_base_threshold, streaming_eos_threshold_step)]
            else:
                logits_processor = None
            outputs = self.model.generate(
                **inputs, past_key_values=state.get('past_key_values', None), 
                return_dict_in_generate=True, do_sample=do_sample, 
                repetition_penalty=repetition_penalty,
                logits_processor=logits_processor,
                max_new_tokens=16,
            )
            state['past_key_values'] = outputs.past_key_values
            state['past_ids'] = outputs.sequences[:, :-1]
            response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
            if hf_spaces:
                light_state = {k: v for k, v in state.items() if k not in ['past_ids', 'past_key_values']}
                yield (start_timestamp, stop_timestamp), response, light_state
            else:
                yield (start_timestamp, stop_timestamp), response, state

    @torch.inference_mode()
    def video_qa(
        self,
        message: str,
        history: list,
        state: dict,
        do_sample: bool = True,
        repetition_penalty: float = 1.05,
        hf_spaces: bool = False,
        **kwargs,
    ): 
        """
        state: dict, (maybe) with keys:
            video_path: str, video path
            video_timestamp: float, current video timestamp
            last_timestamp: float, last processed video timestamp
            last_video_pts_index: int, last processed video frame index
            video_pts: np.ndarray, video pts
            last_history: list, last processed history
        """
        video_path = state.get('video_path', None)
        conversation = []
        if hf_spaces:
            for past_message in history:
                content = [{"type": "text", "text": past_message['content']}]
                if video_path: # only use once
                    content.insert(0, {"type": "video", "video": video_path})
                    video_path = None
                conversation.append({"role": past_message["role"], "content": content})
        else:
            pass # use past_key_values
        past_ids = state.get('past_ids', None)
        content = [{"type": "text", "text": message}]
        if past_ids is None and video_path: # only use once
            content.insert(0, {"type": "video", "video": video_path})
        conversation.append({"role": "user", "content": content})
        print(conversation)
        image_inputs, video_inputs = process_vision_info(conversation)
        texts = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, return_tensors='pt')
        if past_ids is not None:
            texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
        inputs = self.processor(
            text=texts,
            images=image_inputs,
            videos=video_inputs,
            return_tensors="pt",
            return_attention_mask=False
        )
        inputs.to(self.model.device)
        if past_ids is not None:
            inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1) 
        outputs = self.model.generate(
            **inputs, past_key_values=state.get('past_key_values', None), 
            return_dict_in_generate=True, do_sample=do_sample, 
            repetition_penalty=repetition_penalty,
            max_new_tokens=512,
        )
        state['past_key_values'] = outputs.past_key_values if not hf_spaces else None
        state['past_ids'] = outputs.sequences[:, :-1] if not hf_spaces else None
        response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
        print(response)
        return response, state