Spaces:
Running
Running
from transformers import AutoTokenizer, TextStreamer | |
from PIL import Image | |
import torch | |
from threading import Thread | |
import gradio as gr | |
from gradio import FileData | |
import time | |
import spaces | |
from unsloth import FastVisionModel | |
# Load model and tokenizer | |
ckpt = "Daemontatox/DocumentLlama" | |
model, tokenizer = FastVisionModel.from_pretrained( | |
ckpt, | |
load_in_4bit=True, | |
use_gradient_checkpointing="unsloth", | |
) | |
# Enable inference mode | |
FastVisionModel.for_inference(model) | |
def bot_streaming(message, history, max_new_tokens=2048): | |
txt = message["text"] | |
messages = [] | |
images = [] | |
# Process history | |
for i, msg in enumerate(history): | |
if isinstance(msg[0], tuple): | |
messages.append({ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": history[i+1][0]}, | |
{"type": "image"} | |
] | |
}) | |
messages.append({ | |
"role": "assistant", | |
"content": [{"type": "text", "text": history[i+1][1]}] | |
}) | |
images.append(Image.open(msg[0][0]).convert("RGB")) | |
elif isinstance(history[i-1], tuple) and isinstance(msg[0], str): | |
pass | |
elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): | |
messages.append({ | |
"role": "user", | |
"content": [{"type": "text", "text": msg[0]}] | |
}) | |
messages.append({ | |
"role": "assistant", | |
"content": [{"type": "text", "text": msg[1]}] | |
}) | |
# Handle current message | |
if len(message["files"]) == 1: | |
if isinstance(message["files"][0], str): # examples | |
image = Image.open(message["files"][0]).convert("RGB") | |
else: # regular input | |
image = Image.open(message["files"][0]["path"]).convert("RGB") | |
images.append(image) | |
messages.append({ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": txt} | |
] | |
}) | |
else: | |
messages.append({ | |
"role": "user", | |
"content": [{"type": "text", "text": txt}] | |
}) | |
# Prepare inputs | |
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True) | |
if images: | |
inputs = tokenizer( | |
images[-1], # Use the last image | |
input_text, | |
add_special_tokens=False, | |
return_tensors="pt" | |
).to("cuda") | |
else: | |
inputs = tokenizer( | |
input_text, | |
add_special_tokens=False, | |
return_tensors="pt" | |
).to("cuda") | |
# Setup streaming | |
text_streamer = TextStreamer(tokenizer, skip_prompt=True) | |
buffer = "" | |
def generate(): | |
nonlocal buffer | |
output_ids = model.generate( | |
**inputs, | |
streamer=text_streamer, | |
max_new_tokens=max_new_tokens, | |
use_cache=True, | |
temperature=1.5, | |
min_p=0.1 | |
) | |
thread = Thread(target=generate) | |
thread.start() | |
for new_text in text_streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
# Setup Gradio interface | |
demo = gr.ChatInterface( | |
fn=bot_streaming, | |
title="Document Analyzer", | |
examples=[ | |
[{"text": "Which era does this piece belong to? Give details about the era.", "files":["./examples/rococo.jpg"]}, 200], | |
[{"text": "Where do the droughts happen according to this diagram?", "files":["./examples/weather_events.png"]}, 250], | |
[{"text": "What happens when you take out white cat from this chain?", "files":["./examples/ai2d_test.jpg"]}, 250], | |
[{"text": "How long does it take from invoice date to due date? Be short and concise.", "files":["./examples/invoice.png"]}, 250], | |
[{"text": "Where to find this monument? Can you give me other recommendations around the area?", "files":["./examples/wat_arun.jpg"]}, 250], | |
], | |
textbox=gr.MultimodalTextbox(), | |
additional_inputs=[ | |
gr.Slider( | |
minimum=10, | |
maximum=500, | |
value=2048, | |
step=10, | |
label="Maximum number of new tokens to generate", | |
) | |
], | |
cache_examples=False, | |
description="MllM", | |
stop_btn="Stop Generation", | |
fill_height=True, | |
multimodal=True | |
) | |
demo.launch(debug=True) |