|
from threading import Thread |
|
from typing import Dict |
|
|
|
import gradio as gr |
|
import spaces |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer |
|
|
|
|
|
TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.2</center></h1>" |
|
|
|
DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/BUAADreamer/PaliGemma-3B-Chat-v0.2' target='_blank'>our model page</a> for details.</center></h3>" |
|
|
|
CSS = """ |
|
.duplicate-button { |
|
margin: auto !important; |
|
color: white !important; |
|
background: black !important; |
|
border-radius: 100vh !important; |
|
} |
|
""" |
|
|
|
|
|
model_id = "BUAADreamer/PaliGemma-3B-Chat-v0.2" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto") |
|
|
|
|
|
@spaces.GPU |
|
def stream_chat(message: Dict[str, str], history: list): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_path = None |
|
if len(message["files"]) != 0: |
|
image_path = message["files"][0] |
|
|
|
if len(history) != 0 and isinstance(history[0][0], tuple): |
|
image_path = history[0][0][0] |
|
history = history[1:] |
|
|
|
if image_path is not None: |
|
image = Image.open(image_path).convert("RGB") |
|
else: |
|
image = Image.new("RGB", (100, 100), (255, 255, 255)) |
|
|
|
pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"] |
|
|
|
conversation = [] |
|
for prompt, answer in history: |
|
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}]) |
|
|
|
conversation.append({"role": "user", "content": message["text"]}) |
|
|
|
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") |
|
image_token_id = tokenizer.convert_tokens_to_ids("<image>") |
|
image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id) |
|
input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device) |
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
|
|
generate_kwargs = dict( |
|
input_ids=input_ids, |
|
pixel_values=pixel_values, |
|
streamer=streamer, |
|
max_new_tokens=256, |
|
do_sample=True, |
|
) |
|
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
output = "" |
|
for new_token in streamer: |
|
output += new_token |
|
yield output |
|
|
|
|
|
chatbot = gr.Chatbot(height=450) |
|
|
|
with gr.Blocks(css=CSS) as demo: |
|
gr.HTML(TITLE) |
|
gr.HTML(DESCRIPTION) |
|
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") |
|
gr.ChatInterface( |
|
fn=stream_chat, |
|
multimodal=True, |
|
chatbot=chatbot, |
|
fill_height=True, |
|
cache_examples=False, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|