chbsaikiran commited on
Commit
6433553
·
1 Parent(s): 19077ad

trying to fix bugs

Browse files
Files changed (2) hide show
  1. app.py +26 -24
  2. requirements.txt +2 -2
app.py CHANGED
@@ -8,8 +8,9 @@ from transformers import (
8
  SiglipVisionModel,
9
  AutoTokenizer,
10
  AutoImageProcessor,
11
- AutoModelForCausalLM,
12
  )
 
13
  from PIL import Image
14
 
15
  # Initialize device
@@ -19,23 +20,35 @@ print(f"Using device: {device}")
19
  # Load models and processors
20
  def load_models():
21
  # Load SigLIP
22
- siglip_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384").to(device)
 
 
 
 
 
23
  siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
24
 
25
- # Load Phi model with 8-bit quantization that works on both CPU and GPU
26
- print(f"Loading Phi model on {device}...")
27
- phi_model = AutoModelForCausalLM.from_pretrained(
28
- "phi_model_trained",
29
- load_in_8bit=True, # This works on both CPU and GPU
30
- device_map="auto",
31
- torch_dtype=torch.float32
 
 
 
 
 
 
32
  )
33
 
34
- phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
35
  if phi_tokenizer.pad_token is None:
36
  phi_tokenizer.pad_token = phi_tokenizer.eos_token
37
 
38
  # Load trained projections
 
39
  linear_proj = torch.load('linear_projection_final.pth', map_location=device)
40
  image_text_proj = torch.load('image_text_proj.pth', map_location=device)
41
 
@@ -71,15 +84,11 @@ def get_image_embedding(image, siglip_model, siglip_processor, linear_proj, devi
71
  with torch.no_grad():
72
  # Process image through SigLIP
73
  inputs = siglip_processor(image, return_tensors="pt")
74
- # Move inputs to device
75
  inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
76
  outputs = siglip_model(**inputs)
77
  image_features = outputs.pooler_output
78
-
79
- # Project through trained linear layer
80
  projected_features = linear_proj(image_features)
81
-
82
- return projected_features
83
 
84
  def get_random_images():
85
  # Select 10 random images from first 100
@@ -166,9 +175,7 @@ with gr.Blocks() as demo:
166
 
167
  with gr.Row():
168
  with gr.Column():
169
- # Button to get random images
170
  random_btn = gr.Button("Get Random Images")
171
- # Gallery to display images
172
  gallery = gr.Gallery(
173
  label="Click an image to select it",
174
  show_label=True,
@@ -180,17 +187,13 @@ with gr.Blocks() as demo:
180
  )
181
 
182
  with gr.Column():
183
- # Display selected image
184
  selected_img = gr.Image(label="Selected Image", height=200)
185
- # Question buttons
186
  q_buttons = []
187
  for i, q in enumerate(questions):
188
  btn = gr.Button(f"Q{i+1}: {q}")
189
  q_buttons.append(btn)
190
- # Answer textbox
191
  answer_box = gr.Textbox(label="Answer", lines=3)
192
 
193
- # Handle random image button click
194
  def on_random_click():
195
  images, indices = get_random_images()
196
  return {
@@ -206,7 +209,6 @@ with gr.Blocks() as demo:
206
  outputs=[gallery, image_indices, selected_image_tensor, selected_img, answer_box]
207
  )
208
 
209
- # Handle image selection
210
  def on_image_select(evt: gr.SelectData, images, indices):
211
  if images is None or evt.index >= len(images):
212
  return None, None, ""
@@ -220,7 +222,6 @@ with gr.Blocks() as demo:
220
  outputs=[selected_image_tensor, selected_img, answer_box]
221
  )
222
 
223
- # Handle question button clicks
224
  for i, btn in enumerate(q_buttons):
225
  btn.click(
226
  generate_answer,
@@ -228,4 +229,5 @@ with gr.Blocks() as demo:
228
  outputs=answer_box
229
  )
