Llava-qw / app.py
torettomarui's picture
Update app.py
1c8b3d8 verified
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from Models.modeling_llavaqw import LlavaQwModel
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
model_name = "torettomarui/Llava-qw"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
model = LlavaQwModel.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(torch.bfloat16).eval().cuda()
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def preprocess_image(file_path, image_size=448):
transform = build_transform(image_size)
pixel_values = transform(file_path)
return torch.stack([pixel_values]).to(torch.bfloat16).cuda()
def generate_response(image, text):
pixel_values = preprocess_image(image)
generation_config = dict(max_new_tokens=2048, do_sample=False)
question = '<image>\n' + text
response = model.chat(tokenizer, pixel_values, question, generation_config)
return response
# 添加示例图像和文本
examples = [
["./text.png", "图中的文字是什么?"],
]
iface = gr.Interface(
fn=generate_response,
inputs=[
gr.Image(type="pil", label="上传图片"),
gr.Textbox(lines=2, placeholder="输入你的问题..."),
],
outputs="text",
title="Llava-QW",
description="上传一张图片并输入你的问题,模型将生成相应的回答。",
examples=examples # 添加示例
)
iface.launch()