chbsaikiran commited on
Commit
5e37be9
·
1 Parent(s): 7eb024a

Intial Commit

Browse files
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision
4
+ import torchvision.transforms as transforms
5
+ import random
6
+ import numpy as np
7
+ from transformers import (
8
+ SiglipVisionModel,
9
+ AutoTokenizer,
10
+ AutoImageProcessor,
11
+ AutoModelForCausalLM,
12
+ BitsAndBytesConfig
13
+ )
14
+ from PIL import Image
15
+
16
+ # Initialize device
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ # Load models and processors
20
+ def load_models():
21
+ # Load SigLIP
22
+ siglip_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384").to(device)
23
+ siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
24
+
25
+ # Load Phi model with 4-bit quantization
26
+ bnb_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_quant_type="nf4",
29
+ bnb_4bit_compute_dtype=torch.float16,
30
+ bnb_4bit_use_double_quant=False
31
+ )
32
+ phi_model = AutoModelForCausalLM.from_pretrained(
33
+ "phi_model_trained", # Load from saved directory
34
+ quantization_config=bnb_config,
35
+ device_map="auto"
36
+ )
37
+ phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
38
+ if phi_tokenizer.pad_token is None:
39
+ phi_tokenizer.pad_token = phi_tokenizer.eos_token
40
+
41
+ # Load trained projections
42
+ linear_proj = torch.load('linear_projection_final.pth', map_location=device)
43
+ image_text_proj = torch.load('image_text_proj.pth', map_location=device)
44
+
45
+ return (siglip_model, siglip_processor, phi_model, phi_tokenizer, linear_proj, image_text_proj)
46
+
47
+ # Load all models at startup
48
+ print("Loading models...")
49
+ models = load_models()
50
+ siglip_model, siglip_processor, phi_model, phi_tokenizer, linear_proj, image_text_proj = models
51
+ print("Models loaded successfully!")
52
+
53
+ # Load CIFAR10 test dataset
54
+ transform = transforms.Compose([
55
+ transforms.Resize((384, 384)),
56
+ transforms.ToTensor(),
57
+ ])
58
+
59
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
60
+
61
+ # Get first 100 images
62
+ first_100_images = [(images, labels) for images, labels in list(testset)[:100]]
63
+
64
+ # Questions list
65
+ questions = [
66
+ "Give a description of the image?",
67
+ "How does the main object in the image look like?",
68
+ "How can the main object in the image be useful to humans?",
69
+ "What is the color of the main object in the image?",
70
+ "Describe the setting of the image?"
71
+ ]
72
+
73
+ def get_image_embedding(image, siglip_model, siglip_processor, linear_proj, device):
74
+ with torch.no_grad():
75
+ # Process image through SigLIP
76
+ inputs = siglip_processor(image, return_tensors="pt")
77
+ # Move inputs to device
78
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
79
+ outputs = siglip_model(**inputs)
80
+ image_features = outputs.pooler_output
81
+
82
+ # Project through trained linear layer
83
+ projected_features = linear_proj(image_features)
84
+
85
+ return projected_features
86
+
87
+ def get_random_images():
88
+ # Select 10 random images from first 100
89
+ selected_indices = random.sample(range(100), 10)
90
+ selected_images = [first_100_images[i][0] for i in selected_indices]
91
+
92
+ # Convert to numpy arrays and transpose to correct format (H,W,C)
93
+ images_np = [img.permute(1, 2, 0).numpy() for img in selected_images]
94
+ return images_np, selected_indices
95
+
96
+ def generate_answer(image_tensor, question_index):
97
+ if image_tensor is None:
98
+ return "Please select an image first!"
99
+
100
+ try:
101
+ # Get image embedding
102
+ image_embedding = get_image_embedding(
103
+ image_tensor,
104
+ siglip_model,
105
+ siglip_processor,
106
+ linear_proj,
107
+ device
108
+ )
109
+
110
+ # Get question
111
+ question = questions[question_index]
112
+
113
+ # Tokenize question
114
+ question_tokens = phi_tokenizer(
115
+ question,
116
+ padding=True,
117
+ truncation=True,
118
+ max_length=512,
119
+ return_tensors="pt"
120
+ ).to(device)
121
+
122
+ # Get question embeddings
123
+ question_embeds = phi_model.get_input_embeddings()(question_tokens['input_ids'])
124
+
125
+ # Project and prepare image embeddings
126
+ image_embeds = image_text_proj(image_embedding)
127
+ image_embeds = image_embeds.unsqueeze(1)
128
+
129
+ # Combine embeddings
130
+ combined_embedding = torch.cat([
131
+ image_embeds,
132
+ question_embeds
133
+ ], dim=1)
134
+
135
+ # Create attention mask
136
+ attention_mask = torch.ones(
137
+ (1, combined_embedding.size(1)),
138
+ dtype=torch.long,
139
+ device=device
140
+ )
141
+
142
+ # Generate answer
143
+ with torch.no_grad():
144
+ outputs = phi_model.generate(
145
+ inputs_embeds=combined_embedding,
146
+ attention_mask=attention_mask,
147
+ max_new_tokens=100,
148
+ num_beams=4,
149
+ temperature=0.7,
150
+ do_sample=True,
151
+ pad_token_id=phi_tokenizer.pad_token_id,
152
+ eos_token_id=phi_tokenizer.eos_token_id
153
+ )
154
+
155
+ # Decode the generated answer
156
+ answer = phi_tokenizer.decode(outputs[0], skip_special_tokens=True)
157
+ return answer
158
+
159
+ except Exception as e:
160
+ return f"Error generating answer: {str(e)}"
161
+
162
+ # Create Gradio interface
163
+ with gr.Blocks() as demo:
164
+ gr.Markdown("# CIFAR10 Image Question Answering System")
165
+
166
+ # State variables
167
+ selected_image_tensor = gr.State(None)
168
+ image_indices = gr.State([])
169
+
170
+ with gr.Row():
171
+ with gr.Column():
172
+ # Button to get random images
173
+ random_btn = gr.Button("Get Random Images")
174
+ # Gallery to display images
175
+ gallery = gr.Gallery(
176
+ label="Click an image to select it",
177
+ show_label=True,
178
+ elem_id="gallery",
179
+ columns=[5],
180
+ rows=[2],
181
+ height="auto",
182
+ allow_preview=False
183
+ )
184
+
185
+ with gr.Column():
186
+ # Display selected image
187
+ selected_img = gr.Image(label="Selected Image", height=200)
188
+ # Question buttons
189
+ q_buttons = []
190
+ for i, q in enumerate(questions):
191
+ btn = gr.Button(f"Q{i+1}: {q}")
192
+ q_buttons.append(btn)
193
+ # Answer textbox
194
+ answer_box = gr.Textbox(label="Answer", lines=3)
195
+
196
+ # Handle random image button click
197
+ def on_random_click():
198
+ images, indices = get_random_images()
199
+ return {
200
+ gallery: images,
201
+ image_indices: indices,
202
+ selected_image_tensor: None,
203
+ selected_img: None,
204
+ answer_box: ""
205
+ }
206
+
207
+ random_btn.click(
208
+ on_random_click,
209
+ outputs=[gallery, image_indices, selected_image_tensor, selected_img, answer_box]
210
+ )
211
+
212
+ # Handle image selection
213
+ def on_image_select(evt: gr.SelectData, images, indices):
214
+ if images is None or evt.index >= len(images):
215
+ return None, None, ""
216
+ selected_idx = indices[evt.index]
217
+ selected_tensor = first_100_images[selected_idx][0]
218
+ return selected_tensor, images[evt.index], ""
219
+
220
+ gallery.select(
221
+ on_image_select,
222
+ inputs=[gallery, image_indices],
223
+ outputs=[selected_image_tensor, selected_img, answer_box]
224
+ )
225
+
226
+ # Handle question button clicks
227
+ for i, btn in enumerate(q_buttons):
228
+ btn.click(
229
+ generate_answer,
230
+ inputs=[selected_image_tensor, gr.Number(value=i, visible=False)],
231
+ outputs=answer_box
232
+ )
233
+
234
+ demo.launch()
extract_answers.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import glob
4
+
5
+ def extract_assistant_answers(input_file):
6
+ """Extract the text after 'Assistant:' from the input file."""
7
+ with open(input_file, 'r', encoding='utf-8') as f:
8
+ content = f.read()
9
+
10
+ # Split content by "Assistant:" to get all sections after it
11
+ sections = content.split("Assistant:")
12
+
13
+ # Process each section to get clean answers
14
+ answers = []
15
+ for section in sections[1:]: # Skip the first split as it's before first "Assistant:"
16
+ # Get text up to next "Q" or "User:" or end of string
17
+ answer = section.split("Q")[0].split("User:")[0].strip()
18
+ if answer:
19
+ answers.append(answer)
20
+
21
+ return answers
22
+
23
+ def process_all_files():
24
+ """Process all image_*.txt files in the qa_outputs directory."""
25
+ # Get all image_*.txt files
26
+ input_files = glob.glob("qa_outputs/image_*.txt")
27
+
28
+ for input_file in input_files:
29
+ # Extract the base name without extension
30
+ base_name = os.path.splitext(input_file)[0]
31
+ output_file = f"{base_name}_extr.txt"
32
+
33
+ # Extract answers
34
+ answers = extract_assistant_answers(input_file)
35
+
36
+ # Write answers to the output file
37
+ with open(output_file, 'w', encoding='utf-8') as f:
38
+ for i, answer in enumerate(answers, 1):
39
+ f.write(f"{answer}\n")
40
+
41
+ print(f"Processed {input_file} -> {output_file}")
42
+
43
+ if __name__ == "__main__":
44
+ process_all_files()
45
+ print("Extraction complete! Check the files with '_extr' suffix.")
image_text_proj.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddb52b49da8704aff3b46f2c503ac20cf64e3e6efbe4844e9ac89e85d9673894
3
+ size 1586824
linear_projection_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:374fc454085ce3b227047bfbf45c2df30a97a308b593ad1f2ef5ec763cab5afb
3
+ size 592056
phi_model_trained/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: microsoft/Phi-3-mini-4k-instruct
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.15.1
phi_model_trained/adapter_config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "microsoft/Phi-3-mini-4k-instruct",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 32,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.05,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "r": 16,
24
+ "rank_pattern": {},
25
+ "revision": null,
26
+ "target_modules": [
27
+ "mlp.dense_h_to_4h",
28
+ "mlp.dense_4h_to_h",
29
+ "self_attn.qkv_proj",
30
+ "self_attn.dense"
31
+ ],
32
+ "task_type": "CAUSAL_LM",
33
+ "trainable_token_indices": null,
34
+ "use_dora": false,
35
+ "use_rslora": false
36
+ }
phi_model_trained/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91ad139ce2f99c0ed85ff06edd5ac6b766baef76fab1d3e896c9ac32589e96fb
3
+ size 25174552
process_cifar10.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import os
6
+ from transformers import AutoProcessor, AutoModelForImageTextToText
7
+ from tqdm import tqdm
8
+
9
+ # Initialize model and processor
10
+ model_path = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
11
+ processor = AutoProcessor.from_pretrained(model_path)
12
+ model = AutoModelForImageTextToText.from_pretrained(
13
+ model_path,
14
+ torch_dtype=torch.bfloat16
15
+ #_attn_implementation="flash_attention_2"
16
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Create output directory
19
+ os.makedirs("SigLIP_Training/qa_outputs", exist_ok=True)
20
+
21
+ # Load CIFAR-10 dataset
22
+ transform = transforms.Compose([
23
+ transforms.ToTensor(),
24
+ transforms.ToPILImage()
25
+ ])
26
+
27
+ # Using test set instead of train set
28
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False,
29
+ download=True, transform=transform)
30
+
31
+ # List of questions
32
+ questions = [
33
+ "Give a description of the image?",
34
+ "How does the main object in the image look like?",
35
+ "How can the main object in the image be useful to humans?",
36
+ "What is the color of the main object in the image?",
37
+ "Describe the setting of the image?"
38
+ ]
39
+
40
+ def process_image(image, image_idx):
41
+ # Create output file
42
+ output_file = f"SigLIP_Training/qa_outputs/image_{image_idx}.txt"
43
+
44
+ with open(output_file, 'w') as f:
45
+ for q_idx, question in enumerate(questions, 1):
46
+ # Prepare the message for the model
47
+ messages = [
48
+ {
49
+ "role": "user",
50
+ "content": [
51
+ {"type": "image", "image": image},
52
+ {"type": "text", "text": question}
53
+ ]
54
+ }
55
+ ]
56
+
57
+ # Process inputs
58
+ inputs = processor.apply_chat_template(
59
+ messages,
60
+ add_generation_prompt=True,
61
+ tokenize=True,
62
+ return_dict=True,
63
+ return_tensors="pt"
64
+ ).to(model.device, dtype=torch.bfloat16)
65
+
66
+ # Generate answer
67
+ generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=64)
68
+ answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
69
+
70
+ # Write to file in the correct format
71
+ f.write(f"Q{q_idx}: {question}\n")
72
+ f.write(f"A{q_idx}: {answer}\n")
73
+
74
+ # Process all images from test set
75
+ print(f"Starting to process CIFAR-10 test set images...")
76
+ for idx, (image, _) in enumerate(tqdm(testset)):
77
+ process_image(image, idx)
78
+ #if idx >= 1000: # Process first 1000 test images
79
+ # break
80
+
81
+ print("Processing complete! Check the SigLIP_Training/qa_outputs directory for results.")
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.36.0
3
+ torchvision>=0.15.0
4
+ pillow>=9.3.0
5
+ tqdm>=4.65.0
6
+ numpy>=1.24.0
7
+ accelerate>=0.25.0
8
+ gradio>=4.19.0
9
+ bitsandbytes>=0.41.1
10
+ peft>=0.7.0
train_linear_projection.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.optim import AdamW
5
+ from transformers import SiglipVisionModel, AutoTokenizer, AutoImageProcessor, AutoModel
6
+ from torchvision.datasets import CIFAR10
7
+ from torch.utils.data import DataLoader, Subset
8
+ import torchvision.transforms as transforms
9
+ from tqdm import tqdm
10
+ import os
11
+ import numpy as np
12
+ from PIL import Image
13
+ import argparse
14
+
15
+ def siglip_loss(image_embeddings, text_embeddings, temperature=0.07):
16
+ # Normalize
17
+ image_embeddings = F.normalize(image_embeddings, dim=-1)
18
+ text_embeddings = F.normalize(text_embeddings, dim=-1)
19
+
20
+ # Compute pairwise similarities
21
+ logits = image_embeddings @ text_embeddings.T # [batch_size, batch_size]
22
+ logits = logits / temperature
23
+
24
+ # Ground truth: 1.0 for matching pairs (diagonal), 0.0 for all others
25
+ batch_size = logits.size(0)
26
+ targets = torch.eye(batch_size).to(logits.device)
27
+
28
+ # Apply binary cross-entropy with logits
29
+ loss = F.binary_cross_entropy_with_logits(logits, targets)
30
+
31
+ return loss
32
+
33
+ class LinearProjection(nn.Module):
34
+ def __init__(self, input_dim, output_dim):
35
+ super().__init__()
36
+ self.linear = nn.Linear(input_dim, output_dim)
37
+
38
+ def forward(self, x):
39
+ return self.linear(x)
40
+
41
+ def get_text_embedding(text, tokenizer, device, max_length=128):
42
+ # Ensure text is not empty and has minimum content
43
+ if not text or len(text.strip()) == 0:
44
+ text = "This is a placeholder description."
45
+
46
+ # Tokenize with padding and truncation
47
+ inputs = tokenizer(
48
+ text,
49
+ return_tensors="pt",
50
+ padding='max_length', # Changed to max_length padding
51
+ truncation=True,
52
+ max_length=max_length # Fixed max length for all inputs
53
+ )
54
+
55
+ # Move inputs to device and ensure correct data type
56
+ inputs = {
57
+ k: v.to(device).float() for k, v in inputs.items()
58
+ }
59
+
60
+ # Return the input_ids as embeddings
61
+ return inputs['input_ids'].float() # Convert to float for the loss calculation
62
+
63
+ def main(num_images=100, batch_size=32, num_epochs=50, learning_rate=1e-4, load_checkpoint=True, checkpoint_path='linear_projection.pth'):
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ print(f"Using device: {device}")
66
+
67
+ # Load models and processors
68
+ siglip_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384")
69
+ siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
70
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
71
+
72
+ # Set padding token if not set
73
+ if tokenizer.pad_token is None:
74
+ tokenizer.pad_token = tokenizer.eos_token
75
+
76
+ # Freeze SigLIP model
77
+ for param in siglip_model.parameters():
78
+ param.requires_grad = False
79
+
80
+ siglip_model.to(device)
81
+
82
+ # Get SigLIP output dimension and text embedding dimension
83
+ # Create a proper dummy image (black image)
84
+ dummy_image = Image.new('RGB', (384, 384), color='black')
85
+ with torch.no_grad():
86
+ siglip_inputs = siglip_processor(dummy_image, return_tensors="pt").to(device)
87
+ siglip_outputs = siglip_model(**siglip_inputs)
88
+ siglip_output_dim = siglip_outputs.pooler_output.shape[-1]
89
+
90
+ # Get a sample text to determine embedding dimension
91
+ dummy_text = "This is a test."
92
+ dummy_embedding = get_text_embedding(dummy_text, tokenizer, device)
93
+ text_embedding_dim = dummy_embedding.shape[-1]
94
+
95
+ print(f"SigLIP output dimension: {siglip_output_dim}")
96
+ print(f"Text embedding dimension: {text_embedding_dim}")
97
+
98
+ # Create linear projection layer
99
+ linear_proj = LinearProjection(siglip_output_dim, text_embedding_dim).to(device)
100
+
101
+ # Load checkpoint if requested
102
+ if load_checkpoint:
103
+ try:
104
+ checkpoint = torch.load(checkpoint_path, map_location=device)
105
+ linear_proj.load_state_dict(checkpoint)
106
+ print(f"Successfully loaded checkpoint from {checkpoint_path}")
107
+ except Exception as e:
108
+ print(f"Error loading checkpoint: {e}")
109
+ print("Starting training from scratch instead.")
110
+
111
+ # Load CIFAR10 test dataset
112
+ transform = transforms.Compose([
113
+ transforms.Resize((384, 384)),
114
+ transforms.ToTensor(),
115
+ ])
116
+
117
+ test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
118
+ subset_indices = list(range(num_images))
119
+ subset_dataset = Subset(test_dataset, subset_indices)
120
+ dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False)
121
+
122
+ # Create text files directory if it doesn't exist
123
+ os.makedirs('qa_outputs', exist_ok=True)
124
+
125
+ # Optimizer
126
+ optimizer = AdamW(linear_proj.parameters(), lr=learning_rate)
127
+
128
+ # Training loop
129
+ for epoch in range(num_epochs):
130
+ total_loss = 0
131
+ linear_proj.train()
132
+
133
+ progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
134
+ for batch_idx, (images, labels) in enumerate(progress_bar):
135
+ images = images.to(device)
136
+ batch_size = images.size(0)
137
+
138
+ # Get image embeddings
139
+ with torch.no_grad():
140
+ siglip_inputs = siglip_processor(images, return_tensors="pt").to(device)
141
+ siglip_outputs = siglip_model(**siglip_inputs)
142
+ image_features = siglip_outputs.pooler_output
143
+
144
+ # Project image features
145
+ projected_image_features = linear_proj(image_features)
146
+
147
+ # Process text for each line (1 to 5)
148
+ total_batch_loss = 0
149
+ for line_num in range(5):
150
+ text_embeddings_list = []
151
+
152
+ # Read text from files for current batch
153
+ for idx in range(batch_size):
154
+ global_idx = batch_idx * batch_size + idx
155
+ if global_idx < num_images:
156
+ file_path = f'qa_outputs/image_{global_idx}_extr.txt'
157
+ try:
158
+ with open(file_path, 'r') as f:
159
+ lines = f.readlines()
160
+ text = lines[line_num].strip() if line_num < len(lines) else ""
161
+ except:
162
+ text = "No description available"
163
+
164
+ # Get text embeddings directly from tokenizer
165
+ text_embedding = get_text_embedding(text, tokenizer, device)
166
+ text_embeddings_list.append(text_embedding)
167
+
168
+ if text_embeddings_list:
169
+ # Stack instead of cat since all embeddings have same size now
170
+ text_embeddings = torch.stack(text_embeddings_list, dim=0).squeeze(1)
171
+ loss = siglip_loss(projected_image_features, text_embeddings)
172
+ total_batch_loss += loss
173
+
174
+ # Average loss over all text lines
175
+ avg_batch_loss = total_batch_loss / 5
176
+
177
+ # Backpropagation
178
+ optimizer.zero_grad()
179
+ avg_batch_loss.backward()
180
+ optimizer.step()
181
+
182
+ total_loss += avg_batch_loss.item()
183
+ progress_bar.set_postfix({'loss': avg_batch_loss.item()})
184
+
185
+ avg_epoch_loss = total_loss / len(dataloader)
186
+ print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.4f}')
187
+
188
+ # Save checkpoint after each epoch
189
+ # checkpoint_dir = 'checkpoints'
190
+ # os.makedirs(checkpoint_dir, exist_ok=True)
191
+ # checkpoint_file = os.path.join(checkpoint_dir, f'linear_projection_epoch_{epoch+1}.pth')
192
+ # torch.save(linear_proj.state_dict(), checkpoint_file)
193
+ # print(f"Saved checkpoint to {checkpoint_file}")
194
+
195
+ # Save final model
196
+ torch.save(linear_proj.state_dict(), 'linear_projection_final.pth')
197
+ print("Training completed. Final model saved as 'linear_projection_final.pth'")
198
+
199
+ if __name__ == "__main__":
200
+ parser = argparse.ArgumentParser(description='Train or continue training the linear projection layer')
201
+ parser.add_argument('--num_images', type=int, default=100, help='Number of images to train on')
202
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
203
+ parser.add_argument('--num_epochs', type=int, default=50, help='Number of epochs to train')
204
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
205
+ parser.add_argument('--load_checkpoint', action='store_true', help='Whether to load from checkpoint')
206
+ parser.add_argument('--checkpoint_path', type=str, default='linear_projection.pth', help='Path to checkpoint file')
207
+
208
+ args = parser.parse_args()
209
+ main(
210
+ num_images=args.num_images,
211
+ batch_size=args.batch_size,
212
+ num_epochs=args.num_epochs,
213
+ learning_rate=args.learning_rate,
214
+ load_checkpoint=args.load_checkpoint,
215
+ checkpoint_path=args.checkpoint_path
216
+ )
train_phi_with_siglip.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.optim import AdamW
4
+ from transformers import (
5
+ SiglipVisionModel,
6
+ AutoTokenizer,
7
+ AutoImageProcessor,
8
+ AutoModelForCausalLM,
9
+ BitsAndBytesConfig
10
+ )
11
+ from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
12
+ from torchvision.datasets import CIFAR10
13
+ from torch.utils.data import DataLoader, Subset
14
+ import torchvision.transforms as transforms
15
+ from tqdm import tqdm
16
+ import os
17
+ from PIL import Image
18
+
19
+ class LinearProjection(nn.Module):
20
+ def __init__(self, input_dim, output_dim):
21
+ super().__init__()
22
+ self.linear = nn.Linear(input_dim, output_dim)
23
+
24
+ def forward(self, x):
25
+ return self.linear(x)
26
+
27
+ class ImageTextProjection(nn.Module):
28
+ def __init__(self, image_dim, text_dim):
29
+ super().__init__()
30
+ self.image_projection = nn.Linear(image_dim, text_dim)
31
+
32
+ def forward(self, x):
33
+ return self.image_projection(x)
34
+
35
+ def get_image_embedding(image, siglip_model, siglip_processor, linear_proj, device):
36
+ with torch.no_grad():
37
+ # Process image through SigLIP
38
+ inputs = siglip_processor(image, return_tensors="pt")
39
+ # Move inputs to the same device as model
40
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
41
+ outputs = siglip_model(**inputs)
42
+ image_features = outputs.pooler_output
43
+
44
+ # Project through trained linear layer
45
+ projected_features = linear_proj(image_features)
46
+
47
+ return projected_features
48
+
49
+ def main(
50
+ num_images=100,
51
+ batch_size=4, # Smaller batch size due to memory constraints
52
+ num_epochs=100,
53
+ learning_rate=2e-4,
54
+ questions=None # List of 5 questions to be provided
55
+ ):
56
+ if questions is None or len(questions) != 5:
57
+ print("Please provide exactly 5 questions!")
58
+ return
59
+
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ print(f"Using device: {device}")
62
+
63
+ # Load SigLIP model and processor
64
+ siglip_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384").to(device)
65
+ siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
66
+
67
+ # Load trained linear projection
68
+ dummy_image = Image.new('RGB', (384, 384), color='black')
69
+ with torch.no_grad():
70
+ siglip_inputs = siglip_processor(dummy_image, return_tensors="pt")
71
+ # Move inputs to device
72
+ siglip_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in siglip_inputs.items()}
73
+ siglip_outputs = siglip_model(**siglip_inputs)
74
+ siglip_output_dim = siglip_outputs.pooler_output.shape[-1]
75
+
76
+ # First load the checkpoint to get the correct output dimension
77
+ checkpoint = torch.load('linear_projection_final.pth', map_location=device)
78
+ output_dim = checkpoint['linear.weight'].shape[0] # Get the output dimension from saved weights
79
+ print(f"Loading linear projection with output dimension: {output_dim}")
80
+
81
+ # Initialize linear projection with correct dimensions
82
+ linear_proj = LinearProjection(siglip_output_dim, output_dim).to(device)
83
+ try:
84
+ linear_proj.load_state_dict(checkpoint)
85
+ print("Successfully loaded linear projection weights")
86
+ except Exception as e:
87
+ print(f"Error loading linear projection weights: {e}")
88
+ return
89
+
90
+ # Load Phi model with 4-bit quantization
91
+ bnb_config = BitsAndBytesConfig(
92
+ load_in_4bit=True,
93
+ bnb_4bit_quant_type="nf4",
94
+ bnb_4bit_compute_dtype=torch.float16,
95
+ bnb_4bit_use_double_quant=False
96
+ )
97
+
98
+ phi_model = AutoModelForCausalLM.from_pretrained(
99
+ "microsoft/Phi-3-mini-4k-instruct",
100
+ quantization_config=bnb_config,
101
+ device_map="auto"
102
+ )
103
+ phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
104
+
105
+ # Add padding token if not present
106
+ if phi_tokenizer.pad_token is None:
107
+ phi_tokenizer.pad_token = phi_tokenizer.eos_token
108
+
109
+ # Get embedding dimension from phi model
110
+ phi_embed_dim = phi_model.get_input_embeddings().weight.shape[1]
111
+
112
+ # Create projection layer for image embeddings
113
+ image_text_proj = ImageTextProjection(output_dim, phi_embed_dim).to(device)
114
+
115
+ # Prepare model for k-bit training
116
+ phi_model = prepare_model_for_kbit_training(phi_model)
117
+
118
+ # Setup LoRA configuration
119
+ lora_config = LoraConfig(
120
+ r=16,
121
+ lora_alpha=32,
122
+ target_modules=["mlp.dense_h_to_4h", "mlp.dense_4h_to_h", "self_attn.qkv_proj", "self_attn.dense"],
123
+ lora_dropout=0.05,
124
+ bias="none",
125
+ task_type="CAUSAL_LM"
126
+ )
127
+
128
+ # Get PEFT model
129
+ phi_model = get_peft_model(phi_model, lora_config)
130
+
131
+ # Freeze SigLIP and linear projection
132
+ for param in siglip_model.parameters():
133
+ param.requires_grad = False
134
+ for param in linear_proj.parameters():
135
+ param.requires_grad = False
136
+
137
+ # Load CIFAR10 test dataset
138
+ transform = transforms.Compose([
139
+ transforms.Resize((384, 384)),
140
+ transforms.ToTensor(),
141
+ ])
142
+
143
+ test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
144
+ subset_indices = list(range(num_images))
145
+ subset_dataset = Subset(test_dataset, subset_indices)
146
+ dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False)
147
+
148
+ # Optimizer for both phi model and image projection
149
+ optimizer = AdamW([
150
+ {'params': phi_model.parameters()},
151
+ {'params': image_text_proj.parameters()}
152
+ ], lr=learning_rate)
153
+
154
+ # Training loop
155
+ for epoch in range(num_epochs):
156
+ total_loss = 0
157
+ phi_model.train()
158
+ image_text_proj.train()
159
+
160
+ progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
161
+ for batch_idx, (images, _) in enumerate(progress_bar):
162
+ images = images.to(device)
163
+ batch_size = images.size(0)
164
+
165
+ # Get image embeddings
166
+ image_embeddings = get_image_embedding(images, siglip_model, siglip_processor, linear_proj, device)
167
+
168
+ # Process each question
169
+ for q_idx, question in enumerate(questions):
170
+ # Read corresponding answers
171
+ answers = []
172
+ for idx in range(batch_size):
173
+ global_idx = batch_idx * batch_size + idx
174
+ if global_idx < num_images:
175
+ file_path = f'qa_outputs/image_{global_idx}_extr.txt'
176
+ try:
177
+ with open(file_path, 'r') as f:
178
+ lines = f.readlines()
179
+ answer = lines[q_idx].strip() if q_idx < len(lines) else ""
180
+ answers.append(answer)
181
+ except:
182
+ answers.append("No answer available")
183
+
184
+ # Tokenize questions and answers for the entire batch
185
+ question_tokens = phi_tokenizer(
186
+ [question] * batch_size,
187
+ padding=True,
188
+ truncation=True,
189
+ max_length=512,
190
+ return_tensors="pt"
191
+ ).to(device)
192
+
193
+ target_tokens = phi_tokenizer(
194
+ answers,
195
+ padding=True,
196
+ truncation=True,
197
+ max_length=512,
198
+ return_tensors="pt"
199
+ ).to(device)
200
+
201
+ # Get question embeddings for the entire batch
202
+ question_embeds = phi_model.get_input_embeddings()(question_tokens['input_ids']) # [batch_size, seq_len, embed_dim]
203
+
204
+ # Project and prepare image embeddings for the entire batch
205
+ image_embeds = image_text_proj(image_embeddings) # [batch_size, embed_dim]
206
+ image_embeds = image_embeds.unsqueeze(1) # [batch_size, 1, embed_dim]
207
+
208
+ # Combine image embeddings with question embeddings
209
+ combined_embedding = torch.cat([
210
+ image_embeds, # [batch_size, 1, embed_dim]
211
+ question_embeds # [batch_size, seq_len, embed_dim]
212
+ ], dim=1) # [batch_size, 1+seq_len, embed_dim]
213
+
214
+ # Create attention mask for the combined sequence
215
+ attention_mask = torch.ones(
216
+ (batch_size, combined_embedding.size(1)),
217
+ dtype=torch.long,
218
+ device=device
219
+ )
220
+
221
+ # Prepare labels by shifting them right
222
+ labels = target_tokens['input_ids'].clone()
223
+ labels = torch.cat([
224
+ torch.full((batch_size, combined_embedding.size(1) - 1), -100, device=device),
225
+ labels
226
+ ], dim=1)[:, :combined_embedding.size(1)]
227
+
228
+ # Forward pass
229
+ outputs = phi_model(
230
+ inputs_embeds=combined_embedding,
231
+ attention_mask=attention_mask,
232
+ labels=labels
233
+ )
234
+
235
+ loss = outputs.loss
236
+ total_loss += loss.item()
237
+
238
+ # Backward pass
239
+ loss.backward()
240
+ optimizer.step()
241
+ optimizer.zero_grad()
242
+
243
+ progress_bar.set_postfix({'loss': loss.item()})
244
+
245
+ avg_epoch_loss = total_loss / (len(dataloader) * len(questions) * batch_size)
246
+ print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.4f}')
247
+
248
+ # Save the trained models
249
+ phi_model.save_pretrained('phi_model_trained')
250
+ torch.save(image_text_proj.state_dict(), 'image_text_proj.pth')
251
+ print("Training completed. Models saved as 'phi_model_trained' and 'image_text_proj.pth'")
252
+
253
+ if __name__ == "__main__":
254
+ # Example questions - replace with your actual questions
255
+ questions = [
256
+ "Give a description of the image?",
257
+ "How does the main object in the image look like?",
258
+ "How can the main object in the image be useful to humans?",
259
+ "What is the color of the main object in the image?",
260
+ "Describe the setting of the image?"
261
+ ]
262
+
263
+ main(questions=questions)