BLIP-SMILE / app.py
yuezih
init
ca19ab4
raw
history blame
2.54 kB
import gradio as gr
from PIL import Image
import requests
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import sys
sys.path.append('SMILE/BLIP')
from models.model import caption_model
image_size = 384
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
model_url = {
'smile': 'model/blip_smile_base.pth',
'mle_smile': 'model/blip_mle_smile_base.pth',
}
model_smile = caption_model(pretrained=model_url['smile'], image_size=image_size, vit='base')
model_smile.eval()
model_smile = model_smile.to(device)
model_mle_smile = caption_model(pretrained=model_url['mle_smile'], image_size=image_size, vit='base')
model_mle_smile.eval()
model_mle_smile = model_mle_smile.to(device)
def generate_caption(raw_image, strategy):
image = transform(raw_image).unsqueeze(0).to(device)
with torch.no_grad():
if strategy == "More Descriptive":
caption = model_smile.generate(image, sample=False, num_beams=3, max_length=75, min_length=1)
else:
caption = model_mle_smile.generate(image, sample=False, num_beams=3, max_length=75, min_length=1)
return str(caption[0]).replace(' - ', '-').lower() + '.'
inputs = [
gr.Image(type="pil"),
gr.Radio(choices=["More Descriptive", "More Accurate"], default="More Descriptive", label="Strategy")
]
outputs = "text"
examples = [
["example/COCO_val2014_000000093534.jpg", "More Descriptive"],
["example/COCO_val2014_000000411845.jpg", "More Descriptive"],
["example/COCO_val2014_000000001682.jpg", "More Descriptive"],
["example/COCO_val2014_000000473133.jpg", "More Descriptive"],
["example/COCO_val2014_000000562150.jpg", "More Descriptive"]
]
description = """<p style='text-align: center'>Gradio demo for BLIP-SMILE: The most descriptive captioning model before the multimodal LLM era.</p><p style='text-align: center'><a href='https://arxiv.org/abs/2306.13460' target='_blank'>Paper</a> | <a href='https://github.com/yuezih/SMILE' target='_blank'>Github</a></p>"""
interface = gr.Interface(
generate_caption,
inputs,
outputs,
examples=examples,
title="BLIP-SMILE",
description=description,
allow_flagging='never',
)
interface.launch(share=True)