|
import gradio as gr |
|
import torch |
|
from transformers import AutoProcessor, Blip2ForConditionalGeneration |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl" |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL) |
|
model = Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_FLAN_T5_XXL, load_in_8bit=True).to(device) |
|
|
|
|
|
def generate_text(image, text, decoding_method, temperature, length_penalty, repetition_penalty): |
|
if text.startswith("Caption:"): |
|
|
|
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) |
|
generated_ids = model.generate( |
|
pixel_values=inputs.pixel_values, |
|
do_sample=decoding_method == "Nucleus sampling", |
|
temperature=temperature, |
|
length_penalty=length_penalty, |
|
repetition_penalty=repetition_penalty, |
|
max_length=50, |
|
min_length=1, |
|
num_beams=5, |
|
top_p=0.9, |
|
) |
|
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
|
return result |
|
else: |
|
|
|
inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16) |
|
generated_ids = model.generate( |
|
**inputs, |
|
do_sample=decoding_method == "Nucleus sampling", |
|
temperature=temperature, |
|
length_penalty=length_penalty, |
|
repetition_penalty=repetition_penalty, |
|
max_length=30, |
|
min_length=1, |
|
num_beams=5, |
|
top_p=0.9, |
|
) |
|
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
|
return result |
|
|
|
|
|
image_input = gr.Image(type="numpy") |
|
text_input = gr.Text() |
|
output_text = gr.outputs.Textbox() |
|
|
|
|
|
gr.Interface( |
|
fn=generate_text, |
|
inputs=[image_input, text_input, gr.inputs.Radio(["Beam search", "Nucleus sampling"]), gr.inputs.Slider(0.5, 1.0, 0.1), gr.inputs.Slider(-1.0, 2.0, 0.2), gr.inputs.Slider(1.0, 5.0, 0.5)], |
|
outputs=output_text, |
|
examples=[ |
|
["house.png", "Caption:"], |
|
["flower.jpg", "What is this flower and where is its origin?"], |
|
["pizza.jpg", "Caption:"], |
|
["sunset.jpg", "Caption:"], |
|
["forbidden_city.webp", "In what dynasties was this place built?"], |
|
], |
|
title="BLIP-2", |
|
description="Gradio demo for BLIP-2, image-to-text generation from Salesforce Research.", |
|
).launch() |
|
|