iamrobotbear commited on
Commit
6ded388
·
1 Parent(s): 235b83d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
3
+ import torch
4
+ from PIL import Image
5
+
6
+ # Load the BLIP-2 model and processor
7
+ processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
8
+ model = Blip2ForConditionalGeneration.from_pretrained(
9
+ "Salesforce/blip2-opt-2.7b", device_map="auto", load_in_8bit=True
10
+ )
11
+
12
+ # Set device to GPU if available
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ def blip2_interface(image, prompted_caption_text, vqa_question, chat_context):
16
+ # Prepare image input
17
+ image_input = Image.fromarray(image).convert('RGB')
18
+ inputs = processor(image_input, return_tensors="pt").to(device, torch.float16)
19
+
20
+ # Image Captioning
21
+ generated_ids = model.generate(**inputs, max_new_tokens=20)
22
+ image_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
23
+
24
+ # Prompted Image Captioning
25
+ inputs = processor(image_input, text=prompted_caption_text, return_tensors="pt").to(device, torch.float16)
26
+ generated_ids = model.generate(**inputs, max_new_tokens=20)
27
+ prompted_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
28
+
29
+ # Visual Question Answering (VQA)
30
+ prompt = f"Question: {vqa_question} Answer:"
31
+ inputs = processor(image_input, text=prompt, return_tensors="pt").to(device, torch.float16)
32
+ generated_ids = model.generate(**inputs, max_new_tokens=10)
33
+ vqa_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
34
+
35
+ # Chat-based Prompting
36
+ prompt = chat_context + " Answer:"
37
+ inputs = processor(image_input, text=prompt, return_tensors="pt").to(device, torch.float16)
38
+ generated_ids = model.generate(**inputs, max_new_tokens=10)
39
+ chat_response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
40
+
41
+ return image_caption, prompted_caption, vqa_answer, chat_response
42
+
43
+ # Define Gradio input and output components
44
+ image_input = gr.inputs.Image(type="numpy")
45
+ text_input = gr.inputs.Text()
46
+ output_text = gr.outputs.Text()
47
+
48
+ # Create Gradio interface
49
+ iface = gr.Interface(
50
+