230
 
231
- demo.launch()
 
 
8
  SiglipVisionModel,
9
  AutoTokenizer,
10
  AutoImageProcessor,
11
+ AutoModelForCausalLM
12
  )
13
+ from peft import PeftModel
14
  from PIL import Image
15
 
16
  # Initialize device
 
20
  # Load models and processors
21
  def load_models():
22
  # Load SigLIP
23
+ print("Loading SigLIP model...")
24
+ siglip_model = SiglipVisionModel.from_pretrained(
25
+ "google/siglip-so400m-patch14-384",
26
+ torch_dtype=torch.float32,
27
+ low_cpu_mem_usage=True
28
+ ).to(device)
29
  siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
30
 
31
+ # Load base Phi-3 model
32
+ print("Loading Phi-3 model...")
33
+ base_model = AutoModelForCausalLM.from_pretrained(
34
+ "microsoft/phi-3-mini-4k-instruct",
35
+ torch_dtype=torch.float32,
36
+ low_cpu_mem_usage=True
37
+ ).to(device)
38
+
39
+ # Load the trained LoRA weights
40
+ print("Loading trained LoRA weights...")
41
+ phi_model = PeftModel.from_pretrained(
42
+ base_model,
43
+ "phi_model_trained"
44
  )
45
 
46
+ phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
47
  if phi_tokenizer.pad_token is None:
48
  phi_tokenizer.pad_token = phi_tokenizer.eos_token
49
 
50
  # Load trained projections
51
+ print("Loading projection layers...")
52
  linear_proj = torch.load('linear_projection_final.pth', map_location=device)
53
  image_text_proj = torch.load('image_text_proj.pth', map_location=device)
54
 
 
84
  with torch.no_grad():
85
  # Process image through SigLIP
86
  inputs = siglip_processor(image, return_tensors="pt")
 
87
  inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
88
  outputs = siglip_model(**inputs)
89
  image_features = outputs.pooler_output
 
 
90
  projected_features = linear_proj(image_features)
91
+ return projected_features
 
92
 
93
  def get_random_images():
94
  # Select 10 random images from first 100
 
175
 
176
  with gr.Row():
177
  with gr.Column():
 
178
  random_btn = gr.Button("Get Random Images")
 
179
  gallery = gr.Gallery(
180
  label="Click an image to select it",
181
  show_label=True,
 
187
  )
188
 
189
  with gr.Column():
 
190
  selected_img = gr.Image(label="Selected Image", height=200)
 
191
  q_buttons = []
192
  for i, q in enumerate(questions):
193
  btn = gr.Button(f"Q{i+1}: {q}")
194
  q_buttons.append(btn)
 
195
  answer_box = gr.Textbox(label="Answer", lines=3)
196
 
 
197
  def on_random_click():
198
  images, indices = get_random_images()
199
  return {
 
209
  outputs=[gallery, image_indices, selected_image_tensor, selected_img, answer_box]
210
  )
211
 
 
212
  def on_image_select(evt: gr.SelectData, images, indices):
213
  if images is None or evt.index >= len(images):
214
  return None, None, ""
 
222
  outputs=[selected_image_tensor, selected_img, answer_box]
223
  )
224
 
 
225
  for i, btn in enumerate(q_buttons):
226
  btn.click(
227
  generate_answer,
 
229
  outputs=answer_box
230
  )
231
 
232
+ # Launch with minimal settings
233
+ demo.queue(max_size=1).launch(show_error=True)
requirements.txt CHANGED
@@ -6,5 +6,5 @@ tqdm>=4.65.0
6
  numpy>=1.24.0
7
  accelerate>=0.25.0
8
  gradio>=4.19.0
9
- bitsandbytes>=0.41.1
10
- peft>=0.7.0
 
6
  numpy>=1.24.0
7
  accelerate>=0.25.0
8
  gradio>=4.19.0
9
+ peft>=0.7.0
10
+ scipy>=1.11.0