KeerthiVM commited on
Commit
73f97d4
·
1 Parent(s): b3795f6
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -340,15 +340,16 @@ class SkinGPT4(nn.Module):
340
  return prompt
341
 
342
  def generate(self, images, user_input=None, max_length=300):
 
343
  # Get aligned features
344
  aligned_features = self.forward(images)
345
-
346
  prompt = self.build_prompt(aligned_features, user_input)
347
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
348
  image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
349
  image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
350
  image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
351
-
352
  # Generate response
353
  outputs = self.llama.generate(
354
  inputs_embeds=image_embeddings,
@@ -357,7 +358,7 @@ class SkinGPT4(nn.Module):
357
  top_p=0.9,
358
  do_sample=True
359
  )
360
-
361
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
362
 
363
  class SkinGPTClassifier:
@@ -430,12 +431,14 @@ if uploaded_file:
430
  st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
431
  image = Image.open(uploaded_file).convert("RGB")
432
  if not st.session_state.conversation:
433
- # First message - diagnosis
434
  with st.spinner("Analyzing image..."):
435
- diagnosis = classifier.predict(image)
436
- st.session_state.conversation.append(("assistant", diagnosis))
437
- with st.chat_message("assistant"):
438
- st.markdown(diagnosis)
 
 
 
439
  else:
440
  # Follow-up questions
441
  if user_query := st.chat_input("Ask a follow-up question..."):
 
340
  return prompt
341
 
342
  def generate(self, images, user_input=None, max_length=300):
343
+ print("Analysing the image to generate the diagnosis")
344
  # Get aligned features
345
  aligned_features = self.forward(images)
346
+ print("Generated the aligned features with ViT and Qformer")
347
  prompt = self.build_prompt(aligned_features, user_input)
348
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
349
  image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
350
  image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
351
  image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
352
+ print("Generating the diagnosis with llama")
353
  # Generate response
354
  outputs = self.llama.generate(
355
  inputs_embeds=image_embeddings,
 
358
  top_p=0.9,
359
  do_sample=True
360
  )
361
+ print("Generated diagnosis")
362
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
363
 
364
  class SkinGPTClassifier:
 
431
  st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
432
  image = Image.open(uploaded_file).convert("RGB")
433
  if not st.session_state.conversation:
 
434
  with st.spinner("Analyzing image..."):
435
+ result = classifier.predict(image)
436
+ if "error" in result:
437
+ st.error(result["error"])
438
+ else:
439
+ st.session_state.conversation.append(("assistant", result))
440
+ with st.chat_message("assistant"):
441
+ st.markdown(result)
442
  else:
443
  # Follow-up questions
444
  if user_query := st.chat_input("Ask a follow-up question..."):