fix added
Browse files
app.py
CHANGED
@@ -115,27 +115,6 @@ class Blip2QFormer(nn.Module):
|
|
115 |
|
116 |
return outputs.last_hidden_state
|
117 |
|
118 |
-
class LayerNorm(nn.LayerNorm):
|
119 |
-
"""Subclass torch's LayerNorm to handle fp16."""
|
120 |
-
|
121 |
-
def forward(self, x: torch.Tensor):
|
122 |
-
orig_type = x.dtype
|
123 |
-
ret = super().forward(x.type(torch.float32))
|
124 |
-
return ret.type(orig_type)
|
125 |
-
|
126 |
-
|
127 |
-
class ViTClassifier(nn.Module):
|
128 |
-
def __init__(self, vit, ln_vision, num_labels):
|
129 |
-
super(ViTClassifier, self).__init__()
|
130 |
-
self.vit = vit # Pretrained ViT from MiniGPT-4
|
131 |
-
self.ln_vision = ln_vision # LayerNorm from MiniGPT-4
|
132 |
-
self.classifier = nn.Linear(vit.num_features, num_labels)
|
133 |
-
|
134 |
-
def forward(self, x):
|
135 |
-
features = self.ln_vision(self.vit(x)) # [batch, seq_len, dim]
|
136 |
-
cls_token = features[:, 0, :] # Extract CLS token
|
137 |
-
return self.classifier(cls_token)
|
138 |
-
|
139 |
|
140 |
class SkinGPT4(nn.Module):
|
141 |
def __init__(self, vit_checkpoint_path,
|
@@ -161,10 +140,7 @@ class SkinGPT4(nn.Module):
|
|
161 |
self.q_former.load_from_pretrained(q_former_model)
|
162 |
for param in self.q_former.parameters():
|
163 |
param.requires_grad = False
|
164 |
-
|
165 |
-
for param in module.parameters():
|
166 |
-
param.requires_grad = False
|
167 |
-
module.eval()
|
168 |
print("Loaded QFormer")
|
169 |
|
170 |
self.tokenizer = LlamaTokenizer.from_pretrained(
|
@@ -185,8 +161,10 @@ class SkinGPT4(nn.Module):
|
|
185 |
print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}")
|
186 |
print(f"LLaMA input dim: {self.llama.config.hidden_size}")
|
187 |
|
188 |
-
for
|
189 |
-
param
|
|
|
|
|
190 |
|
191 |
def _init_vit(self, vit_checkpoint_path):
|
192 |
"""Initialize EVA-ViT-G with paper specifications"""
|
@@ -213,9 +191,6 @@ class SkinGPT4(nn.Module):
|
|
213 |
# 4. Load weights while ignoring classifier head
|
214 |
vit.load_state_dict(vit_weights, strict=False)
|
215 |
|
216 |
-
# 5. Freeze according to paper specs
|
217 |
-
for param in vit.parameters():
|
218 |
-
param.requires_grad = False
|
219 |
|
220 |
return vit.eval()
|
221 |
|
@@ -226,27 +201,13 @@ class SkinGPT4(nn.Module):
|
|
226 |
"": 0 if torch.cuda.is_available() else "cpu"
|
227 |
}
|
228 |
# First try loading with device_map="auto"
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
)
|
237 |
-
except ImportError:
|
238 |
-
# Fallback to CPU-offloading if GPU memory is insufficient
|
239 |
-
with init_empty_weights():
|
240 |
-
model = LlamaForCausalLM.from_pretrained(
|
241 |
-
"meta-llama/Llama-2-13b-chat-hf",
|
242 |
-
token=token,
|
243 |
-
torch_dtype=torch.float16
|
244 |
-
)
|
245 |
-
model = model.to(self.device)
|
246 |
-
|
247 |
-
# Freeze all parameters
|
248 |
-
for param in model.parameters():
|
249 |
-
param.requires_grad = False
|
250 |
|
251 |
return model.eval()
|
252 |
|
@@ -259,12 +220,7 @@ class SkinGPT4(nn.Module):
|
|
259 |
f"Original error: {str(e)}"
|
260 |
)
|
261 |
|
262 |
-
def
|
263 |
-
"""Paper specifies Xavier initialization for alignment layer"""
|
264 |
-
nn.init.xavier_normal_(self.llama_proj.weight)
|
265 |
-
nn.init.constant_(self.llama_proj.bias, 0)
|
266 |
-
|
267 |
-
def _create_patches(self, x):
|
268 |
"""Convert image to patch embeddings following Eq. (1)"""
|
269 |
# x: (B, C, H, W)
|
270 |
x = x.to(self.dtype)
|
@@ -276,69 +232,39 @@ class SkinGPT4(nn.Module):
|
|
276 |
B, C, H, W = x.shape
|
277 |
N = (H * W) // (self.P ** 2)
|
278 |
|
279 |
-
x = self.vit.patch_embed(x)
|
280 |
|
281 |
num_patches = x.shape[1]
|
282 |
-
pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :]
|
283 |
x = x + pos_embed
|
284 |
|
285 |
# Add class token
|
286 |
-
class_token = self.vit.cls_token.expand(
|
287 |
-
x = torch.cat([class_token, x], dim=1)
|
288 |
-
return x
|
289 |
-
|
290 |
-
def forward_encoder(self, x):
|
291 |
-
"""ViT encoder from Eqs. (2)-(3)"""
|
292 |
-
# x: (B, N+1, D)
|
293 |
for blk in self.vit.blocks:
|
294 |
x = blk(x)
|
295 |
x = self.vit.norm(x)
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
vit_output = self.forward_encoder(x)
|
303 |
-
with torch.cuda.amp.autocast(enabled=False):
|
304 |
-
qformer_output = self.q_former(vit_output.float())
|
305 |
-
aligned_features = self.llama_proj(qformer_output.to(self.dtype))
|
306 |
-
return aligned_features
|
307 |
-
|
308 |
-
|
309 |
-
def add_to_history(self, role, content):
|
310 |
-
self.conversation_history.append({"role": role, "content": content})
|
311 |
-
|
312 |
-
def get_full_context(self):
|
313 |
-
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in self.conversation_history])
|
314 |
-
|
315 |
-
def build_prompt(self, image_embeds, user_question=None):
|
316 |
-
# Base prompt for initial diagnosis
|
317 |
-
if not user_question:
|
318 |
-
prompt = (
|
319 |
-
"### Instruction: <Img><ImageHere></Img> "
|
320 |
-
"Could you describe the skin disease in this image for me? "
|
321 |
-
"### Response:"
|
322 |
-
)
|
323 |
-
else:
|
324 |
-
# Follow-up prompt with conversation history
|
325 |
-
history = self.get_full_context()
|
326 |
-
prompt = (
|
327 |
-
f"### Instruction: <Img><ImageHere></Img> "
|
328 |
-
f"Based on our previous conversation:\n{history}\n"
|
329 |
-
f"User asks: {user_question}\n"
|
330 |
-
"### Response:"
|
331 |
-
)
|
332 |
|
333 |
-
return
|
334 |
|
335 |
-
def generate(self, images, user_input=None,
|
336 |
print("Analysing the image to generate the diagnosis")
|
337 |
-
|
338 |
-
|
|
|
339 |
print("Generated the aligned features with ViT and Qformer")
|
|
|
340 |
prompt = (
|
341 |
-
"<Img><ImageHere></Img>
|
|
|
|
|
|
|
342 |
)
|
343 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
344 |
image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
|
@@ -347,28 +273,39 @@ class SkinGPT4(nn.Module):
|
|
347 |
raise ValueError("Image token not found in prompt")
|
348 |
# Prepare embeddings
|
349 |
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
350 |
-
|
351 |
-
visual_embeds = aligned_features.mean(dim=1, keepdim=True) # [1, 1, 5120]
|
352 |
-
visual_embeds = visual_embeds.to(input_embeddings.dtype)
|
353 |
-
print(f"Visual embeddings : {visual_embeds}")
|
354 |
input_embeddings[image_token_pos] = visual_embeds
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
outputs = self.llama.generate(
|
358 |
inputs_embeds=input_embeddings,
|
359 |
-
max_new_tokens=
|
360 |
-
|
361 |
-
top_p=0.9,
|
362 |
-
repetition_penalty=1.2, # Prevent repetition
|
363 |
do_sample=True,
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
366 |
)
|
367 |
|
368 |
|
369 |
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
370 |
print(f"Output from llama : {full_output}")
|
371 |
-
|
|
|
|
|
372 |
|
373 |
|
374 |
class SkinGPTClassifier:
|
@@ -395,7 +332,6 @@ class SkinGPTClassifier:
|
|
395 |
)
|
396 |
model = SkinGPT4(vit_checkpoint_path=model_path).eval()
|
397 |
model = model.to(self.device)
|
398 |
-
model.eval()
|
399 |
return model
|
400 |
|
401 |
def predict(self, image):
|
@@ -450,7 +386,7 @@ if uploaded_file:
|
|
450 |
else:
|
451 |
st.session_state.conversation.append(("assistant", result))
|
452 |
with st.chat_message("assistant"):
|
453 |
-
st.markdown(result)
|
454 |
else:
|
455 |
# Follow-up questions
|
456 |
if user_query := st.chat_input("Ask a follow-up question..."):
|
|
|
115 |
|
116 |
return outputs.last_hidden_state
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
class SkinGPT4(nn.Module):
|
120 |
def __init__(self, vit_checkpoint_path,
|
|
|
140 |
self.q_former.load_from_pretrained(q_former_model)
|
141 |
for param in self.q_former.parameters():
|
142 |
param.requires_grad = False
|
143 |
+
|
|
|
|
|
|
|
144 |
print("Loaded QFormer")
|
145 |
|
146 |
self.tokenizer = LlamaTokenizer.from_pretrained(
|
|
|
161 |
print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}")
|
162 |
print(f"LLaMA input dim: {self.llama.config.hidden_size}")
|
163 |
|
164 |
+
for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
|
165 |
+
for param in module.parameters():
|
166 |
+
param.requires_grad = False
|
167 |
+
module.eval()
|
168 |
|
169 |
def _init_vit(self, vit_checkpoint_path):
|
170 |
"""Initialize EVA-ViT-G with paper specifications"""
|
|
|
191 |
# 4. Load weights while ignoring classifier head
|
192 |
vit.load_state_dict(vit_weights, strict=False)
|
193 |
|
|
|
|
|
|
|
194 |
|
195 |
return vit.eval()
|
196 |
|
|
|
201 |
"": 0 if torch.cuda.is_available() else "cpu"
|
202 |
}
|
203 |
# First try loading with device_map="auto"
|
204 |
+
model = LlamaForCausalLM.from_pretrained(
|
205 |
+
"meta-llama/Llama-2-13b-chat-hf",
|
206 |
+
token=token,
|
207 |
+
torch_dtype=torch.float16,
|
208 |
+
device_map=device_map,
|
209 |
+
low_cpu_mem_usage=True
|
210 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
return model.eval()
|
213 |
|
|
|
220 |
f"Original error: {str(e)}"
|
221 |
)
|
222 |
|
223 |
+
def encode_image(self, x):
|
|
|
|
|
|
|
|
|
|
|
224 |
"""Convert image to patch embeddings following Eq. (1)"""
|
225 |
# x: (B, C, H, W)
|
226 |
x = x.to(self.dtype)
|
|
|
232 |
B, C, H, W = x.shape
|
233 |
N = (H * W) // (self.P ** 2)
|
234 |
|
235 |
+
x = self.vit.patch_embed(x)
|
236 |
|
237 |
num_patches = x.shape[1]
|
238 |
+
pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :]
|
239 |
x = x + pos_embed
|
240 |
|
241 |
# Add class token
|
242 |
+
class_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
|
243 |
+
x = torch.cat([class_token, x], dim=1)
|
|
|
|
|
|
|
|
|
|
|
244 |
for blk in self.vit.blocks:
|
245 |
x = blk(x)
|
246 |
x = self.vit.norm(x)
|
247 |
+
vit_features = self.ln_vision(x)
|
248 |
+
|
249 |
+
# Q-Former forward pass
|
250 |
+
with torch.no_grad():
|
251 |
+
qformer_output = self.q_former(vit_features.float())
|
252 |
+
image_embeds = self.llama_proj(qformer_output.to(self.dtype))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
return image_embeds
|
255 |
|
256 |
+
def generate(self, images, user_input=None, max_new_tokens=300):
|
257 |
print("Analysing the image to generate the diagnosis")
|
258 |
+
|
259 |
+
image_embeds = self.encode_image(images)
|
260 |
+
print(f"Aligned features : {image_embeds}")
|
261 |
print("Generated the aligned features with ViT and Qformer")
|
262 |
+
|
263 |
prompt = (
|
264 |
+
"### Instruction: <Img><ImageHere></Img> "
|
265 |
+
"Could you describe the skin condition in this image? "
|
266 |
+
"Please provide a detailed analysis including possible diagnoses. "
|
267 |
+
"### Response:"
|
268 |
)
|
269 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
270 |
image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
|
|
|
273 |
raise ValueError("Image token not found in prompt")
|
274 |
# Prepare embeddings
|
275 |
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
276 |
+
visual_embeds = image_embeds.mean(dim=1, keepdim=True)
|
|
|
|
|
|
|
277 |
input_embeddings[image_token_pos] = visual_embeds
|
278 |
+
|
279 |
+
# outputs = self.llama.generate(
|
280 |
+
# inputs_embeds=input_embeddings,
|
281 |
+
# max_new_tokens=max_length,
|
282 |
+
# temperature=0.7,
|
283 |
+
# top_p=0.9,
|
284 |
+
# repetition_penalty=1.2, # Prevent repetition
|
285 |
+
# do_sample=True,
|
286 |
+
# pad_token_id=self.tokenizer.eos_token_id,
|
287 |
+
# eos_token_id=self.tokenizer.eos_token_id
|
288 |
+
# )
|
289 |
|
290 |
outputs = self.llama.generate(
|
291 |
inputs_embeds=input_embeddings,
|
292 |
+
max_new_tokens=max_new_tokens,
|
293 |
+
num_beams=1,
|
|
|
|
|
294 |
do_sample=True,
|
295 |
+
min_length=1,
|
296 |
+
top_p=0.9,
|
297 |
+
repetition_penalty=1.1,
|
298 |
+
length_penalty=1,
|
299 |
+
temperature=1.0,
|
300 |
+
pad_token_id=self.tokenizer.eos_token_id
|
301 |
)
|
302 |
|
303 |
|
304 |
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
305 |
print(f"Output from llama : {full_output}")
|
306 |
+
response = full_output.split("### Response:")[-1].strip()
|
307 |
+
|
308 |
+
return response
|
309 |
|
310 |
|
311 |
class SkinGPTClassifier:
|
|
|
332 |
)
|
333 |
model = SkinGPT4(vit_checkpoint_path=model_path).eval()
|
334 |
model = model.to(self.device)
|
|
|
335 |
return model
|
336 |
|
337 |
def predict(self, image):
|
|
|
386 |
else:
|
387 |
st.session_state.conversation.append(("assistant", result))
|
388 |
with st.chat_message("assistant"):
|
389 |
+
st.markdown(result["diagnosis"])
|
390 |
else:
|
391 |
# Follow-up questions
|
392 |
if user_query := st.chat_input("Ask a follow-up question..."):
|