Spaces:
Runtime error
Runtime error
Commit
·
6433553
1
Parent(s):
19077ad
trying to fix bugs
Browse files- app.py +26 -24
- requirements.txt +2 -2
app.py
CHANGED
@@ -8,8 +8,9 @@ from transformers import (
|
|
8 |
SiglipVisionModel,
|
9 |
AutoTokenizer,
|
10 |
AutoImageProcessor,
|
11 |
-
AutoModelForCausalLM
|
12 |
)
|
|
|
13 |
from PIL import Image
|
14 |
|
15 |
# Initialize device
|
@@ -19,23 +20,35 @@ print(f"Using device: {device}")
|
|
19 |
# Load models and processors
|
20 |
def load_models():
|
21 |
# Load SigLIP
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
23 |
siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
|
24 |
|
25 |
-
# Load Phi model
|
26 |
-
print(
|
27 |
-
|
28 |
-
"
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
)
|
33 |
|
34 |
-
phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/
|
35 |
if phi_tokenizer.pad_token is None:
|
36 |
phi_tokenizer.pad_token = phi_tokenizer.eos_token
|
37 |
|
38 |
# Load trained projections
|
|
|
39 |
linear_proj = torch.load('linear_projection_final.pth', map_location=device)
|
40 |
image_text_proj = torch.load('image_text_proj.pth', map_location=device)
|
41 |
|
@@ -71,15 +84,11 @@ def get_image_embedding(image, siglip_model, siglip_processor, linear_proj, devi
|
|
71 |
with torch.no_grad():
|
72 |
# Process image through SigLIP
|
73 |
inputs = siglip_processor(image, return_tensors="pt")
|
74 |
-
# Move inputs to device
|
75 |
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
76 |
outputs = siglip_model(**inputs)
|
77 |
image_features = outputs.pooler_output
|
78 |
-
|
79 |
-
# Project through trained linear layer
|
80 |
projected_features = linear_proj(image_features)
|
81 |
-
|
82 |
-
return projected_features
|
83 |
|
84 |
def get_random_images():
|
85 |
# Select 10 random images from first 100
|
@@ -166,9 +175,7 @@ with gr.Blocks() as demo:
|
|
166 |
|
167 |
with gr.Row():
|
168 |
with gr.Column():
|
169 |
-
# Button to get random images
|
170 |
random_btn = gr.Button("Get Random Images")
|
171 |
-
# Gallery to display images
|
172 |
gallery = gr.Gallery(
|
173 |
label="Click an image to select it",
|
174 |
show_label=True,
|
@@ -180,17 +187,13 @@ with gr.Blocks() as demo:
|
|
180 |
)
|
181 |
|
182 |
with gr.Column():
|
183 |
-
# Display selected image
|
184 |
selected_img = gr.Image(label="Selected Image", height=200)
|
185 |
-
# Question buttons
|
186 |
q_buttons = []
|
187 |
for i, q in enumerate(questions):
|
188 |
btn = gr.Button(f"Q{i+1}: {q}")
|
189 |
q_buttons.append(btn)
|
190 |
-
# Answer textbox
|
191 |
answer_box = gr.Textbox(label="Answer", lines=3)
|
192 |
|
193 |
-
# Handle random image button click
|
194 |
def on_random_click():
|
195 |
images, indices = get_random_images()
|
196 |
return {
|
@@ -206,7 +209,6 @@ with gr.Blocks() as demo:
|
|
206 |
outputs=[gallery, image_indices, selected_image_tensor, selected_img, answer_box]
|
207 |
)
|
208 |
|
209 |
-
# Handle image selection
|
210 |
def on_image_select(evt: gr.SelectData, images, indices):
|
211 |
if images is None or evt.index >= len(images):
|
212 |
return None, None, ""
|
@@ -220,7 +222,6 @@ with gr.Blocks() as demo:
|
|
220 |
outputs=[selected_image_tensor, selected_img, answer_box]
|
221 |
)
|
222 |
|
223 |
-
# Handle question button clicks
|
224 |
for i, btn in enumerate(q_buttons):
|
225 |
btn.click(
|
226 |
generate_answer,
|
@@ -228,4 +229,5 @@ with gr.Blocks() as demo:
|
|
228 |
outputs=answer_box
|
229 |
)
|
230 |
|
231 |
-
|
|
|
|
8 |
SiglipVisionModel,
|
9 |
AutoTokenizer,
|
10 |
AutoImageProcessor,
|
11 |
+
AutoModelForCausalLM
|
12 |
)
|
13 |
+
from peft import PeftModel
|
14 |
from PIL import Image
|
15 |
|
16 |
# Initialize device
|
|
|
20 |
# Load models and processors
|
21 |
def load_models():
|
22 |
# Load SigLIP
|
23 |
+
print("Loading SigLIP model...")
|
24 |
+
siglip_model = SiglipVisionModel.from_pretrained(
|
25 |
+
"google/siglip-so400m-patch14-384",
|
26 |
+
torch_dtype=torch.float32,
|
27 |
+
low_cpu_mem_usage=True
|
28 |
+
).to(device)
|
29 |
siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
|
30 |
|
31 |
+
# Load base Phi-3 model
|
32 |
+
print("Loading Phi-3 model...")
|
33 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
34 |
+
"microsoft/phi-3-mini-4k-instruct",
|
35 |
+
torch_dtype=torch.float32,
|
36 |
+
low_cpu_mem_usage=True
|
37 |
+
).to(device)
|
38 |
+
|
39 |
+
# Load the trained LoRA weights
|
40 |
+
print("Loading trained LoRA weights...")
|
41 |
+
phi_model = PeftModel.from_pretrained(
|
42 |
+
base_model,
|
43 |
+
"phi_model_trained"
|
44 |
)
|
45 |
|
46 |
+
phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
|
47 |
if phi_tokenizer.pad_token is None:
|
48 |
phi_tokenizer.pad_token = phi_tokenizer.eos_token
|
49 |
|
50 |
# Load trained projections
|
51 |
+
print("Loading projection layers...")
|
52 |
linear_proj = torch.load('linear_projection_final.pth', map_location=device)
|
53 |
image_text_proj = torch.load('image_text_proj.pth', map_location=device)
|
54 |
|
|
|
84 |
with torch.no_grad():
|
85 |
# Process image through SigLIP
|
86 |
inputs = siglip_processor(image, return_tensors="pt")
|
|
|
87 |
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
88 |
outputs = siglip_model(**inputs)
|
89 |
image_features = outputs.pooler_output
|
|
|
|
|
90 |
projected_features = linear_proj(image_features)
|
91 |
+
return projected_features
|
|
|
92 |
|
93 |
def get_random_images():
|
94 |
# Select 10 random images from first 100
|
|
|
175 |
|
176 |
with gr.Row():
|
177 |
with gr.Column():
|
|
|
178 |
random_btn = gr.Button("Get Random Images")
|
|
|
179 |
gallery = gr.Gallery(
|
180 |
label="Click an image to select it",
|
181 |
show_label=True,
|
|
|
187 |
)
|
188 |
|
189 |
with gr.Column():
|
|
|
190 |
selected_img = gr.Image(label="Selected Image", height=200)
|
|
|
191 |
q_buttons = []
|
192 |
for i, q in enumerate(questions):
|
193 |
btn = gr.Button(f"Q{i+1}: {q}")
|
194 |
q_buttons.append(btn)
|
|
|
195 |
answer_box = gr.Textbox(label="Answer", lines=3)
|
196 |
|
|
|
197 |
def on_random_click():
|
198 |
images, indices = get_random_images()
|
199 |
return {
|
|
|
209 |
outputs=[gallery, image_indices, selected_image_tensor, selected_img, answer_box]
|
210 |
)
|
211 |
|
|
|
212 |
def on_image_select(evt: gr.SelectData, images, indices):
|
213 |
if images is None or evt.index >= len(images):
|
214 |
return None, None, ""
|
|
|
222 |
outputs=[selected_image_tensor, selected_img, answer_box]
|
223 |
)
|
224 |
|
|
|
225 |
for i, btn in enumerate(q_buttons):
|
226 |
btn.click(
|
227 |
generate_answer,
|
|
|
229 |
outputs=answer_box
|
230 |
)
|
231 |
|
232 |
+
# Launch with minimal settings
|
233 |
+
demo.queue(max_size=1).launch(show_error=True)
|
requirements.txt
CHANGED
@@ -6,5 +6,5 @@ tqdm>=4.65.0
|
|
6 |
numpy>=1.24.0
|
7 |
accelerate>=0.25.0
|
8 |
gradio>=4.19.0
|
9 |
-
|
10 |
-
|
|
|
6 |
numpy>=1.24.0
|
7 |
accelerate>=0.25.0
|
8 |
gradio>=4.19.0
|
9 |
+
peft>=0.7.0
|
10 |
+
scipy>=1.11.0
|