Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer | |
from transformers.image_utils import load_image | |
from threading import Thread | |
import torch | |
import spaces | |
from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown | |
MODEL_ID = "TIGER-Lab/VL-Rethinker-7B" | |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = AutoModelForImageTextToText.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 | |
).to("cuda").eval() | |
def model_inference(input_dict, history): | |
text = input_dict["text"] | |
files = input_dict["files"] | |
""" | |
Create chat history | |
Example history value: | |
[ | |
[('pixel.png',), None], | |
['ignore this image. just say "hi" and nothing else', 'Hi!'], | |
['just say "hi" and nothing else', 'Hi!'] | |
] | |
""" | |
all_images = [] | |
current_message_images = [] | |
messages = [] | |
for val in history: | |
if val[0]: | |
if isinstance(val[0], str): | |
messages.append({ | |
"role": "user", | |
"content": [ | |
*[{"type": "image", "image": image} for image in current_message_images], | |
{"type": "text", "text": val[0]}, | |
], | |
}) | |
current_message_images = [] | |
else: | |
# Load messages. These will be appended to the first user text message that comes after | |
current_message_images = [load_image(image) for image in val[0]] | |
all_images += current_message_images | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
current_message_images = [load_image(image) for image in files] | |
all_images += current_message_images | |
messages.append({ | |
"role": "user", | |
"content": [ | |
*[{"type": "image", "image": image} for image in current_message_images], | |
{"type": "text", "text": text}, | |
], | |
}) | |
#print(messages) | |
""" | |
Generate and stream text | |
""" | |
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = processor( | |
text=[prompt], | |
images=all_images if all_images else None, | |
return_tensors="pt", | |
padding=True, | |
).to("cuda") | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
yield buffer | |
with gr.Blocks() as demo: | |
examples = [ | |
[{"text": "Solve this question.", "files": ["example_images/document.png"]}] | |
] | |
gr.HTML(html_header) | |
gr.ChatInterface( | |
fn=model_inference, | |
description="# **VL-Rethinker-7B**", | |
examples=examples, | |
fill_height=True, | |
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), | |
stop_btn="Stop Generation", | |
multimodal=True, | |
cache_examples=False, | |
) | |
gr.Markdown(tos_markdown) | |
gr.Markdown(learn_more_markdown) | |
gr.Markdown(bibtext) | |
demo.launch(debug=True) |