khang119966 commited on
Commit
42bde0f
·
verified ·
1 Parent(s): ba17d2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -1
app.py CHANGED
@@ -133,7 +133,312 @@ def load_image(image_file, input_size=448, max_num=12, target_aspect_ratio=False
133
  return pixel_values, target_aspect_ratio
134
  else:
135
  return pixel_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  model = AutoModel.from_pretrained(
138
  "khang119966/Vintern-1B-v3_5-explainableAI",
139
  torch_dtype=torch.bfloat16,
@@ -150,6 +455,9 @@ def generate_video(image, prompt, max_tokens):
150
  response, query = model.chat(tokenizer, pixel_values, '<image>\n'+prompt, generation_config, return_history=False, \
151
  attention_visualize=True,last_visualize_layers=7,raw_image_path=test_image,target_aspect_ratio=target_aspect_ratio)
152
  print(response)
 
 
 
153
  return "path_to_generated_video.mp4"
154
 
155
  with gr.Blocks() as demo:
@@ -157,7 +465,7 @@ with gr.Blocks() as demo:
157
 
158
  with gr.Row():
159
  with gr.Column():
160
- image = gr.Image(label="Upload your image")
161
  prompt = gr.Textbox(label="Describe your prompt", value="List all the text." )
162
  max_tokens = gr.Slider(label="Max token output (⚠️ Choose <100 for faster response)", minimum=1, maximum=512, value=50)
163
  btn = gr.Button("Attenion Video")
 
133
  return pixel_values, target_aspect_ratio
134
  else:
135
  return pixel_values
136
+
137
+ def visualize_attention_hiddenstate(attention_tensor, head=None, start_img_token_index=0, end_img_token_index=0, target_aspect_ratio=(0,0)):
138
+ """Vẽ heatmap của attention scores từ trung bình 8 layer cuối và trả về top 5 token có attention cao nhất."""
139
+ last_8_layers = attention_tensor[-8:] # Lấy 8 layer cuối
140
+ averaged_layer = np.mean(last_8_layers,axis=0) # Trung bình 8 layer cuối
141
+
142
+ if head is None:
143
+ averaged_attention = averaged_layer.mean(axis=1).squeeze() # Trung bình qua các head
144
+ else:
145
+ averaged_attention = averaged_layer[:, head, :, :].squeeze() # Chọn head cụ thể
146
+
147
+ heat_maps = []
148
+ top_5_tokens = []
149
+
150
+ for i in range(len(averaged_attention)): # Duyệt qua các beam
151
+ h_target_aspect_ratio = target_aspect_ratio[1] if target_aspect_ratio[1] != 0 else 1
152
+ w_target_aspect_ratio = target_aspect_ratio[0] if target_aspect_ratio[0] != 0 else 1
153
+
154
+ img_atten_score = averaged_attention[i].reshape(-1)[start_img_token_index:end_img_token_index]
155
+
156
+ # Lấy index của 5 token có attention cao nhất
157
+ top_5_indices = np.argsort(img_atten_score)[-5:][::-1] # Sắp xếp giảm dần
158
+ top_5_values = img_atten_score[top_5_indices]
159
+ # top_5_tokens.append(list(zip(top_5_indices + start_img_token_index, top_5_values)))
160
+ top_5_tokens.append(list(top_5_indices + start_img_token_index))
161
+
162
+
163
+ # Reshape lại attention để vẽ heatmap
164
+ img_atten_score = img_atten_score.reshape(h_target_aspect_ratio, w_target_aspect_ratio, 16, 16)
165
+ img_atten_score = np.transpose(img_atten_score, (0, 2, 1, 3)).reshape(h_target_aspect_ratio * 16, w_target_aspect_ratio * 16)
166
+
167
+ img_atten_score = np.power(img_atten_score, 0.9)
168
+ heat_maps.append(img_atten_score)
169
+
170
+ return heat_maps, top_5_tokens
171
+
172
+ def generate_next_token_table_image(model, tokenizer, response, index_focus):
173
+ next_token_table = []
174
+ for layer_index in range(len(response.hidden_states[index_focus])):
175
+ h_out = model.language_model.lm_head(
176
+ model.language_model.model.norm(response.hidden_states[index_focus][layer_index][0])
177
+ )
178
+ h_out = torch.softmax(h_out, -1)
179
+ top_tokens = []
180
+ for token_index in h_out.argsort(descending=True)[0, :3]: # Top 3
181
+ token_str = tokenizer.decode(token_index)
182
+ prob = float(h_out[0, int(token_index)])
183
+ top_tokens.append((token_str, prob))
184
+ next_token_table.append((layer_index, top_tokens))
185
+ next_token_table = next_token_table[::-1]
186
+
187
+ html_rows = ""
188
+ last_layer_index = len(next_token_table) - 1
189
+
190
+ for i, (layer_index, tokens) in enumerate(next_token_table):
191
+ row = f"<tr><td style='font-weight: bold'>Layer {layer_index}</td>"
192
 
193
+ # For the first column (Top 1)
194
+ token_str, prob = tokens[0]
195
+
196
+ # If this is the last layer in the table, make the text blue
197
+ if layer_index == last_layer_index:
198
+ row += f"<td><span style='color: red; font-weight: bold'>{token_str}</span> ({prob:.2%})</td>"
199
+ else:
200
+ row += f"<td><span style='color: blue; font-weight: bold'>{token_str}</span> ({prob:.2%})</td>"
201
+
202
+ # For the other columns, keep normal formatting
203
+ for token_str, prob in tokens[1:]:
204
+ row += f"<td>{token_str} ({prob:.2%})</td>"
205
+
206
+ row += "</tr>"
207
+ html_rows += row
208
+
209
+ html_code = f'''
210
+ <html>
211
+ <head>
212
+ <meta charset="utf-8">
213
+ <style>
214
+ table {{
215
+ font-family: 'Noto Sans';
216
+ font-size: 12px;
217
+ border-collapse: collapse;
218
+ table-layout: fixed;
219
+ width: 100%;
220
+ }}
221
+ th, td {{
222
+ border: 1px solid black;
223
+ padding: 8px;
224
+ width: 150px;
225
+ height: 30px;
226
+ overflow: hidden;
227
+ text-overflow: ellipsis;
228
+ white-space: nowrap;
229
+ text-align: center;
230
+ }}
231
+ th.layer {{
232
+ width: 100px;
233
+ }}
234
+ th.title {{
235
+ font-size: 14px;
236
+ padding: 10px;
237
+ height: auto;
238
+ white-space: normal;
239
+ overflow: visible;
240
+ }}
241
+ </style>
242
+ </head>
243
+ <body style="background-color: white;">
244
+ <table>
245
+ <tr>
246
+ <th colspan="4" class="title">
247
+ Top hidden tokens per layer for the Prediction
248
+ </th>
249
+ </tr>
250
+ <tr>
251
+ <th class="layer">Layer ⬆️</th>
252
+ <th>Top 1</th>
253
+ <th>Top 2</th>
254
+ <th>Top 3</th>
255
+ </tr>
256
+ {html_rows}
257
+ </table>
258
+ </body>
259
+ </html>
260
+ '''
261
+
262
+
263
+ with tempfile.TemporaryDirectory() as tmpdir:
264
+ hti = Html2Image(output_path=tmpdir)
265
+ hti.browser_flags = [
266
+ "--headless=new", # ← Dùng chế độ headless mới
267
+ "--disable-gpu", # ← Tắt GPU
268
+ "--disable-software-rasterizer", # ← Tránh dùng fallback GPU software
269
+ "--no-sandbox", # ← Tránh lỗi sandbox đa luồng
270
+ ]
271
+ filename = str(uuid.uuid4())+".png"
272
+ # filename = 'next_token_table.png'
273
+ hti.screenshot(html_str=html_code, save_as=filename, size=(500, 1000))
274
+ img_path = os.path.join(tmpdir, filename)
275
+ img_cv2 = cv2.imread(img_path)[:,:,::-1]
276
+ os.remove(img_path)
277
+ return img_cv2
278
+
279
+ def adjust_overlay(overlay, text_img):
280
+ h_o, w_o = overlay.shape[:2]
281
+ h_t, w_t = text_img.shape[:2]
282
+
283
+ if h_o > w_o: # Overlay là ảnh đứng
284
+ # Resize overlay sao cho h = h_t, giữ nguyên tỷ lệ
285
+ new_h = h_t
286
+ new_w = int(w_o * (new_h / h_o))
287
+ overlay_resized = cv2.resize(overlay, (new_w, new_h))
288
+ else: # Overlay là ảnh ngang
289
+ # Giữ nguyên overlay, nhưng nếu h < h_t thì thêm padding trắng
290
+ overlay_resized = overlay.copy()
291
+
292
+ # Thêm padding trắng nếu overlay có h < h_t
293
+ if overlay_resized.shape[0] < h_t:
294
+ pad_h = h_t - overlay_resized.shape[0]
295
+ padding = np.ones((pad_h, overlay_resized.shape[1], 3), dtype=np.uint8) * 255
296
+ overlay_resized = np.vstack((overlay_resized, padding)) # Padding vào dưới
297
+
298
+ # Đảm bảo overlay có cùng chiều cao với text_img
299
+ if overlay_resized.shape[0] != h_t:
300
+ overlay_resized = cv2.resize(overlay_resized, (overlay_resized.shape[1], h_t))
301
+
302
+ return overlay_resized
303
+
304
+ def generate_text_image_with_html2image(old_text, input_token, new_token, image_width=400, min_height=1000, font_size=16):
305
+ full_text = old_text + f"<span style='color:blue; font-weight:bold'>[{input_token}]</span>"+ "→" + f"<span style='color:red; font-weight:bold'>[{new_token}]</span>"
306
+
307
+ # Thay \n bằng thẻ HTML <br> để xuống dòng
308
+ full_text = full_text.replace('\n', '<br>')
309
+
310
+ html_code = f'''
311
+ <html>
312
+ <head>
313
+ <meta charset="utf-8">
314
+ </head>
315
+ <body style="font-family: 'DejaVu Sans', sans-serif; font-size: {font_size}px; width: {image_width}px; min-height: {min_height}px; padding: 10px; background-color: white; line-height: 1.4;">
316
+ {full_text}
317
+ </body>
318
+ </html>
319
+ '''
320
+ save_path = str(uuid.uuid4())+".png"
321
+ hti = Html2Image(output_path='.')
322
+ hti.browser_flags = [
323
+ "--headless=new", # ← Dùng chế độ headless mới
324
+ "--disable-gpu", # ← Tắt GPU
325
+ "--disable-software-rasterizer", # ← Tránh dùng fallback GPU software
326
+ "--no-sandbox", # ← Tránh lỗi sandbox đa luồng
327
+ ]
328
+ hti.screenshot(html_str=html_code, save_as=save_path, size=(image_width, min_height))
329
+ text_img = cv2.imread(save_path)
330
+ text_img = cv2.cvtColor(text_img, cv2.COLOR_BGR2RGB)
331
+ os.remove(save_path)
332
+ return text_img
333
+
334
+ def extract_next_token_table_data(model, tokenizer, response, index_focus):
335
+ next_token_table = []
336
+ for layer_index in range(len(response.hidden_states[index_focus])):
337
+ h_out = model.language_model.lm_head(
338
+ model.language_model.model.norm(response.hidden_states[index_focus][layer_index][0])
339
+ )
340
+ h_out = torch.softmax(h_out, -1)
341
+ top_tokens = []
342
+ for token_index in h_out.argsort(descending=True)[0, :3]: # Top 3
343
+ token_str = tokenizer.decode(token_index)
344
+ prob = float(h_out[0, int(token_index)])
345
+ top_tokens.append((token_str, prob))
346
+ next_token_table.append((layer_index, top_tokens))
347
+ next_token_table = next_token_table[::-1]
348
+ return next_token_table
349
+
350
+ def render_next_token_table_image(table_data, predict_token):
351
+ import tempfile, uuid, os
352
+ from html2image import Html2Image
353
+ import cv2
354
+
355
+ html_rows = ""
356
+ last_layer_index = len(table_data)
357
+ for layer_index, tokens in table_data:
358
+ row = f"<tr><td style='font-weight: bold'>Layer {layer_index+1}</td>"
359
+
360
+ token_str, prob = tokens[0]
361
+ if token_str == predict_token:
362
+ style = "color: red; font-weight: bold"
363
+ else:
364
+ style = "color: blue; font-weight: bold"
365
+ row += f"<td><span style='{style}'>{token_str}</span> ({prob:.2%})</td>"
366
+
367
+ for token_str, prob in tokens[1:]:
368
+ row += f"<td>{token_str} ({prob:.2%})</td>"
369
+
370
+ row += "</tr>"
371
+ html_rows += row
372
+
373
+ html_code = f'''
374
+ <html>
375
+ <head>
376
+ <meta charset="utf-8">
377
+ <style>
378
+ table {{
379
+ font-family: 'Noto Sans';
380
+ font-size: 12px;
381
+ border-collapse: collapse;
382
+ table-layout: fixed;
383
+ width: 100%;
384
+ }}
385
+ th, td {{
386
+ border: 1px solid black;
387
+ padding: 8px;
388
+ width: 150px;
389
+ height: 30px;
390
+ overflow: hidden;
391
+ text-overflow: ellipsis;
392
+ white-space: nowrap;
393
+ text-align: center;
394
+ }}
395
+ th.layer {{
396
+ width: 100px;
397
+ }}
398
+ th.title {{
399
+ font-size: 14px;
400
+ padding: 10px;
401
+ height: auto;
402
+ white-space: normal;
403
+ overflow: visible;
404
+ }}
405
+ </style>
406
+ </head>
407
+ <body style="background-color: white;">
408
+ <table>
409
+ <tr>
410
+ <th colspan="4" class="title">
411
+ Hidden states per Transformer layer (LLM) for Prediction
412
+ </th>
413
+ </tr>
414
+ <tr>
415
+ <th class="layer">Layer ⬆️</th>
416
+ <th>Top 1</th>
417
+ <th>Top 2</th>
418
+ <th>Top 3</th>
419
+ </tr>
420
+ {html_rows}
421
+ </table>
422
+ </body>
423
+ </html>
424
+ '''
425
+
426
+ with tempfile.TemporaryDirectory() as tmpdir:
427
+ hti = Html2Image(output_path=tmpdir)
428
+ hti.browser_flags = [
429
+ "--headless=new",
430
+ "--disable-gpu",
431
+ "--disable-software-rasterizer",
432
+ "--no-sandbox",
433
+ ]
434
+ filename = str(uuid.uuid4()) + ".png"
435
+ hti.screenshot(html_str=html_code, save_as=filename, size=(500, 1000))
436
+ img_path = os.path.join(tmpdir, filename)
437
+ img_cv2 = cv2.imread(img_path)[:, :, ::-1]
438
+ os.remove(img_path)
439
+ return img_cv2
440
+
441
+
442
  model = AutoModel.from_pretrained(
443
  "khang119966/Vintern-1B-v3_5-explainableAI",
444
  torch_dtype=torch.bfloat16,
 
455
  response, query = model.chat(tokenizer, pixel_values, '<image>\n'+prompt, generation_config, return_history=False, \
456
  attention_visualize=True,last_visualize_layers=7,raw_image_path=test_image,target_aspect_ratio=target_aspect_ratio)
457
  print(response)
458
+ generation_output = response
459
+ raw_image_path = image
460
+
461
  return "path_to_generated_video.mp4"
462
 
463
  with gr.Blocks() as demo:
 
465
 
466
  with gr.Row():
467
  with gr.Column():
468
+ image = gr.Image(label="Upload your image", type = 'filepath')
469
  prompt = gr.Textbox(label="Describe your prompt", value="List all the text." )
470
  max_tokens = gr.Slider(label="Max token output (⚠️ Choose <100 for faster response)", minimum=1, maximum=512, value=50)
471
  btn = gr.Button("Attenion Video")