import gradio as gr from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer from transformers.image_utils import load_image from threading import Thread import torch import spaces from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown MODEL_ID = "TIGER-Lab/VL-Rethinker-7B" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.bfloat16 ).to("cuda").eval() @spaces.GPU def model_inference(input_dict, history): text = input_dict["text"] files = input_dict["files"] """ Create chat history Example history value: [ [('pixel.png',), None], ['ignore this image. just say "hi" and nothing else', 'Hi!'], ['just say "hi" and nothing else', 'Hi!'] ] """ all_images = [] current_message_images = [] messages = [] for val in history: if val[0]: if isinstance(val[0], str): messages.append({ "role": "user", "content": [ *[{"type": "image", "image": image} for image in current_message_images], {"type": "text", "text": val[0]}, ], }) current_message_images = [] else: # Load messages. These will be appended to the first user text message that comes after current_message_images = [load_image(image) for image in val[0]] all_images += current_message_images if val[1]: messages.append({"role": "assistant", "content": val[1]}) current_message_images = [load_image(image) for image in files] all_images += current_message_images messages.append({ "role": "user", "content": [ *[{"type": "image", "image": image} for image in current_message_images], {"type": "text", "text": text}, ], }) #print(messages) """ Generate and stream text """ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[prompt], images=all_images if all_images else None, return_tensors="pt", padding=True, ).to("cuda") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer with gr.Blocks() as demo: examples = [ [{"text": "Solve this question.", "files": ["example_images/document.png"]}] ] gr.HTML(html_header) gr.ChatInterface( fn=model_inference, description="# **VL-Rethinker-7B**", examples=examples, fill_height=True, textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True, cache_examples=False, ) gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) gr.Markdown(bibtext) demo.launch(debug=True)