blip-vqa-gradio / app.py
iamrobotbear's picture
this is a total fucking mess.
5053a56
raw
history blame
2.73 kB
import gradio as gr
import torch
from transformers import AutoProcessor, Blip2ForConditionalGeneration
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model ID
MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
# Load the model and processor
processor = AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL)
model = Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_FLAN_T5_XXL, load_in_8bit=True).to(device)
# Define a function for generating captions and answering questions
def generate_text(image, text, decoding_method, temperature, length_penalty, repetition_penalty):
if text.startswith("Caption:"):
# Generate caption
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(
pixel_values=inputs.pixel_values,
do_sample=decoding_method == "Nucleus sampling",
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_length=50,
min_length=1,
num_beams=5,
top_p=0.9,
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return result
else:
# Answer question
inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(
**inputs,
do_sample=decoding_method == "Nucleus sampling",
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_length=30,
min_length=1,
num_beams=5,
top_p=0.9,
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return result
# Define Gradio input and output components
image_input = gr.Image(type="numpy")
text_input = gr.Text()
output_text = gr.outputs.Textbox()
# Define Gradio interface
gr.Interface(
fn=generate_text,
inputs=[image_input, text_input, gr.inputs.Radio(["Beam search", "Nucleus sampling"]), gr.inputs.Slider(0.5, 1.0, 0.1), gr.inputs.Slider(-1.0, 2.0, 0.2), gr.inputs.Slider(1.0, 5.0, 0.5)],
outputs=output_text,
examples=[
["house.png", "Caption:"],
["flower.jpg", "What is this flower and where is its origin?"],
["pizza.jpg", "Caption:"],
["sunset.jpg", "Caption:"],
["forbidden_city.webp", "In what dynasties was this place built?"],
],
title="BLIP-2",
description="Gradio demo for BLIP-2, image-to-text generation from Salesforce Research.",
).launch()