Spaces:
Runtime error
Runtime error
Commit
·
5e37be9
1
Parent(s):
7eb024a
Intial Commit
Browse files- app.py +234 -0
- extract_answers.py +45 -0
- image_text_proj.pth +3 -0
- linear_projection_final.pth +3 -0
- phi_model_trained/README.md +202 -0
- phi_model_trained/adapter_config.json +36 -0
- phi_model_trained/adapter_model.safetensors +3 -0
- process_cifar10.py +81 -0
- requirements.txt +10 -0
- train_linear_projection.py +216 -0
- train_phi_with_siglip.py +263 -0
app.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
from transformers import (
|
8 |
+
SiglipVisionModel,
|
9 |
+
AutoTokenizer,
|
10 |
+
AutoImageProcessor,
|
11 |
+
AutoModelForCausalLM,
|
12 |
+
BitsAndBytesConfig
|
13 |
+
)
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
# Initialize device
|
17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
+
# Load models and processors
|
20 |
+
def load_models():
|
21 |
+
# Load SigLIP
|
22 |
+
siglip_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384").to(device)
|
23 |
+
siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
|
24 |
+
|
25 |
+
# Load Phi model with 4-bit quantization
|
26 |
+
bnb_config = BitsAndBytesConfig(
|
27 |
+
load_in_4bit=True,
|
28 |
+
bnb_4bit_quant_type="nf4",
|
29 |
+
bnb_4bit_compute_dtype=torch.float16,
|
30 |
+
bnb_4bit_use_double_quant=False
|
31 |
+
)
|
32 |
+
phi_model = AutoModelForCausalLM.from_pretrained(
|
33 |
+
"phi_model_trained", # Load from saved directory
|
34 |
+
quantization_config=bnb_config,
|
35 |
+
device_map="auto"
|
36 |
+
)
|
37 |
+
phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
38 |
+
if phi_tokenizer.pad_token is None:
|
39 |
+
phi_tokenizer.pad_token = phi_tokenizer.eos_token
|
40 |
+
|
41 |
+
# Load trained projections
|
42 |
+
linear_proj = torch.load('linear_projection_final.pth', map_location=device)
|
43 |
+
image_text_proj = torch.load('image_text_proj.pth', map_location=device)
|
44 |
+
|
45 |
+
return (siglip_model, siglip_processor, phi_model, phi_tokenizer, linear_proj, image_text_proj)
|
46 |
+
|
47 |
+
# Load all models at startup
|
48 |
+
print("Loading models...")
|
49 |
+
models = load_models()
|
50 |
+
siglip_model, siglip_processor, phi_model, phi_tokenizer, linear_proj, image_text_proj = models
|
51 |
+
print("Models loaded successfully!")
|
52 |
+
|
53 |
+
# Load CIFAR10 test dataset
|
54 |
+
transform = transforms.Compose([
|
55 |
+
transforms.Resize((384, 384)),
|
56 |
+
transforms.ToTensor(),
|
57 |
+
])
|
58 |
+
|
59 |
+
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
|
60 |
+
|
61 |
+
# Get first 100 images
|
62 |
+
first_100_images = [(images, labels) for images, labels in list(testset)[:100]]
|
63 |
+
|
64 |
+
# Questions list
|
65 |
+
questions = [
|
66 |
+
"Give a description of the image?",
|
67 |
+
"How does the main object in the image look like?",
|
68 |
+
"How can the main object in the image be useful to humans?",
|
69 |
+
"What is the color of the main object in the image?",
|
70 |
+
"Describe the setting of the image?"
|
71 |
+
]
|
72 |
+
|
73 |
+
def get_image_embedding(image, siglip_model, siglip_processor, linear_proj, device):
|
74 |
+
with torch.no_grad():
|
75 |
+
# Process image through SigLIP
|
76 |
+
inputs = siglip_processor(image, return_tensors="pt")
|
77 |
+
# Move inputs to device
|
78 |
+
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
79 |
+
outputs = siglip_model(**inputs)
|
80 |
+
image_features = outputs.pooler_output
|
81 |
+
|
82 |
+
# Project through trained linear layer
|
83 |
+
projected_features = linear_proj(image_features)
|
84 |
+
|
85 |
+
return projected_features
|
86 |
+
|
87 |
+
def get_random_images():
|
88 |
+
# Select 10 random images from first 100
|
89 |
+
selected_indices = random.sample(range(100), 10)
|
90 |
+
selected_images = [first_100_images[i][0] for i in selected_indices]
|
91 |
+
|
92 |
+
# Convert to numpy arrays and transpose to correct format (H,W,C)
|
93 |
+
images_np = [img.permute(1, 2, 0).numpy() for img in selected_images]
|
94 |
+
return images_np, selected_indices
|
95 |
+
|
96 |
+
def generate_answer(image_tensor, question_index):
|
97 |
+
if image_tensor is None:
|
98 |
+
return "Please select an image first!"
|
99 |
+
|
100 |
+
try:
|
101 |
+
# Get image embedding
|
102 |
+
image_embedding = get_image_embedding(
|
103 |
+
image_tensor,
|
104 |
+
siglip_model,
|
105 |
+
siglip_processor,
|
106 |
+
linear_proj,
|
107 |
+
device
|
108 |
+
)
|
109 |
+
|
110 |
+
# Get question
|
111 |
+
question = questions[question_index]
|
112 |
+
|
113 |
+
# Tokenize question
|
114 |
+
question_tokens = phi_tokenizer(
|
115 |
+
question,
|
116 |
+
padding=True,
|
117 |
+
truncation=True,
|
118 |
+
max_length=512,
|
119 |
+
return_tensors="pt"
|
120 |
+
).to(device)
|
121 |
+
|
122 |
+
# Get question embeddings
|
123 |
+
question_embeds = phi_model.get_input_embeddings()(question_tokens['input_ids'])
|
124 |
+
|
125 |
+
# Project and prepare image embeddings
|
126 |
+
image_embeds = image_text_proj(image_embedding)
|
127 |
+
image_embeds = image_embeds.unsqueeze(1)
|
128 |
+
|
129 |
+
# Combine embeddings
|
130 |
+
combined_embedding = torch.cat([
|
131 |
+
image_embeds,
|
132 |
+
question_embeds
|
133 |
+
], dim=1)
|
134 |
+
|
135 |
+
# Create attention mask
|
136 |
+
attention_mask = torch.ones(
|
137 |
+
(1, combined_embedding.size(1)),
|
138 |
+
dtype=torch.long,
|
139 |
+
device=device
|
140 |
+
)
|
141 |
+
|
142 |
+
# Generate answer
|
143 |
+
with torch.no_grad():
|
144 |
+
outputs = phi_model.generate(
|
145 |
+
inputs_embeds=combined_embedding,
|
146 |
+
attention_mask=attention_mask,
|
147 |
+
max_new_tokens=100,
|
148 |
+
num_beams=4,
|
149 |
+
temperature=0.7,
|
150 |
+
do_sample=True,
|
151 |
+
pad_token_id=phi_tokenizer.pad_token_id,
|
152 |
+
eos_token_id=phi_tokenizer.eos_token_id
|
153 |
+
)
|
154 |
+
|
155 |
+
# Decode the generated answer
|
156 |
+
answer = phi_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
157 |
+
return answer
|
158 |
+
|
159 |
+
except Exception as e:
|
160 |
+
return f"Error generating answer: {str(e)}"
|
161 |
+
|
162 |
+
# Create Gradio interface
|
163 |
+
with gr.Blocks() as demo:
|
164 |
+
gr.Markdown("# CIFAR10 Image Question Answering System")
|
165 |
+
|
166 |
+
# State variables
|
167 |
+
selected_image_tensor = gr.State(None)
|
168 |
+
image_indices = gr.State([])
|
169 |
+
|
170 |
+
with gr.Row():
|
171 |
+
with gr.Column():
|
172 |
+
# Button to get random images
|
173 |
+
random_btn = gr.Button("Get Random Images")
|
174 |
+
# Gallery to display images
|
175 |
+
gallery = gr.Gallery(
|
176 |
+
label="Click an image to select it",
|
177 |
+
show_label=True,
|
178 |
+
elem_id="gallery",
|
179 |
+
columns=[5],
|
180 |
+
rows=[2],
|
181 |
+
height="auto",
|
182 |
+
allow_preview=False
|
183 |
+
)
|
184 |
+
|
185 |
+
with gr.Column():
|
186 |
+
# Display selected image
|
187 |
+
selected_img = gr.Image(label="Selected Image", height=200)
|
188 |
+
# Question buttons
|
189 |
+
q_buttons = []
|
190 |
+
for i, q in enumerate(questions):
|
191 |
+
btn = gr.Button(f"Q{i+1}: {q}")
|
192 |
+
q_buttons.append(btn)
|
193 |
+
# Answer textbox
|
194 |
+
answer_box = gr.Textbox(label="Answer", lines=3)
|
195 |
+
|
196 |
+
# Handle random image button click
|
197 |
+
def on_random_click():
|
198 |
+
images, indices = get_random_images()
|
199 |
+
return {
|
200 |
+
gallery: images,
|
201 |
+
image_indices: indices,
|
202 |
+
selected_image_tensor: None,
|
203 |
+
selected_img: None,
|
204 |
+
answer_box: ""
|
205 |
+
}
|
206 |
+
|
207 |
+
random_btn.click(
|
208 |
+
on_random_click,
|
209 |
+
outputs=[gallery, image_indices, selected_image_tensor, selected_img, answer_box]
|
210 |
+
)
|
211 |
+
|
212 |
+
# Handle image selection
|
213 |
+
def on_image_select(evt: gr.SelectData, images, indices):
|
214 |
+
if images is None or evt.index >= len(images):
|
215 |
+
return None, None, ""
|
216 |
+
selected_idx = indices[evt.index]
|
217 |
+
selected_tensor = first_100_images[selected_idx][0]
|
218 |
+
return selected_tensor, images[evt.index], ""
|
219 |
+
|
220 |
+
gallery.select(
|
221 |
+
on_image_select,
|
222 |
+
inputs=[gallery, image_indices],
|
223 |
+
outputs=[selected_image_tensor, selected_img, answer_box]
|
224 |
+
)
|
225 |
+
|
226 |
+
# Handle question button clicks
|
227 |
+
for i, btn in enumerate(q_buttons):
|
228 |
+
btn.click(
|
229 |
+
generate_answer,
|
230 |
+
inputs=[selected_image_tensor, gr.Number(value=i, visible=False)],
|
231 |
+
outputs=answer_box
|
232 |
+
)
|
233 |
+
|
234 |
+
demo.launch()
|
extract_answers.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import glob
|
4 |
+
|
5 |
+
def extract_assistant_answers(input_file):
|
6 |
+
"""Extract the text after 'Assistant:' from the input file."""
|
7 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
8 |
+
content = f.read()
|
9 |
+
|
10 |
+
# Split content by "Assistant:" to get all sections after it
|
11 |
+
sections = content.split("Assistant:")
|
12 |
+
|
13 |
+
# Process each section to get clean answers
|
14 |
+
answers = []
|
15 |
+
for section in sections[1:]: # Skip the first split as it's before first "Assistant:"
|
16 |
+
# Get text up to next "Q" or "User:" or end of string
|
17 |
+
answer = section.split("Q")[0].split("User:")[0].strip()
|
18 |
+
if answer:
|
19 |
+
answers.append(answer)
|
20 |
+
|
21 |
+
return answers
|
22 |
+
|
23 |
+
def process_all_files():
|
24 |
+
"""Process all image_*.txt files in the qa_outputs directory."""
|
25 |
+
# Get all image_*.txt files
|
26 |
+
input_files = glob.glob("qa_outputs/image_*.txt")
|
27 |
+
|
28 |
+
for input_file in input_files:
|
29 |
+
# Extract the base name without extension
|
30 |
+
base_name = os.path.splitext(input_file)[0]
|
31 |
+
output_file = f"{base_name}_extr.txt"
|
32 |
+
|
33 |
+
# Extract answers
|
34 |
+
answers = extract_assistant_answers(input_file)
|
35 |
+
|
36 |
+
# Write answers to the output file
|
37 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
38 |
+
for i, answer in enumerate(answers, 1):
|
39 |
+
f.write(f"{answer}\n")
|
40 |
+
|
41 |
+
print(f"Processed {input_file} -> {output_file}")
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
process_all_files()
|
45 |
+
print("Extraction complete! Check the files with '_extr' suffix.")
|
image_text_proj.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ddb52b49da8704aff3b46f2c503ac20cf64e3e6efbe4844e9ac89e85d9673894
|
3 |
+
size 1586824
|
linear_projection_final.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:374fc454085ce3b227047bfbf45c2df30a97a308b593ad1f2ef5ec763cab5afb
|
3 |
+
size 592056
|
phi_model_trained/README.md
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
base_model: microsoft/Phi-3-mini-4k-instruct
|
3 |
+
library_name: peft
|
4 |
+
---
|
5 |
+
|
6 |
+
# Model Card for Model ID
|
7 |
+
|
8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Model Details
|
13 |
+
|
14 |
+
### Model Description
|
15 |
+
|
16 |
+
<!-- Provide a longer summary of what this model is. -->
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
- **Developed by:** [More Information Needed]
|
21 |
+
- **Funded by [optional]:** [More Information Needed]
|
22 |
+
- **Shared by [optional]:** [More Information Needed]
|
23 |
+
- **Model type:** [More Information Needed]
|
24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
25 |
+
- **License:** [More Information Needed]
|
26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
+
|
28 |
+
### Model Sources [optional]
|
29 |
+
|
30 |
+
<!-- Provide the basic links for the model. -->
|
31 |
+
|
32 |
+
- **Repository:** [More Information Needed]
|
33 |
+
- **Paper [optional]:** [More Information Needed]
|
34 |
+
- **Demo [optional]:** [More Information Needed]
|
35 |
+
|
36 |
+
## Uses
|
37 |
+
|
38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
39 |
+
|
40 |
+
### Direct Use
|
41 |
+
|
42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
+
|
44 |
+
[More Information Needed]
|
45 |
+
|
46 |
+
### Downstream Use [optional]
|
47 |
+
|
48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
49 |
+
|
50 |
+
[More Information Needed]
|
51 |
+
|
52 |
+
### Out-of-Scope Use
|
53 |
+
|
54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
55 |
+
|
56 |
+
[More Information Needed]
|
57 |
+
|
58 |
+
## Bias, Risks, and Limitations
|
59 |
+
|
60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
61 |
+
|
62 |
+
[More Information Needed]
|
63 |
+
|
64 |
+
### Recommendations
|
65 |
+
|
66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
67 |
+
|
68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
69 |
+
|
70 |
+
## How to Get Started with the Model
|
71 |
+
|
72 |
+
Use the code below to get started with the model.
|
73 |
+
|
74 |
+
[More Information Needed]
|
75 |
+
|
76 |
+
## Training Details
|
77 |
+
|
78 |
+
### Training Data
|
79 |
+
|
80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
81 |
+
|
82 |
+
[More Information Needed]
|
83 |
+
|
84 |
+
### Training Procedure
|
85 |
+
|
86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
87 |
+
|
88 |
+
#### Preprocessing [optional]
|
89 |
+
|
90 |
+
[More Information Needed]
|
91 |
+
|
92 |
+
|
93 |
+
#### Training Hyperparameters
|
94 |
+
|
95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
96 |
+
|
97 |
+
#### Speeds, Sizes, Times [optional]
|
98 |
+
|
99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
+
|
101 |
+
[More Information Needed]
|
102 |
+
|
103 |
+
## Evaluation
|
104 |
+
|
105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
+
|
107 |
+
### Testing Data, Factors & Metrics
|
108 |
+
|
109 |
+
#### Testing Data
|
110 |
+
|
111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
112 |
+
|
113 |
+
[More Information Needed]
|
114 |
+
|
115 |
+
#### Factors
|
116 |
+
|
117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
+
|
119 |
+
[More Information Needed]
|
120 |
+
|
121 |
+
#### Metrics
|
122 |
+
|
123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
+
|
125 |
+
[More Information Needed]
|
126 |
+
|
127 |
+
### Results
|
128 |
+
|
129 |
+
[More Information Needed]
|
130 |
+
|
131 |
+
#### Summary
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
## Model Examination [optional]
|
136 |
+
|
137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
138 |
+
|
139 |
+
[More Information Needed]
|
140 |
+
|
141 |
+
## Environmental Impact
|
142 |
+
|
143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
+
|
145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
+
|
147 |
+
- **Hardware Type:** [More Information Needed]
|
148 |
+
- **Hours used:** [More Information Needed]
|
149 |
+
- **Cloud Provider:** [More Information Needed]
|
150 |
+
- **Compute Region:** [More Information Needed]
|
151 |
+
- **Carbon Emitted:** [More Information Needed]
|
152 |
+
|
153 |
+
## Technical Specifications [optional]
|
154 |
+
|
155 |
+
### Model Architecture and Objective
|
156 |
+
|
157 |
+
[More Information Needed]
|
158 |
+
|
159 |
+
### Compute Infrastructure
|
160 |
+
|
161 |
+
[More Information Needed]
|
162 |
+
|
163 |
+
#### Hardware
|
164 |
+
|
165 |
+
[More Information Needed]
|
166 |
+
|
167 |
+
#### Software
|
168 |
+
|
169 |
+
[More Information Needed]
|
170 |
+
|
171 |
+
## Citation [optional]
|
172 |
+
|
173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
+
|
175 |
+
**BibTeX:**
|
176 |
+
|
177 |
+
[More Information Needed]
|
178 |
+
|
179 |
+
**APA:**
|
180 |
+
|
181 |
+
[More Information Needed]
|
182 |
+
|
183 |
+
## Glossary [optional]
|
184 |
+
|
185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
186 |
+
|
187 |
+
[More Information Needed]
|
188 |
+
|
189 |
+
## More Information [optional]
|
190 |
+
|
191 |
+
[More Information Needed]
|
192 |
+
|
193 |
+
## Model Card Authors [optional]
|
194 |
+
|
195 |
+
[More Information Needed]
|
196 |
+
|
197 |
+
## Model Card Contact
|
198 |
+
|
199 |
+
[More Information Needed]
|
200 |
+
### Framework versions
|
201 |
+
|
202 |
+
- PEFT 0.15.1
|
phi_model_trained/adapter_config.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "microsoft/Phi-3-mini-4k-instruct",
|
5 |
+
"bias": "none",
|
6 |
+
"corda_config": null,
|
7 |
+
"eva_config": null,
|
8 |
+
"exclude_modules": null,
|
9 |
+
"fan_in_fan_out": false,
|
10 |
+
"inference_mode": true,
|
11 |
+
"init_lora_weights": true,
|
12 |
+
"layer_replication": null,
|
13 |
+
"layers_pattern": null,
|
14 |
+
"layers_to_transform": null,
|
15 |
+
"loftq_config": {},
|
16 |
+
"lora_alpha": 32,
|
17 |
+
"lora_bias": false,
|
18 |
+
"lora_dropout": 0.05,
|
19 |
+
"megatron_config": null,
|
20 |
+
"megatron_core": "megatron.core",
|
21 |
+
"modules_to_save": null,
|
22 |
+
"peft_type": "LORA",
|
23 |
+
"r": 16,
|
24 |
+
"rank_pattern": {},
|
25 |
+
"revision": null,
|
26 |
+
"target_modules": [
|
27 |
+
"mlp.dense_h_to_4h",
|
28 |
+
"mlp.dense_4h_to_h",
|
29 |
+
"self_attn.qkv_proj",
|
30 |
+
"self_attn.dense"
|
31 |
+
],
|
32 |
+
"task_type": "CAUSAL_LM",
|
33 |
+
"trainable_token_indices": null,
|
34 |
+
"use_dora": false,
|
35 |
+
"use_rslora": false
|
36 |
+
}
|
phi_model_trained/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91ad139ce2f99c0ed85ff06edd5ac6b766baef76fab1d3e896c9ac32589e96fb
|
3 |
+
size 25174552
|
process_cifar10.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
# Initialize model and processor
|
10 |
+
model_path = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
11 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
12 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
13 |
+
model_path,
|
14 |
+
torch_dtype=torch.bfloat16
|
15 |
+
#_attn_implementation="flash_attention_2"
|
16 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
|
18 |
+
# Create output directory
|
19 |
+
os.makedirs("SigLIP_Training/qa_outputs", exist_ok=True)
|
20 |
+
|
21 |
+
# Load CIFAR-10 dataset
|
22 |
+
transform = transforms.Compose([
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.ToPILImage()
|
25 |
+
])
|
26 |
+
|
27 |
+
# Using test set instead of train set
|
28 |
+
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
|
29 |
+
download=True, transform=transform)
|
30 |
+
|
31 |
+
# List of questions
|
32 |
+
questions = [
|
33 |
+
"Give a description of the image?",
|
34 |
+
"How does the main object in the image look like?",
|
35 |
+
"How can the main object in the image be useful to humans?",
|
36 |
+
"What is the color of the main object in the image?",
|
37 |
+
"Describe the setting of the image?"
|
38 |
+
]
|
39 |
+
|
40 |
+
def process_image(image, image_idx):
|
41 |
+
# Create output file
|
42 |
+
output_file = f"SigLIP_Training/qa_outputs/image_{image_idx}.txt"
|
43 |
+
|
44 |
+
with open(output_file, 'w') as f:
|
45 |
+
for q_idx, question in enumerate(questions, 1):
|
46 |
+
# Prepare the message for the model
|
47 |
+
messages = [
|
48 |
+
{
|
49 |
+
"role": "user",
|
50 |
+
"content": [
|
51 |
+
{"type": "image", "image": image},
|
52 |
+
{"type": "text", "text": question}
|
53 |
+
]
|
54 |
+
}
|
55 |
+
]
|
56 |
+
|
57 |
+
# Process inputs
|
58 |
+
inputs = processor.apply_chat_template(
|
59 |
+
messages,
|
60 |
+
add_generation_prompt=True,
|
61 |
+
tokenize=True,
|
62 |
+
return_dict=True,
|
63 |
+
return_tensors="pt"
|
64 |
+
).to(model.device, dtype=torch.bfloat16)
|
65 |
+
|
66 |
+
# Generate answer
|
67 |
+
generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=64)
|
68 |
+
answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
69 |
+
|
70 |
+
# Write to file in the correct format
|
71 |
+
f.write(f"Q{q_idx}: {question}\n")
|
72 |
+
f.write(f"A{q_idx}: {answer}\n")
|
73 |
+
|
74 |
+
# Process all images from test set
|
75 |
+
print(f"Starting to process CIFAR-10 test set images...")
|
76 |
+
for idx, (image, _) in enumerate(tqdm(testset)):
|
77 |
+
process_image(image, idx)
|
78 |
+
#if idx >= 1000: # Process first 1000 test images
|
79 |
+
# break
|
80 |
+
|
81 |
+
print("Processing complete! Check the SigLIP_Training/qa_outputs directory for results.")
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.36.0
|
3 |
+
torchvision>=0.15.0
|
4 |
+
pillow>=9.3.0
|
5 |
+
tqdm>=4.65.0
|
6 |
+
numpy>=1.24.0
|
7 |
+
accelerate>=0.25.0
|
8 |
+
gradio>=4.19.0
|
9 |
+
bitsandbytes>=0.41.1
|
10 |
+
peft>=0.7.0
|
train_linear_projection.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.optim import AdamW
|
5 |
+
from transformers import SiglipVisionModel, AutoTokenizer, AutoImageProcessor, AutoModel
|
6 |
+
from torchvision.datasets import CIFAR10
|
7 |
+
from torch.utils.data import DataLoader, Subset
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from tqdm import tqdm
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
import argparse
|
14 |
+
|
15 |
+
def siglip_loss(image_embeddings, text_embeddings, temperature=0.07):
|
16 |
+
# Normalize
|
17 |
+
image_embeddings = F.normalize(image_embeddings, dim=-1)
|
18 |
+
text_embeddings = F.normalize(text_embeddings, dim=-1)
|
19 |
+
|
20 |
+
# Compute pairwise similarities
|
21 |
+
logits = image_embeddings @ text_embeddings.T # [batch_size, batch_size]
|
22 |
+
logits = logits / temperature
|
23 |
+
|
24 |
+
# Ground truth: 1.0 for matching pairs (diagonal), 0.0 for all others
|
25 |
+
batch_size = logits.size(0)
|
26 |
+
targets = torch.eye(batch_size).to(logits.device)
|
27 |
+
|
28 |
+
# Apply binary cross-entropy with logits
|
29 |
+
loss = F.binary_cross_entropy_with_logits(logits, targets)
|
30 |
+
|
31 |
+
return loss
|
32 |
+
|
33 |
+
class LinearProjection(nn.Module):
|
34 |
+
def __init__(self, input_dim, output_dim):
|
35 |
+
super().__init__()
|
36 |
+
self.linear = nn.Linear(input_dim, output_dim)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
return self.linear(x)
|
40 |
+
|
41 |
+
def get_text_embedding(text, tokenizer, device, max_length=128):
|
42 |
+
# Ensure text is not empty and has minimum content
|
43 |
+
if not text or len(text.strip()) == 0:
|
44 |
+
text = "This is a placeholder description."
|
45 |
+
|
46 |
+
# Tokenize with padding and truncation
|
47 |
+
inputs = tokenizer(
|
48 |
+
text,
|
49 |
+
return_tensors="pt",
|
50 |
+
padding='max_length', # Changed to max_length padding
|
51 |
+
truncation=True,
|
52 |
+
max_length=max_length # Fixed max length for all inputs
|
53 |
+
)
|
54 |
+
|
55 |
+
# Move inputs to device and ensure correct data type
|
56 |
+
inputs = {
|
57 |
+
k: v.to(device).float() for k, v in inputs.items()
|
58 |
+
}
|
59 |
+
|
60 |
+
# Return the input_ids as embeddings
|
61 |
+
return inputs['input_ids'].float() # Convert to float for the loss calculation
|
62 |
+
|
63 |
+
def main(num_images=100, batch_size=32, num_epochs=50, learning_rate=1e-4, load_checkpoint=True, checkpoint_path='linear_projection.pth'):
|
64 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
65 |
+
print(f"Using device: {device}")
|
66 |
+
|
67 |
+
# Load models and processors
|
68 |
+
siglip_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384")
|
69 |
+
siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
|
70 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
71 |
+
|
72 |
+
# Set padding token if not set
|
73 |
+
if tokenizer.pad_token is None:
|
74 |
+
tokenizer.pad_token = tokenizer.eos_token
|
75 |
+
|
76 |
+
# Freeze SigLIP model
|
77 |
+
for param in siglip_model.parameters():
|
78 |
+
param.requires_grad = False
|
79 |
+
|
80 |
+
siglip_model.to(device)
|
81 |
+
|
82 |
+
# Get SigLIP output dimension and text embedding dimension
|
83 |
+
# Create a proper dummy image (black image)
|
84 |
+
dummy_image = Image.new('RGB', (384, 384), color='black')
|
85 |
+
with torch.no_grad():
|
86 |
+
siglip_inputs = siglip_processor(dummy_image, return_tensors="pt").to(device)
|
87 |
+
siglip_outputs = siglip_model(**siglip_inputs)
|
88 |
+
siglip_output_dim = siglip_outputs.pooler_output.shape[-1]
|
89 |
+
|
90 |
+
# Get a sample text to determine embedding dimension
|
91 |
+
dummy_text = "This is a test."
|
92 |
+
dummy_embedding = get_text_embedding(dummy_text, tokenizer, device)
|
93 |
+
text_embedding_dim = dummy_embedding.shape[-1]
|
94 |
+
|
95 |
+
print(f"SigLIP output dimension: {siglip_output_dim}")
|
96 |
+
print(f"Text embedding dimension: {text_embedding_dim}")
|
97 |
+
|
98 |
+
# Create linear projection layer
|
99 |
+
linear_proj = LinearProjection(siglip_output_dim, text_embedding_dim).to(device)
|
100 |
+
|
101 |
+
# Load checkpoint if requested
|
102 |
+
if load_checkpoint:
|
103 |
+
try:
|
104 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
105 |
+
linear_proj.load_state_dict(checkpoint)
|
106 |
+
print(f"Successfully loaded checkpoint from {checkpoint_path}")
|
107 |
+
except Exception as e:
|
108 |
+
print(f"Error loading checkpoint: {e}")
|
109 |
+
print("Starting training from scratch instead.")
|
110 |
+
|
111 |
+
# Load CIFAR10 test dataset
|
112 |
+
transform = transforms.Compose([
|
113 |
+
transforms.Resize((384, 384)),
|
114 |
+
transforms.ToTensor(),
|
115 |
+
])
|
116 |
+
|
117 |
+
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
|
118 |
+
subset_indices = list(range(num_images))
|
119 |
+
subset_dataset = Subset(test_dataset, subset_indices)
|
120 |
+
dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False)
|
121 |
+
|
122 |
+
# Create text files directory if it doesn't exist
|
123 |
+
os.makedirs('qa_outputs', exist_ok=True)
|
124 |
+
|
125 |
+
# Optimizer
|
126 |
+
optimizer = AdamW(linear_proj.parameters(), lr=learning_rate)
|
127 |
+
|
128 |
+
# Training loop
|
129 |
+
for epoch in range(num_epochs):
|
130 |
+
total_loss = 0
|
131 |
+
linear_proj.train()
|
132 |
+
|
133 |
+
progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
|
134 |
+
for batch_idx, (images, labels) in enumerate(progress_bar):
|
135 |
+
images = images.to(device)
|
136 |
+
batch_size = images.size(0)
|
137 |
+
|
138 |
+
# Get image embeddings
|
139 |
+
with torch.no_grad():
|
140 |
+
siglip_inputs = siglip_processor(images, return_tensors="pt").to(device)
|
141 |
+
siglip_outputs = siglip_model(**siglip_inputs)
|
142 |
+
image_features = siglip_outputs.pooler_output
|
143 |
+
|
144 |
+
# Project image features
|
145 |
+
projected_image_features = linear_proj(image_features)
|
146 |
+
|
147 |
+
# Process text for each line (1 to 5)
|
148 |
+
total_batch_loss = 0
|
149 |
+
for line_num in range(5):
|
150 |
+
text_embeddings_list = []
|
151 |
+
|
152 |
+
# Read text from files for current batch
|
153 |
+
for idx in range(batch_size):
|
154 |
+
global_idx = batch_idx * batch_size + idx
|
155 |
+
if global_idx < num_images:
|
156 |
+
file_path = f'qa_outputs/image_{global_idx}_extr.txt'
|
157 |
+
try:
|
158 |
+
with open(file_path, 'r') as f:
|
159 |
+
lines = f.readlines()
|
160 |
+
text = lines[line_num].strip() if line_num < len(lines) else ""
|
161 |
+
except:
|
162 |
+
text = "No description available"
|
163 |
+
|
164 |
+
# Get text embeddings directly from tokenizer
|
165 |
+
text_embedding = get_text_embedding(text, tokenizer, device)
|
166 |
+
text_embeddings_list.append(text_embedding)
|
167 |
+
|
168 |
+
if text_embeddings_list:
|
169 |
+
# Stack instead of cat since all embeddings have same size now
|
170 |
+
text_embeddings = torch.stack(text_embeddings_list, dim=0).squeeze(1)
|
171 |
+
loss = siglip_loss(projected_image_features, text_embeddings)
|
172 |
+
total_batch_loss += loss
|
173 |
+
|
174 |
+
# Average loss over all text lines
|
175 |
+
avg_batch_loss = total_batch_loss / 5
|
176 |
+
|
177 |
+
# Backpropagation
|
178 |
+
optimizer.zero_grad()
|
179 |
+
avg_batch_loss.backward()
|
180 |
+
optimizer.step()
|
181 |
+
|
182 |
+
total_loss += avg_batch_loss.item()
|
183 |
+
progress_bar.set_postfix({'loss': avg_batch_loss.item()})
|
184 |
+
|
185 |
+
avg_epoch_loss = total_loss / len(dataloader)
|
186 |
+
print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.4f}')
|
187 |
+
|
188 |
+
# Save checkpoint after each epoch
|
189 |
+
# checkpoint_dir = 'checkpoints'
|
190 |
+
# os.makedirs(checkpoint_dir, exist_ok=True)
|
191 |
+
# checkpoint_file = os.path.join(checkpoint_dir, f'linear_projection_epoch_{epoch+1}.pth')
|
192 |
+
# torch.save(linear_proj.state_dict(), checkpoint_file)
|
193 |
+
# print(f"Saved checkpoint to {checkpoint_file}")
|
194 |
+
|
195 |
+
# Save final model
|
196 |
+
torch.save(linear_proj.state_dict(), 'linear_projection_final.pth')
|
197 |
+
print("Training completed. Final model saved as 'linear_projection_final.pth'")
|
198 |
+
|
199 |
+
if __name__ == "__main__":
|
200 |
+
parser = argparse.ArgumentParser(description='Train or continue training the linear projection layer')
|
201 |
+
parser.add_argument('--num_images', type=int, default=100, help='Number of images to train on')
|
202 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
|
203 |
+
parser.add_argument('--num_epochs', type=int, default=50, help='Number of epochs to train')
|
204 |
+
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
205 |
+
parser.add_argument('--load_checkpoint', action='store_true', help='Whether to load from checkpoint')
|
206 |
+
parser.add_argument('--checkpoint_path', type=str, default='linear_projection.pth', help='Path to checkpoint file')
|
207 |
+
|
208 |
+
args = parser.parse_args()
|
209 |
+
main(
|
210 |
+
num_images=args.num_images,
|
211 |
+
batch_size=args.batch_size,
|
212 |
+
num_epochs=args.num_epochs,
|
213 |
+
learning_rate=args.learning_rate,
|
214 |
+
load_checkpoint=args.load_checkpoint,
|
215 |
+
checkpoint_path=args.checkpoint_path
|
216 |
+
)
|
train_phi_with_siglip.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.optim import AdamW
|
4 |
+
from transformers import (
|
5 |
+
SiglipVisionModel,
|
6 |
+
AutoTokenizer,
|
7 |
+
AutoImageProcessor,
|
8 |
+
AutoModelForCausalLM,
|
9 |
+
BitsAndBytesConfig
|
10 |
+
)
|
11 |
+
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
|
12 |
+
from torchvision.datasets import CIFAR10
|
13 |
+
from torch.utils.data import DataLoader, Subset
|
14 |
+
import torchvision.transforms as transforms
|
15 |
+
from tqdm import tqdm
|
16 |
+
import os
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
class LinearProjection(nn.Module):
|
20 |
+
def __init__(self, input_dim, output_dim):
|
21 |
+
super().__init__()
|
22 |
+
self.linear = nn.Linear(input_dim, output_dim)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return self.linear(x)
|
26 |
+
|
27 |
+
class ImageTextProjection(nn.Module):
|
28 |
+
def __init__(self, image_dim, text_dim):
|
29 |
+
super().__init__()
|
30 |
+
self.image_projection = nn.Linear(image_dim, text_dim)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return self.image_projection(x)
|
34 |
+
|
35 |
+
def get_image_embedding(image, siglip_model, siglip_processor, linear_proj, device):
|
36 |
+
with torch.no_grad():
|
37 |
+
# Process image through SigLIP
|
38 |
+
inputs = siglip_processor(image, return_tensors="pt")
|
39 |
+
# Move inputs to the same device as model
|
40 |
+
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
41 |
+
outputs = siglip_model(**inputs)
|
42 |
+
image_features = outputs.pooler_output
|
43 |
+
|
44 |
+
# Project through trained linear layer
|
45 |
+
projected_features = linear_proj(image_features)
|
46 |
+
|
47 |
+
return projected_features
|
48 |
+
|
49 |
+
def main(
|
50 |
+
num_images=100,
|
51 |
+
batch_size=4, # Smaller batch size due to memory constraints
|
52 |
+
num_epochs=100,
|
53 |
+
learning_rate=2e-4,
|
54 |
+
questions=None # List of 5 questions to be provided
|
55 |
+
):
|
56 |
+
if questions is None or len(questions) != 5:
|
57 |
+
print("Please provide exactly 5 questions!")
|
58 |
+
return
|
59 |
+
|
60 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
61 |
+
print(f"Using device: {device}")
|
62 |
+
|
63 |
+
# Load SigLIP model and processor
|
64 |
+
siglip_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384").to(device)
|
65 |
+
siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
|
66 |
+
|
67 |
+
# Load trained linear projection
|
68 |
+
dummy_image = Image.new('RGB', (384, 384), color='black')
|
69 |
+
with torch.no_grad():
|
70 |
+
siglip_inputs = siglip_processor(dummy_image, return_tensors="pt")
|
71 |
+
# Move inputs to device
|
72 |
+
siglip_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in siglip_inputs.items()}
|
73 |
+
siglip_outputs = siglip_model(**siglip_inputs)
|
74 |
+
siglip_output_dim = siglip_outputs.pooler_output.shape[-1]
|
75 |
+
|
76 |
+
# First load the checkpoint to get the correct output dimension
|
77 |
+
checkpoint = torch.load('linear_projection_final.pth', map_location=device)
|
78 |
+
output_dim = checkpoint['linear.weight'].shape[0] # Get the output dimension from saved weights
|
79 |
+
print(f"Loading linear projection with output dimension: {output_dim}")
|
80 |
+
|
81 |
+
# Initialize linear projection with correct dimensions
|
82 |
+
linear_proj = LinearProjection(siglip_output_dim, output_dim).to(device)
|
83 |
+
try:
|
84 |
+
linear_proj.load_state_dict(checkpoint)
|
85 |
+
print("Successfully loaded linear projection weights")
|
86 |
+
except Exception as e:
|
87 |
+
print(f"Error loading linear projection weights: {e}")
|
88 |
+
return
|
89 |
+
|
90 |
+
# Load Phi model with 4-bit quantization
|
91 |
+
bnb_config = BitsAndBytesConfig(
|
92 |
+
load_in_4bit=True,
|
93 |
+
bnb_4bit_quant_type="nf4",
|
94 |
+
bnb_4bit_compute_dtype=torch.float16,
|
95 |
+
bnb_4bit_use_double_quant=False
|
96 |
+
)
|
97 |
+
|
98 |
+
phi_model = AutoModelForCausalLM.from_pretrained(
|
99 |
+
"microsoft/Phi-3-mini-4k-instruct",
|
100 |
+
quantization_config=bnb_config,
|
101 |
+
device_map="auto"
|
102 |
+
)
|
103 |
+
phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
104 |
+
|
105 |
+
# Add padding token if not present
|
106 |
+
if phi_tokenizer.pad_token is None:
|
107 |
+
phi_tokenizer.pad_token = phi_tokenizer.eos_token
|
108 |
+
|
109 |
+
# Get embedding dimension from phi model
|
110 |
+
phi_embed_dim = phi_model.get_input_embeddings().weight.shape[1]
|
111 |
+
|
112 |
+
# Create projection layer for image embeddings
|
113 |
+
image_text_proj = ImageTextProjection(output_dim, phi_embed_dim).to(device)
|
114 |
+
|
115 |
+
# Prepare model for k-bit training
|
116 |
+
phi_model = prepare_model_for_kbit_training(phi_model)
|
117 |
+
|
118 |
+
# Setup LoRA configuration
|
119 |
+
lora_config = LoraConfig(
|
120 |
+
r=16,
|
121 |
+
lora_alpha=32,
|
122 |
+
target_modules=["mlp.dense_h_to_4h", "mlp.dense_4h_to_h", "self_attn.qkv_proj", "self_attn.dense"],
|
123 |
+
lora_dropout=0.05,
|
124 |
+
bias="none",
|
125 |
+
task_type="CAUSAL_LM"
|
126 |
+
)
|
127 |
+
|
128 |
+
# Get PEFT model
|
129 |
+
phi_model = get_peft_model(phi_model, lora_config)
|
130 |
+
|
131 |
+
# Freeze SigLIP and linear projection
|
132 |
+
for param in siglip_model.parameters():
|
133 |
+
param.requires_grad = False
|
134 |
+
for param in linear_proj.parameters():
|
135 |
+
param.requires_grad = False
|
136 |
+
|
137 |
+
# Load CIFAR10 test dataset
|
138 |
+
transform = transforms.Compose([
|
139 |
+
transforms.Resize((384, 384)),
|
140 |
+
transforms.ToTensor(),
|
141 |
+
])
|
142 |
+
|
143 |
+
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
|
144 |
+
subset_indices = list(range(num_images))
|
145 |
+
subset_dataset = Subset(test_dataset, subset_indices)
|
146 |
+
dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False)
|
147 |
+
|
148 |
+
# Optimizer for both phi model and image projection
|
149 |
+
optimizer = AdamW([
|
150 |
+
{'params': phi_model.parameters()},
|
151 |
+
{'params': image_text_proj.parameters()}
|
152 |
+
], lr=learning_rate)
|
153 |
+
|
154 |
+
# Training loop
|
155 |
+
for epoch in range(num_epochs):
|
156 |
+
total_loss = 0
|
157 |
+
phi_model.train()
|
158 |
+
image_text_proj.train()
|
159 |
+
|
160 |
+
progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
|
161 |
+
for batch_idx, (images, _) in enumerate(progress_bar):
|
162 |
+
images = images.to(device)
|
163 |
+
batch_size = images.size(0)
|
164 |
+
|
165 |
+
# Get image embeddings
|
166 |
+
image_embeddings = get_image_embedding(images, siglip_model, siglip_processor, linear_proj, device)
|
167 |
+
|
168 |
+
# Process each question
|
169 |
+
for q_idx, question in enumerate(questions):
|
170 |
+
# Read corresponding answers
|
171 |
+
answers = []
|
172 |
+
for idx in range(batch_size):
|
173 |
+
global_idx = batch_idx * batch_size + idx
|
174 |
+
if global_idx < num_images:
|
175 |
+
file_path = f'qa_outputs/image_{global_idx}_extr.txt'
|
176 |
+
try:
|
177 |
+
with open(file_path, 'r') as f:
|
178 |
+
lines = f.readlines()
|
179 |
+
answer = lines[q_idx].strip() if q_idx < len(lines) else ""
|
180 |
+
answers.append(answer)
|
181 |
+
except:
|
182 |
+
answers.append("No answer available")
|
183 |
+
|
184 |
+
# Tokenize questions and answers for the entire batch
|
185 |
+
question_tokens = phi_tokenizer(
|
186 |
+
[question] * batch_size,
|
187 |
+
padding=True,
|
188 |
+
truncation=True,
|
189 |
+
max_length=512,
|
190 |
+
return_tensors="pt"
|
191 |
+
).to(device)
|
192 |
+
|
193 |
+
target_tokens = phi_tokenizer(
|
194 |
+
answers,
|
195 |
+
padding=True,
|
196 |
+
truncation=True,
|
197 |
+
max_length=512,
|
198 |
+
return_tensors="pt"
|
199 |
+
).to(device)
|
200 |
+
|
201 |
+
# Get question embeddings for the entire batch
|
202 |
+
question_embeds = phi_model.get_input_embeddings()(question_tokens['input_ids']) # [batch_size, seq_len, embed_dim]
|
203 |
+
|
204 |
+
# Project and prepare image embeddings for the entire batch
|
205 |
+
image_embeds = image_text_proj(image_embeddings) # [batch_size, embed_dim]
|
206 |
+
image_embeds = image_embeds.unsqueeze(1) # [batch_size, 1, embed_dim]
|
207 |
+
|
208 |
+
# Combine image embeddings with question embeddings
|
209 |
+
combined_embedding = torch.cat([
|
210 |
+
image_embeds, # [batch_size, 1, embed_dim]
|
211 |
+
question_embeds # [batch_size, seq_len, embed_dim]
|
212 |
+
], dim=1) # [batch_size, 1+seq_len, embed_dim]
|
213 |
+
|
214 |
+
# Create attention mask for the combined sequence
|
215 |
+
attention_mask = torch.ones(
|
216 |
+
(batch_size, combined_embedding.size(1)),
|
217 |
+
dtype=torch.long,
|
218 |
+
device=device
|
219 |
+
)
|
220 |
+
|
221 |
+
# Prepare labels by shifting them right
|
222 |
+
labels = target_tokens['input_ids'].clone()
|
223 |
+
labels = torch.cat([
|
224 |
+
torch.full((batch_size, combined_embedding.size(1) - 1), -100, device=device),
|
225 |
+
labels
|
226 |
+
], dim=1)[:, :combined_embedding.size(1)]
|
227 |
+
|
228 |
+
# Forward pass
|
229 |
+
outputs = phi_model(
|
230 |
+
inputs_embeds=combined_embedding,
|
231 |
+
attention_mask=attention_mask,
|
232 |
+
labels=labels
|
233 |
+
)
|
234 |
+
|
235 |
+
loss = outputs.loss
|
236 |
+
total_loss += loss.item()
|
237 |
+
|
238 |
+
# Backward pass
|
239 |
+
loss.backward()
|
240 |
+
optimizer.step()
|
241 |
+
optimizer.zero_grad()
|
242 |
+
|
243 |
+
progress_bar.set_postfix({'loss': loss.item()})
|
244 |
+
|
245 |
+
avg_epoch_loss = total_loss / (len(dataloader) * len(questions) * batch_size)
|
246 |
+
print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.4f}')
|
247 |
+
|
248 |
+
# Save the trained models
|
249 |
+
phi_model.save_pretrained('phi_model_trained')
|
250 |
+
torch.save(image_text_proj.state_dict(), 'image_text_proj.pth')
|
251 |
+
print("Training completed. Models saved as 'phi_model_trained' and 'image_text_proj.pth'")
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
# Example questions - replace with your actual questions
|
255 |
+
questions = [
|
256 |
+
"Give a description of the image?",
|
257 |
+
"How does the main object in the image look like?",
|
258 |
+
"How can the main object in the image be useful to humans?",
|
259 |
+
"What is the color of the main object in the image?",
|
260 |
+
"Describe the setting of the image?"
|
261 |
+
]
|
262 |
+
|
263 |
+
main(questions=questions)
|