|
import os |
|
import argparse |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms as T |
|
from transformers import AutoTokenizer |
|
import gradio as gr |
|
from resnet50 import build_model |
|
from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50 |
|
from utils import IMAGENET_MEAN, IMAGENET_STD |
|
from internvl.train.dataset import dynamic_preprocess |
|
from internvl.model.internvl_chat import InternVLChatModel |
|
import spaces |
|
|
|
|
|
CHECKPOINTS = { |
|
"TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg", |
|
"TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg", |
|
} |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
current_vis = [] |
|
current_bpe = [] |
|
current_index = 0 |
|
|
|
|
|
def load_model(check_type): |
|
|
|
device = torch.device("cuda") |
|
if check_type == 'R50': |
|
tokenizer = load_tokenizer('tokenizer_path') |
|
model = build_model(argparse.Namespace()).eval() |
|
model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model']) |
|
transform = build_transform_R50(normalize_type='imagenet') |
|
|
|
elif check_type == 'R50_siglip': |
|
tokenizer = load_tokenizer('tokenizer_path') |
|
model = build_model(argparse.Namespace()).eval() |
|
model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model']) |
|
transform = build_transform_R50(normalize_type='imagenet') |
|
|
|
elif 'TokenFD' in check_type: |
|
model_path = CHECKPOINTS[check_type] |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN) |
|
model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval() |
|
transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB')), |
|
T.Resize((224, 224)), |
|
T.ToTensor(), |
|
T.Normalize(IMAGENET_MEAN, IMAGENET_STD) |
|
]) |
|
|
|
return model.to(device), tokenizer, transform, device |
|
|
|
def process_image(model, tokenizer, transform, device, check_type, image, text): |
|
global current_vis, current_bpe, current_index |
|
src_size = image.size |
|
if 'TokenOCR' in check_type: |
|
images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12, |
|
image_size=model.config.force_image_size, |
|
use_thumbnail=model.config.use_thumbnail, |
|
return_ratio=True) |
|
pixel_values = torch.stack([transform(img) for img in images]).to(device) |
|
else: |
|
pixel_values = torch.stack([transform(image)]).to(device) |
|
target_ratio = (1, 1) |
|
|
|
|
|
text += ' ' |
|
input_ids = tokenizer(text)['input_ids'][1:] |
|
input_ids = torch.tensor(input_ids, device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
if 'R50' in check_type: |
|
text_embeds = model.language_embedding(input_ids) |
|
else: |
|
text_embeds = model.tok_embeddings(input_ids) |
|
|
|
vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device)) |
|
print("vit_embeds",vit_embeds) |
|
print("vit_embeds,shape",vit_embeds.shape) |
|
print("target_ratio",target_ratio) |
|
print("check_type",check_type) |
|
vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type) |
|
|
|
|
|
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) |
|
vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True) |
|
similarity = text_embeds @ vit_embeds.T |
|
resized_size = size1 if size1 is not None else size2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1]) |
|
|
|
all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids] |
|
current_vis = generate_similiarity_map([image], attn_map, |
|
[tokenizer.decode([i]) for i in input_ids], |
|
[], target_ratio, src_size) |
|
|
|
current_bpe = [tokenizer.decode([i]) for i in input_ids] |
|
|
|
current_bpe[-1] = text |
|
print("current_vis",len(current_vis)) |
|
print("current_bpe",len(current_bpe)) |
|
return image, current_vis[0], current_bpe[0] |
|
|
|
|
|
def update_index(change): |
|
global current_vis, current_bpe, current_index |
|
current_index = max(0, min(len(current_vis) - 1, current_index + change)) |
|
return current_vis[current_index], format_bpe_display(current_bpe[current_index]) |
|
|
|
def format_bpe_display(bpe): |
|
|
|
return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>" |
|
|
|
|
|
with gr.Blocks(title="BPE Visualization Demo") as demo: |
|
gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.5): |
|
model_type = gr.Dropdown( |
|
choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"], |
|
label="Select model type", |
|
value="TokenOCR_4096_English_seg" |
|
) |
|
image_input = gr.Image(label="Upload images", type="pil") |
|
text_input = gr.Textbox(label="Input text") |
|
|
|
run_btn = gr.Button("RUN") |
|
|
|
gr.Examples( |
|
examples=[ |
|
[os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"], |
|
[os.path.join("examples", "examples1.jpg"), "Refreshers"], |
|
[os.path.join("examples", "examples2.png"), "Vision Transformer"] |
|
], |
|
inputs=[image_input, text_input], |
|
label="Sample input" |
|
) |
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.</p>") |
|
|
|
with gr.Row(): |
|
orig_img = gr.Image(label="Original picture", interactive=False) |
|
heatmap = gr.Image(label="BPE visualization", interactive=False) |
|
|
|
with gr.Row() as controls: |
|
prev_btn = gr.Button("⬅ Last", visible=False) |
|
index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False) |
|
next_btn = gr.Button("⮕ Next", visible=False) |
|
|
|
bpe_display = gr.Markdown("Current BPE: ", visible=False) |
|
|
|
|
|
@spaces.GPU |
|
def on_run_clicked(model_type, image, text): |
|
global current_vis, current_bpe, current_index |
|
current_index = 0 |
|
image, vis, bpe = process_image(*load_model(model_type), model_type, image, text) |
|
|
|
slider_max_val = len(current_bpe) - 1 |
|
bpe_text = format_bpe_display(bpe) |
|
print("current_vis",len(current_vis)) |
|
print("current_bpe",len(current_bpe)) |
|
return image, vis, bpe_text, slider_max_val |
|
|
|
run_btn.click( |
|
on_run_clicked, |
|
inputs=[model_type, image_input, text_input], |
|
outputs=[orig_img, heatmap, bpe_display, index_slider], |
|
).then( |
|
lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)), |
|
inputs=index_slider, |
|
outputs=[prev_btn, index_slider, next_btn, bpe_display], |
|
) |
|
|
|
prev_btn.click( |
|
lambda: (*update_index(-1), current_index), |
|
outputs=[heatmap, bpe_display, index_slider] |
|
) |
|
|
|
next_btn.click( |
|
lambda: (*update_index(1), current_index), |
|
outputs=[heatmap, bpe_display, index_slider] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index_slider.change( |
|
lambda x: ( |
|
print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}"), |
|
(current_vis[x], format_bpe_display(current_bpe[x])) if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe) else (None, "索引超出范围") |
|
), |
|
inputs=index_slider, |
|
outputs=[heatmap, bpe_display] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |