KeerthiVM commited on
Commit
b3795f6
·
1 Parent(s): f1c3197

Removed secrets from history

Browse files
Files changed (1) hide show
  1. app.py +19 -71
app.py CHANGED
@@ -20,7 +20,8 @@ import requests
20
  from io import BytesIO
21
  import os
22
  from huggingface_hub import hf_hub_download
23
-
 
24
 
25
  token = os.getenv("HF_TOKEN")
26
  if not token:
@@ -164,7 +165,6 @@ class SkinGPT4(nn.Module):
164
  self.q_former.eval()
165
  print("Loaded QFormer")
166
  self.llama = self._init_llama()
167
- self.llama = self.llama.to(device)
168
  self.llama.resize_token_embeddings(len(self.tokenizer))
169
 
170
  self.llama_proj = nn.Linear(
@@ -214,30 +214,16 @@ class SkinGPT4(nn.Module):
214
  def _init_llama(self):
215
  """Initialize frozen LLaMA-2-13b-chat with proper error handling"""
216
  try:
217
- from transformers import BitsAndBytesConfig
218
- from accelerate import init_empty_weights
219
-
220
- # Configure 4-bit quantization to reduce memory usage
221
- # quantization_config = BitsAndBytesConfig(
222
- # load_in_4bit=True,
223
- # bnb_4bit_compute_dtype=torch.float16,
224
- # bnb_4bit_use_double_quant=True,
225
- # bnb_4bit_quant_type="nf4"
226
- # )
227
- quant_config = BitsAndBytesConfig(
228
- load_in_4bit=True,
229
- bnb_4bit_compute_dtype=torch.float16,
230
- bnb_4bit_quant_type="nf4",
231
- )
232
-
233
  # First try loading with device_map="auto"
234
  try:
235
  model = LlamaForCausalLM.from_pretrained(
236
  "meta-llama/Llama-2-13b-chat-hf",
237
- # quantization_config=quant_config,
238
  token=token,
239
  torch_dtype=torch.float16,
240
- device_map="auto",
241
  low_cpu_mem_usage=True
242
  )
243
  except ImportError:
@@ -355,22 +341,10 @@ class SkinGPT4(nn.Module):
355
 
356
  def generate(self, images, user_input=None, max_length=300):
357
  # Get aligned features
358
- images = images.to(self.dtype)
359
-
360
  aligned_features = self.forward(images)
361
 
362
  prompt = self.build_prompt(aligned_features, user_input)
363
-
364
- self.llama = self.llama.to(self.dtype)
365
-
366
- # Tokenize prompt
367
-
368
- # self.tokenizer.add_special_tokens({'additional_special_tokens': ['<ImageHere>']})
369
- # self.llama.resize_token_embeddings(len(self.tokenizer))
370
-
371
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
372
-
373
- # Replace <ImageHere> with aligned features
374
  image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
375
  image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
376
  image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
@@ -386,27 +360,13 @@ class SkinGPT4(nn.Module):
386
 
387
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
388
 
389
-
390
- # def load_model(model_path):
391
- # model_path = hf_hub_download(
392
- # repo_id="KeerthiVM/SkinCancerDiagnosis",
393
- # filename="dermnet_finetuned_version1.pth",
394
- # )
395
- # # model = SkinGPT4(vit_checkpoint_path="dermnet_finetuned_version1.pth")
396
- # model = SkinGPT4(vit_checkpoint_path=model_path)
397
- # model.to(device)
398
- # model.eval()
399
- # return model
400
-
401
-
402
-
403
  class SkinGPTClassifier:
404
  def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
405
  self.device = torch.device(device)
406
  self.conversation_history = []
407
- # Initialize models (they'll be loaded when needed)
408
- self.base_models = None
409
- self.meta_model = None
410
  self.resnet_feature_extractor = None
411
 
412
  # Image transformations
@@ -421,22 +381,11 @@ class SkinGPTClassifier:
421
  repo_id="KeerthiVM/SkinCancerDiagnosis",
422
  filename="dermnet_finetuned_version1.pth",
423
  )
424
- # self.meta_model = SkinGPT4(vit_checkpoint_path="dermnet_finetuned_version1.pth")
425
- self.meta_model = SkinGPT4(vit_checkpoint_path=model_path)
426
- self.meta_model.to_empty(device=device)
427
-
428
- def predict(self, image, top_k=3):
429
- """Make prediction for a single image"""
430
- if self.meta_model is None:
431
- self.load_models()
432
-
433
- # Load and preprocess image
434
- try:
435
- # image = Image.open(image_path).convert('RGB')
436
- image = image.convert('RGB')
437
- except:
438
- raise ValueError("Could not load image from path")
439
 
 
 
440
  image_tensor = self.transform(image).unsqueeze(0).to(self.device)
441
  diagnosis = self.meta_model.generate(
442
  image_tensor
@@ -446,18 +395,16 @@ class SkinGPTClassifier:
446
  "top_predictions": diagnosis,
447
  }
448
 
449
- classifier = SkinGPTClassifier()
 
 
450
 
 
451
 
452
  # === Session Init ===
453
  if "messages" not in st.session_state:
454
  st.session_state.messages = []
455
 
456
- # === Image Processing Function ===
457
- def run_inference(image):
458
- result = classifier.predict(image, top_k=1)
459
-
460
- return result
461
 
462
  # === PDF Export ===
463
  def export_chat_to_pdf(messages):
@@ -484,7 +431,8 @@ if uploaded_file:
484
  image = Image.open(uploaded_file).convert("RGB")
485
  if not st.session_state.conversation:
486
  # First message - diagnosis
487
- diagnosis = classifier.predict(image, top_k=1)
 
488
  st.session_state.conversation.append(("assistant", diagnosis))
489
  with st.chat_message("assistant"):
490
  st.markdown(diagnosis)
 
20
  from io import BytesIO
21
  import os
22
  from huggingface_hub import hf_hub_download
23
+ from transformers import BitsAndBytesConfig
24
+ from accelerate import init_empty_weights
25
 
26
  token = os.getenv("HF_TOKEN")
27
  if not token:
 
165
  self.q_former.eval()
166
  print("Loaded QFormer")
167
  self.llama = self._init_llama()
 
168
  self.llama.resize_token_embeddings(len(self.tokenizer))
169
 
170
  self.llama_proj = nn.Linear(
 
214
  def _init_llama(self):
215
  """Initialize frozen LLaMA-2-13b-chat with proper error handling"""
216
  try:
217
+ device_map = {
218
+ "": 0 if torch.cuda.is_available() else "cpu"
219
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  # First try loading with device_map="auto"
221
  try:
222
  model = LlamaForCausalLM.from_pretrained(
223
  "meta-llama/Llama-2-13b-chat-hf",
 
224
  token=token,
225
  torch_dtype=torch.float16,
226
+ device_map=device_map,
227
  low_cpu_mem_usage=True
228
  )
229
  except ImportError:
 
341
 
342
  def generate(self, images, user_input=None, max_length=300):
343
  # Get aligned features
 
 
344
  aligned_features = self.forward(images)
345
 
346
  prompt = self.build_prompt(aligned_features, user_input)
 
 
 
 
 
 
 
 
347
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
 
 
348
  image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
349
  image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
350
  image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
 
360
 
361
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  class SkinGPTClassifier:
364
  def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
365
  self.device = torch.device(device)
366
  self.conversation_history = []
367
+
368
+ with st.spinner("Loading AI models (this may take several minutes)..."):
369
+ self.meta_model = self.load_models()
370
  self.resnet_feature_extractor = None
371
 
372
  # Image transformations
 
381
  repo_id="KeerthiVM/SkinCancerDiagnosis",
382
  filename="dermnet_finetuned_version1.pth",
383
  )
384
+ meta_model = SkinGPT4(vit_checkpoint_path=model_path)
385
+ return meta_model
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ def predict(self, image):
388
+ image = image.convert('RGB')
389
  image_tensor = self.transform(image).unsqueeze(0).to(self.device)
390
  diagnosis = self.meta_model.generate(
391
  image_tensor
 
395
  "top_predictions": diagnosis,
396
  }
397
 
398
+ @st.cache_resource
399
+ def get_classifier():
400
+ return SkinGPTClassifier()
401
 
402
+ classifier = get_classifier()
403
 
404
  # === Session Init ===
405
  if "messages" not in st.session_state:
406
  st.session_state.messages = []
407
 
 
 
 
 
 
408
 
409
  # === PDF Export ===
410
  def export_chat_to_pdf(messages):
 
431
  image = Image.open(uploaded_file).convert("RGB")
432
  if not st.session_state.conversation:
433
  # First message - diagnosis
434
+ with st.spinner("Analyzing image..."):
435
+ diagnosis = classifier.predict(image)
436
  st.session_state.conversation.append(("assistant", diagnosis))
437
  with st.chat_message("assistant"):
438
  st.markdown(diagnosis)