ThomasSimonini HF Staff commited on
Commit
829ac34
·
verified ·
1 Parent(s): 8eae749

Upload moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +188 -1
moondream.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from .vision_encoder import VisionEncoder
3
  from .configuration_moondream import MoondreamConfig
@@ -85,6 +86,192 @@ class Moondream(PreTrainedModel):
85
  inputs_embeds=inputs_embeds, streamer=streamer, **generate_config
86
  )
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
90
 
@@ -177,4 +364,4 @@ class Moondream(PreTrainedModel):
177
  return [
178
  x.strip()
179
  for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
180
- ]
 
1
+ """
2
  import torch
3
  from .vision_encoder import VisionEncoder
4
  from .configuration_moondream import MoondreamConfig
 
86
  inputs_embeds=inputs_embeds, streamer=streamer, **generate_config
87
  )
88
 
89
+ return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
90
+
91
+ def answer_question(
92
+ self,
93
+ image_embeds,
94
+ question,
95
+ tokenizer,
96
+ chat_history="",
97
+ result_queue=None,
98
+ **kwargs,
99
+ ):
100
+ prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
101
+ answer = self.generate(
102
+ image_embeds,
103
+ prompt,
104
+ tokenizer=tokenizer,
105
+ max_new_tokens=512,
106
+ **kwargs,
107
+ )[0]
108
+ cleaned_answer = answer.strip()
109
+
110
+ # Use the result_queue to pass the result if it is provided
111
+ if result_queue:
112
+ result_queue.put(cleaned_answer)
113
+ else:
114
+ return cleaned_answer
115
+
116
+ def batch_answer(
117
+ self,
118
+ images,
119
+ prompts,
120
+ tokenizer,
121
+ **kwargs,
122
+ ):
123
+ image_embeds = self.encode_image(images)
124
+
125
+ templated_prompts = [
126
+ f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts
127
+ ]
128
+ prompt_embs = [
129
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
130
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
131
+ ]
132
+
133
+ bos_emb = prompt_embs[0][0]
134
+ max_len = max([p.shape[0] for p in prompt_embs])
135
+
136
+ inputs_embeds = torch.cat(
137
+ [
138
+ torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
139
+ for p in prompt_embs
140
+ ],
141
+ dim=0,
142
+ )
143
+ attention_mask = torch.cat(
144
+ [
145
+ torch.cat(
146
+ [
147
+ torch.zeros(
148
+ 1,
149
+ max_len - p.shape[0],
150
+ device=self.device,
151
+ dtype=torch.long,
152
+ ),
153
+ torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
154
+ ],
155
+ dim=1,
156
+ )
157
+ for p in prompt_embs
158
+ ],
159
+ dim=0,
160
+ )
161
+
162
+ generate_config = {
163
+ "eos_token_id": tokenizer.eos_token_id,
164
+ "bos_token_id": tokenizer.bos_token_id,
165
+ "pad_token_id": tokenizer.bos_token_id,
166
+ "max_new_tokens": 512,
167
+ **kwargs,
168
+ }
169
+
170
+ with torch.no_grad():
171
+ output_ids = self.text_model.generate(
172
+ inputs_embeds=inputs_embeds,
173
+ attention_mask=attention_mask,
174
+ **generate_config,
175
+ )
176
+
177
+ return [
178
+ x.strip()
179
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
180
+ ]
181
+ """
182
+
183
+ import torch
184
+ from .vision_encoder import VisionEncoder
185
+ from .configuration_moondream import MoondreamConfig
186
+ from transformers import PreTrainedModel, TextIteratorStreamer
187
+
188
+ from .modeling_phi import PhiForCausalLM
189
+ from .configuration_moondream import PhiConfig
190
+
191
+ class Moondream(PreTrainedModel):
192
+ config_class = MoondreamConfig
193
+ _supports_flash_attn_2 = True
194
+
195
+ def __init__(self, config):
196
+ super().__init__(config)
197
+ self.vision_encoder = VisionEncoder(
198
+ use_flash_attn=config._attn_implementation == "flash_attention_2"
199
+ )
200
+
201
+ if type(config.text_config) == dict:
202
+ phi_config = PhiConfig(
203
+ **config.text_config, attn_implementation=config._attn_implementation
204
+ )
205
+ else:
206
+ phi_config = config.text_config
207
+ self.text_model = PhiForCausalLM(phi_config)
208
+
209
+ @property
210
+ def device(self):
211
+ return self.text_model.device
212
+
213
+ def encode_image(self, image):
214
+ with torch.no_grad():
215
+ return self.vision_encoder(image)
216
+
217
+ def input_embeds(self, prompt, image_embeds, tokenizer):
218
+ def _tokenize(txt):
219
+ return tokenizer(
220
+ txt, return_tensors="pt", add_special_tokens=False
221
+ ).input_ids.to(self.device)
222
+
223
+ text_emb = self.text_model.get_input_embeddings()
224
+
225
+ # Add BOS token
226
+ embeds = []
227
+ embeds.append(
228
+ text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
229
+ )
230
+
231
+ if "<image>" not in prompt:
232
+ embeds.append(text_emb(_tokenize(prompt)))
233
+ else:
234
+ assert prompt.count("<image>") == 1
235
+ before, after = prompt.split("<image>")
236
+ if len(before) > 0:
237
+ embeds.append(text_emb(_tokenize(before)))
238
+ embeds.append(image_embeds.to(self.device))
239
+ if len(after) > 0:
240
+ embeds.append(text_emb(_tokenize(after)))
241
+
242
+ return torch.cat(embeds, dim=1)
243
+
244
+ def get_input_embeddings(self):
245
+ return self.text_model.get_input_embeddings()
246
+
247
+ def generate(
248
+ self,
249
+ image_embeds,
250
+ prompt,
251
+ tokenizer,
252
+ max_new_tokens=128,
253
+ **kwargs,
254
+ ):
255
+ generate_config = {
256
+ "eos_token_id": tokenizer.eos_token_id,
257
+ "bos_token_id": tokenizer.bos_token_id,
258
+ "pad_token_id": tokenizer.bos_token_id,
259
+ "max_new_tokens": max_new_tokens,
260
+ **kwargs,
261
+ }
262
+
263
+ with torch.no_grad():
264
+ inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
265
+ streamer = TextIteratorStreamer(tokenizer) #, timeout=10.0, skip_prompt=True, skip_special_tokens=True
266
+ output_ids = self.text_model.generate(
267
+ inputs_embeds=inputs_embeds, streamer=streamer, **generate_config
268
+ )
269
+
270
+ model_output = ""
271
+ for new_text in streamer:
272
+ model_output += new_text
273
+ print("NEWTEXT" + new_text)
274
+ yield new_text
275
 
276
  return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
277
 
 
364
  return [
365
  x.strip()
366
  for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
367
+ ]