chats-bug commited on
Commit
9845f41
·
1 Parent(s): 92aea9e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel
3
+ import torch
4
+ import open_clip
5
+
6
+ from huggingface_hub import hf_hub_download
7
+
8
+
9
+ # Load the Blip2 model
10
+ preprocessor_blip2_8_bit = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b")
11
+ model_blip2_8_bit = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b", device_map="auto", load_in_8bit=True)
12
+
13
+ # Load the Blip base model
14
+ preprocessor_blip_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
15
+ model_blip_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
16
+
17
+ # Load the Blip large model
18
+ preprocessor_blip_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
19
+ model_blip_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
20
+
21
+ # Load the GIT coco model
22
+ preprocessor_git_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
23
+ model_git_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
24
+
25
+ # Load the CLIP model
26
+ model_oc_coca, _, transform_oc_coca = open_clip.create_model_and_transforms(
27
+ model_name="coca_ViT-L-14",
28
+ pretrained="mscoco_finetuned_laion2B-s13B-b90k"
29
+ )
30
+
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ # Transfer the models to the device
33
+ model_blip2_8_bit.to(device)
34
+ model_blip_base.to(device)
35
+ model_blip_large.to(device)
36
+ model_git_large_coco.to(device)
37
+ model_oc_coca.to(device)
38
+
39
+
40
+ def generate_caption(
41
+ preprocessor,
42
+ model,
43
+ image,
44
+ tokenizer=None,
45
+ use_float_16=False,
46
+ ):
47
+ """
48
+ Generate captions for the given image.
49
+
50
+ -----
51
+ Parameters
52
+ preprocessor: AutoProcessor
53
+ The preprocessor for the model.
54
+ model: BlipForConditionalGeneration
55
+ The model to use.
56
+ image: PIL.Image
57
+ The image to generate captions for.
58
+ tokenizer: AutoTokenizer
59
+ The tokenizer to use. If None, the default tokenizer for the model will be used.
60
+ use_float_16: bool
61
+ Whether to use float16 precision. This can speed up inference, but may lead to worse results.
62
+
63
+ -----
64
+ Returns
65
+ str
66
+ The generated caption.
67
+ """
68
+ inputs = preprocessor(image, return_tensors="pt").to(device)
69
+
70
+ if use_float_16:
71
+ inputs = inputs.to(torch.float16)
72
+
73
+ generated_ids = model.generate(
74
+ pixel_values=inputs.pixel_values,
75
+ # attention_mask=inputs.attention_mask,
76
+ max_length=32,
77
+ use_cache=True,
78
+ )
79
+
80
+ if tokenizer is None:
81
+ generated_caption = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0]
82
+ else:
83
+ generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
84
+
85
+ return generated_caption
86
+
87
+
88
+ def generate_captions_clip(
89
+ model,
90
+ transform,
91
+ image
92
+ ):
93
+ """
94
+ Generate captions for the given image using CLIP.
95
+
96
+ -----
97
+ Parameters
98
+ model: VisionEncoderDecoderModel
99
+ The CLIP model to use.
100
+ transform: Callable
101
+ The transform to apply to the image before passing it to the model.
102
+ image: PIL.Image
103
+ The image to generate captions for.
104
+
105
+ -----
106
+ Returns
107
+ str
108
+ The generated caption.
109
+ """
110
+ img = transform(image).unsqueeze(0).to(device)
111
+ with torch.no_grad(), torch.cuda.amp.autocast():
112
+ generated = model.generate(img, seq_len=32, do_sample=True, temperature=0.9)
113
+
114
+ generated_caption = model.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
115
+ return generated_caption
116
+
117
+
118
+ def generate_captions(
119
+ image
120
+ ):
121
+ """
122
+ Generate captions for the given image.
123
+
124
+ -----
125
+ Parameters
126
+ image: PIL.Image
127
+ The image to generate captions for.
128
+
129
+ -----
130
+ Returns
131
+ str
132
+ The generated caption.
133
+ """
134
+ # Generate captions for the image using the Blip2 model
135
+ caption_blip2_8_bit = generate_caption(preprocessor_blip2_8_bit, model_blip2_8_bit, image, use_float_16=True).strip()
136
+
137
+ # Generate captions for the image using the Blip base model
138
+ caption_blip_base = generate_caption(preprocessor_blip_base, model_blip_base, image).strip()
139
+
140
+ # Generate captions for the image using the Blip large model
141
+ caption_blip_large = generate_caption(preprocessor_blip_large, model_blip_large, image).strip()
142
+
143
+ # Generate captions for the image using the GIT coco model
144
+ caption_git_large_coco = generate_caption(preprocessor_git_large_coco, model_git_large_coco, image).strip()
145
+
146
+ # Generate captions for the image using the CLIP model
147
+ caption_oc_coca = generate_captions_clip(model_oc_coca, transform_oc_coca, image).strip()
148
+
149
+ return caption_blip2_8_bit, caption_blip_base, caption_blip_large, caption_git_large_coco, caption_oc_coca
150
+
151
+
152
+ # Create the interface
153
+ iface = gr.Interface(
154
+ fn=generate_captions,
155
+ # Define the inputs: Image, Slider for Max Length, Slider for Temperature
156
+ inputs=[
157
+ gr.inputs.Image(label="Image"),
158
+ gr.inputs.Slider(minimum=16, maximum=64, step=2, default=32, label="Max Length"),
159
+ gr.inputs.Slider(minimum=0.5, maximum=1.5, step=0.1, default=1.0, label="Temperature"),
160
+ ],
161
+ # Define the outputs
162
+ outputs=[
163
+ gr.outputs.Textbox(label="Blip2 8-bit"),
164
+ gr.outputs.Textbox(label="Blip base"),
165
+ gr.outputs.Textbox(label="Blip large"),
166
+ gr.outputs.Textbox(label="GIT large coco"),
167
+ gr.outputs.Textbox(label="CLIP"),
168
+ ],
169
+ title="Image Captioning",
170
+ description="Generate captions for images using the Blip2 model, the Blip base model, the Blip large model, the GIT large coco model, and the CLIP model.",
171
+ enable_queue=True,
172
+ )
173
+
174
+ # Launch the interface
175
+ iface.launch()