iamrobotbear commited on
Commit
f69bea2
·
1 Parent(s): 5053a56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -59
app.py CHANGED
@@ -1,69 +1,57 @@
1
  import gradio as gr
2
- import torch
3
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
 
 
4
 
5
- # Check if CUDA is available
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
7
 
8
- # Model ID
9
- MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
 
10
 
11
- # Load the model and processor
12
- processor = AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL)
13
- model = Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_FLAN_T5_XXL, load_in_8bit=True).to(device)
 
 
 
 
 
14
 
15
- # Define a function for generating captions and answering questions
16
- def generate_text(image, text, decoding_method, temperature, length_penalty, repetition_penalty):
17
- if text.startswith("Caption:"):
18
- # Generate caption
19
- inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
20
- generated_ids = model.generate(
21
- pixel_values=inputs.pixel_values,
22
- do_sample=decoding_method == "Nucleus sampling",
23
- temperature=temperature,
24
- length_penalty=length_penalty,
25
- repetition_penalty=repetition_penalty,
26
- max_length=50,
27
- min_length=1,
28
- num_beams=5,
29
- top_p=0.9,
30
- )
31
- result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
32
- return result
33
- else:
34
- # Answer question
35
- inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16)
36
- generated_ids = model.generate(
37
- **inputs,
38
- do_sample=decoding_method == "Nucleus sampling",
39
- temperature=temperature,
40
- length_penalty=length_penalty,
41
- repetition_penalty=repetition_penalty,
42
- max_length=30,
43
- min_length=1,
44
- num_beams=5,
45
- top_p=0.9,
46
- )
47
- result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
48
- return result
49
 
50
  # Define Gradio input and output components
51
- image_input = gr.Image(type="numpy")
52
- text_input = gr.Text()
53
  output_text = gr.outputs.Textbox()
54
 
55
- # Define Gradio interface
56
- gr.Interface(
57
- fn=generate_text,
58
- inputs=[image_input, text_input, gr.inputs.Radio(["Beam search", "Nucleus sampling"]), gr.inputs.Slider(0.5, 1.0, 0.1), gr.inputs.Slider(-1.0, 2.0, 0.2), gr.inputs.Slider(1.0, 5.0, 0.5)],
59
- outputs=output_text,
60
- examples=[
61
- ["house.png", "Caption:"],
62
- ["flower.jpg", "What is this flower and where is its origin?"],
63
- ["pizza.jpg", "Caption:"],
64
- ["sunset.jpg", "Caption:"],
65
- ["forbidden_city.webp", "In what dynasties was this place built?"],
66
- ],
67
- title="BLIP-2",
68
- description="Gradio demo for BLIP-2, image-to-text generation from Salesforce Research.",
69
- ).launch()
 
1
  import gradio as gr
 
2
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
3
+ import torch
4
+ from PIL import Image
5
 
6
+ # Load the BLIP-2 model and processor
7
+ processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
8
+ model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
9
 
10
+ # Set device to GPU if available
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model.to(device)
13
 
14
+ def blip2_interface(image, prompted_caption_text, vqa_question, chat_context):
15
+ # Prepare image input
16
+ image_input = Image.fromarray(image).convert('RGB')
17
+ inputs = processor(image_input, return_tensors="pt").to(device, torch.float16)
18
+
19
+ # Image Captioning
20
+ generated_ids = model.generate(**inputs, max_new_tokens=20)
21
+ image_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
22
 
23
+ # Prompted Image Captioning
24
+ inputs = processor(image_input, text=prompted_caption_text, return_tensors="pt").to(device, torch.float16)
25
+ generated_ids = model.generate(**inputs, max_new_tokens=20)
26
+ prompted_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
27
+
28
+ # Visual Question Answering (VQA)
29
+ prompt = f"Question: {vqa_question} Answer:"
30
+ inputs = processor(image_input, text=prompt, return_tensors="pt").to(device, torch.float16)
31
+ generated_ids = model.generate(**inputs, max_new_tokens=10)
32
+ vqa_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
33
+
34
+ # Chat-based Prompting
35
+ prompt = chat_context + " Answer:"
36
+ inputs = processor(image_input, text=prompt, return_tensors="pt").to(device, torch.float16)
37
+ generated_ids = model.generate(**inputs, max_new_tokens=10)
38
+ chat_response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
39
+
40
+ return image_caption, prompted_caption, vqa_answer, chat_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Define Gradio input and output components
43
+ image_input = gr.inputs.Image(type="numpy")
44
+ text_input = gr.inputs.Text()
45
  output_text = gr.outputs.Textbox()
46
 
47
+ # Create Gradio interface
48
+ iface = gr.Interface(
49
+ fn=blip2_interface,
50
+ inputs=[image_input, text_input, text_input, text_input],
51
+ outputs=[output_text, output_text, output_text, output_text],
52
+ title="BLIP-2 Image Captioning and VQA",
53
+ description="Interact with the BLIP-2 model for image captioning, prompted image captioning, visual question answering, and chat-based prompting.",
54
+ )
55
+
56
+ if __name__ == "__main__":
57
+ iface.launch()