Removed secrets from history
Browse files
app.py
CHANGED
@@ -20,7 +20,8 @@ import requests
|
|
20 |
from io import BytesIO
|
21 |
import os
|
22 |
from huggingface_hub import hf_hub_download
|
23 |
-
|
|
|
24 |
|
25 |
token = os.getenv("HF_TOKEN")
|
26 |
if not token:
|
@@ -164,7 +165,6 @@ class SkinGPT4(nn.Module):
|
|
164 |
self.q_former.eval()
|
165 |
print("Loaded QFormer")
|
166 |
self.llama = self._init_llama()
|
167 |
-
self.llama = self.llama.to(device)
|
168 |
self.llama.resize_token_embeddings(len(self.tokenizer))
|
169 |
|
170 |
self.llama_proj = nn.Linear(
|
@@ -214,30 +214,16 @@ class SkinGPT4(nn.Module):
|
|
214 |
def _init_llama(self):
|
215 |
"""Initialize frozen LLaMA-2-13b-chat with proper error handling"""
|
216 |
try:
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
# Configure 4-bit quantization to reduce memory usage
|
221 |
-
# quantization_config = BitsAndBytesConfig(
|
222 |
-
# load_in_4bit=True,
|
223 |
-
# bnb_4bit_compute_dtype=torch.float16,
|
224 |
-
# bnb_4bit_use_double_quant=True,
|
225 |
-
# bnb_4bit_quant_type="nf4"
|
226 |
-
# )
|
227 |
-
quant_config = BitsAndBytesConfig(
|
228 |
-
load_in_4bit=True,
|
229 |
-
bnb_4bit_compute_dtype=torch.float16,
|
230 |
-
bnb_4bit_quant_type="nf4",
|
231 |
-
)
|
232 |
-
|
233 |
# First try loading with device_map="auto"
|
234 |
try:
|
235 |
model = LlamaForCausalLM.from_pretrained(
|
236 |
"meta-llama/Llama-2-13b-chat-hf",
|
237 |
-
# quantization_config=quant_config,
|
238 |
token=token,
|
239 |
torch_dtype=torch.float16,
|
240 |
-
device_map=
|
241 |
low_cpu_mem_usage=True
|
242 |
)
|
243 |
except ImportError:
|
@@ -355,22 +341,10 @@ class SkinGPT4(nn.Module):
|
|
355 |
|
356 |
def generate(self, images, user_input=None, max_length=300):
|
357 |
# Get aligned features
|
358 |
-
images = images.to(self.dtype)
|
359 |
-
|
360 |
aligned_features = self.forward(images)
|
361 |
|
362 |
prompt = self.build_prompt(aligned_features, user_input)
|
363 |
-
|
364 |
-
self.llama = self.llama.to(self.dtype)
|
365 |
-
|
366 |
-
# Tokenize prompt
|
367 |
-
|
368 |
-
# self.tokenizer.add_special_tokens({'additional_special_tokens': ['<ImageHere>']})
|
369 |
-
# self.llama.resize_token_embeddings(len(self.tokenizer))
|
370 |
-
|
371 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
372 |
-
|
373 |
-
# Replace <ImageHere> with aligned features
|
374 |
image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
375 |
image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
|
376 |
image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
|
@@ -386,27 +360,13 @@ class SkinGPT4(nn.Module):
|
|
386 |
|
387 |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
388 |
|
389 |
-
|
390 |
-
# def load_model(model_path):
|
391 |
-
# model_path = hf_hub_download(
|
392 |
-
# repo_id="KeerthiVM/SkinCancerDiagnosis",
|
393 |
-
# filename="dermnet_finetuned_version1.pth",
|
394 |
-
# )
|
395 |
-
# # model = SkinGPT4(vit_checkpoint_path="dermnet_finetuned_version1.pth")
|
396 |
-
# model = SkinGPT4(vit_checkpoint_path=model_path)
|
397 |
-
# model.to(device)
|
398 |
-
# model.eval()
|
399 |
-
# return model
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
class SkinGPTClassifier:
|
404 |
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
405 |
self.device = torch.device(device)
|
406 |
self.conversation_history = []
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
self.resnet_feature_extractor = None
|
411 |
|
412 |
# Image transformations
|
@@ -421,22 +381,11 @@ class SkinGPTClassifier:
|
|
421 |
repo_id="KeerthiVM/SkinCancerDiagnosis",
|
422 |
filename="dermnet_finetuned_version1.pth",
|
423 |
)
|
424 |
-
|
425 |
-
|
426 |
-
self.meta_model.to_empty(device=device)
|
427 |
-
|
428 |
-
def predict(self, image, top_k=3):
|
429 |
-
"""Make prediction for a single image"""
|
430 |
-
if self.meta_model is None:
|
431 |
-
self.load_models()
|
432 |
-
|
433 |
-
# Load and preprocess image
|
434 |
-
try:
|
435 |
-
# image = Image.open(image_path).convert('RGB')
|
436 |
-
image = image.convert('RGB')
|
437 |
-
except:
|
438 |
-
raise ValueError("Could not load image from path")
|
439 |
|
|
|
|
|
440 |
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
441 |
diagnosis = self.meta_model.generate(
|
442 |
image_tensor
|
@@ -446,18 +395,16 @@ class SkinGPTClassifier:
|
|
446 |
"top_predictions": diagnosis,
|
447 |
}
|
448 |
|
449 |
-
|
|
|
|
|
450 |
|
|
|
451 |
|
452 |
# === Session Init ===
|
453 |
if "messages" not in st.session_state:
|
454 |
st.session_state.messages = []
|
455 |
|
456 |
-
# === Image Processing Function ===
|
457 |
-
def run_inference(image):
|
458 |
-
result = classifier.predict(image, top_k=1)
|
459 |
-
|
460 |
-
return result
|
461 |
|
462 |
# === PDF Export ===
|
463 |
def export_chat_to_pdf(messages):
|
@@ -484,7 +431,8 @@ if uploaded_file:
|
|
484 |
image = Image.open(uploaded_file).convert("RGB")
|
485 |
if not st.session_state.conversation:
|
486 |
# First message - diagnosis
|
487 |
-
|
|
|
488 |
st.session_state.conversation.append(("assistant", diagnosis))
|
489 |
with st.chat_message("assistant"):
|
490 |
st.markdown(diagnosis)
|
|
|
20 |
from io import BytesIO
|
21 |
import os
|
22 |
from huggingface_hub import hf_hub_download
|
23 |
+
from transformers import BitsAndBytesConfig
|
24 |
+
from accelerate import init_empty_weights
|
25 |
|
26 |
token = os.getenv("HF_TOKEN")
|
27 |
if not token:
|
|
|
165 |
self.q_former.eval()
|
166 |
print("Loaded QFormer")
|
167 |
self.llama = self._init_llama()
|
|
|
168 |
self.llama.resize_token_embeddings(len(self.tokenizer))
|
169 |
|
170 |
self.llama_proj = nn.Linear(
|
|
|
214 |
def _init_llama(self):
|
215 |
"""Initialize frozen LLaMA-2-13b-chat with proper error handling"""
|
216 |
try:
|
217 |
+
device_map = {
|
218 |
+
"": 0 if torch.cuda.is_available() else "cpu"
|
219 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
# First try loading with device_map="auto"
|
221 |
try:
|
222 |
model = LlamaForCausalLM.from_pretrained(
|
223 |
"meta-llama/Llama-2-13b-chat-hf",
|
|
|
224 |
token=token,
|
225 |
torch_dtype=torch.float16,
|
226 |
+
device_map=device_map,
|
227 |
low_cpu_mem_usage=True
|
228 |
)
|
229 |
except ImportError:
|
|
|
341 |
|
342 |
def generate(self, images, user_input=None, max_length=300):
|
343 |
# Get aligned features
|
|
|
|
|
344 |
aligned_features = self.forward(images)
|
345 |
|
346 |
prompt = self.build_prompt(aligned_features, user_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
|
|
|
|
348 |
image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
349 |
image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
|
350 |
image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
|
|
|
360 |
|
361 |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
class SkinGPTClassifier:
|
364 |
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
365 |
self.device = torch.device(device)
|
366 |
self.conversation_history = []
|
367 |
+
|
368 |
+
with st.spinner("Loading AI models (this may take several minutes)..."):
|
369 |
+
self.meta_model = self.load_models()
|
370 |
self.resnet_feature_extractor = None
|
371 |
|
372 |
# Image transformations
|
|
|
381 |
repo_id="KeerthiVM/SkinCancerDiagnosis",
|
382 |
filename="dermnet_finetuned_version1.pth",
|
383 |
)
|
384 |
+
meta_model = SkinGPT4(vit_checkpoint_path=model_path)
|
385 |
+
return meta_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
+
def predict(self, image):
|
388 |
+
image = image.convert('RGB')
|
389 |
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
390 |
diagnosis = self.meta_model.generate(
|
391 |
image_tensor
|
|
|
395 |
"top_predictions": diagnosis,
|
396 |
}
|
397 |
|
398 |
+
@st.cache_resource
|
399 |
+
def get_classifier():
|
400 |
+
return SkinGPTClassifier()
|
401 |
|
402 |
+
classifier = get_classifier()
|
403 |
|
404 |
# === Session Init ===
|
405 |
if "messages" not in st.session_state:
|
406 |
st.session_state.messages = []
|
407 |
|
|
|
|
|
|
|
|
|
|
|
408 |
|
409 |
# === PDF Export ===
|
410 |
def export_chat_to_pdf(messages):
|
|
|
431 |
image = Image.open(uploaded_file).convert("RGB")
|
432 |
if not st.session_state.conversation:
|
433 |
# First message - diagnosis
|
434 |
+
with st.spinner("Analyzing image..."):
|
435 |
+
diagnosis = classifier.predict(image)
|
436 |
st.session_state.conversation.append(("assistant", diagnosis))
|
437 |
with st.chat_message("assistant"):
|
438 |
st.markdown(diagnosis)
|