iamrobotbear commited on
Commit
5053a56
·
1 Parent(s): 3bc78d3

this is a total fucking mess.

Browse files
Files changed (1) hide show
  1. app.py +59 -47
app.py CHANGED
@@ -1,57 +1,69 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, Blip2ForConditionalGeneration
3
  import torch
4
- from PIL import Image
5
-
6
- # Load the BLIP-2 model and processor
7
- processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
8
- model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
9
 
10
- # Set device to GPU if available
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model.to(device)
13
 
14
- def blip2_interface(image, prompted_caption_text, vqa_question, chat_context):
15
- # Prepare image input
16
- image_input = Image.fromarray(image).convert('RGB')
17
- inputs = processor(image_input, return_tensors="pt").to(device, torch.float16)
18
-
19
- # Image Captioning
20
- generated_ids = model.generate(**inputs, max_new_tokens=20)
21
- image_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
22
 
23
- # Prompted Image Captioning
24
- inputs = processor(image_input, text=prompted_caption_text, return_tensors="pt").to(device, torch.float16)
25
- generated_ids = model.generate(**inputs, max_new_tokens=20)
26
- prompted_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
27
-
28
- # Visual Question Answering (VQA)
29
- prompt = f"Question: {vqa_question} Answer:"
30
- inputs = processor(image_input, text=prompt, return_tensors="pt").to(device, torch.float16)
31
- generated_ids = model.generate(**inputs, max_new_tokens=10)
32
- vqa_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
33
-
34
- # Chat-based Prompting
35
- prompt = chat_context + " Answer:"
36
- inputs = processor(image_input, text=prompt, return_tensors="pt").to(device, torch.float16)
37
- generated_ids = model.generate(**inputs, max_new_tokens=10)
38
- chat_response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
39
 
40
- return image_caption, prompted_caption, vqa_answer, chat_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Define Gradio input and output components
43
- image_input = gr.inputs.Image(type="numpy")
44
- text_input = gr.inputs.Text()
45
  output_text = gr.outputs.Textbox()
46
 
47
- # Create Gradio interface
48
- iface = gr.Interface(
49
- fn=blip2_interface,
50
- inputs=[image_input, text_input, text_input, text_input],
51
- outputs=[output_text, output_text, output_text, output_text],
52
- title="BLIP-2 Image Captioning and VQA",
53
- description="Interact with the BLIP-2 model for image captioning, prompted image captioning, visual question answering, and chat-based prompting.",
54
- )
55
-
56
- if __name__ == "__main__":
57
- iface.launch()
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
 
 
 
 
4
 
5
+ # Check if CUDA is available
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
7
 
8
+ # Model ID
9
+ MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
 
 
 
 
 
 
10
 
11
+ # Load the model and processor
12
+ processor = AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL)
13
+ model = Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_FLAN_T5_XXL, load_in_8bit=True).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Define a function for generating captions and answering questions
16
+ def generate_text(image, text, decoding_method, temperature, length_penalty, repetition_penalty):
17
+ if text.startswith("Caption:"):
18
+ # Generate caption
19
+ inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
20
+ generated_ids = model.generate(
21
+ pixel_values=inputs.pixel_values,
22
+ do_sample=decoding_method == "Nucleus sampling",
23
+ temperature=temperature,
24
+ length_penalty=length_penalty,
25
+ repetition_penalty=repetition_penalty,
26
+ max_length=50,
27
+ min_length=1,
28
+ num_beams=5,
29
+ top_p=0.9,
30
+ )
31
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
32
+ return result
33
+ else:
34
+ # Answer question
35
+ inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16)
36
+ generated_ids = model.generate(
37
+ **inputs,
38
+ do_sample=decoding_method == "Nucleus sampling",
39
+ temperature=temperature,
40
+ length_penalty=length_penalty,
41
+ repetition_penalty=repetition_penalty,
42
+ max_length=30,
43
+ min_length=1,
44
+ num_beams=5,
45
+ top_p=0.9,
46
+ )
47
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
48
+ return result
49
 
50
  # Define Gradio input and output components
51
+ image_input = gr.Image(type="numpy")
52
+ text_input = gr.Text()
53
  output_text = gr.outputs.Textbox()
54
 
55
+ # Define Gradio interface
56
+ gr.Interface(
57
+ fn=generate_text,
58
+ inputs=[image_input, text_input, gr.inputs.Radio(["Beam search", "Nucleus sampling"]), gr.inputs.Slider(0.5, 1.0, 0.1), gr.inputs.Slider(-1.0, 2.0, 0.2), gr.inputs.Slider(1.0, 5.0, 0.5)],
59
+ outputs=output_text,
60
+ examples=[
61
+ ["house.png", "Caption:"],
62
+ ["flower.jpg", "What is this flower and where is its origin?"],
63
+ ["pizza.jpg", "Caption:"],
64
+ ["sunset.jpg", "Caption:"],
65
+ ["forbidden_city.webp", "In what dynasties was this place built?"],
66
+ ],
67
+ title="BLIP-2",
68
+ description="Gradio demo for BLIP-2, image-to-text generation from Salesforce Research.",
69
+ ).launch()