Issue fix
Browse files
app.py
CHANGED
@@ -340,15 +340,16 @@ class SkinGPT4(nn.Module):
|
|
340 |
return prompt
|
341 |
|
342 |
def generate(self, images, user_input=None, max_length=300):
|
|
|
343 |
# Get aligned features
|
344 |
aligned_features = self.forward(images)
|
345 |
-
|
346 |
prompt = self.build_prompt(aligned_features, user_input)
|
347 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
348 |
image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
349 |
image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
|
350 |
image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
|
351 |
-
|
352 |
# Generate response
|
353 |
outputs = self.llama.generate(
|
354 |
inputs_embeds=image_embeddings,
|
@@ -357,7 +358,7 @@ class SkinGPT4(nn.Module):
|
|
357 |
top_p=0.9,
|
358 |
do_sample=True
|
359 |
)
|
360 |
-
|
361 |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
362 |
|
363 |
class SkinGPTClassifier:
|
@@ -430,12 +431,14 @@ if uploaded_file:
|
|
430 |
st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
|
431 |
image = Image.open(uploaded_file).convert("RGB")
|
432 |
if not st.session_state.conversation:
|
433 |
-
# First message - diagnosis
|
434 |
with st.spinner("Analyzing image..."):
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
|
|
|
|
|
|
439 |
else:
|
440 |
# Follow-up questions
|
441 |
if user_query := st.chat_input("Ask a follow-up question..."):
|
|
|
340 |
return prompt
|
341 |
|
342 |
def generate(self, images, user_input=None, max_length=300):
|
343 |
+
print("Analysing the image to generate the diagnosis")
|
344 |
# Get aligned features
|
345 |
aligned_features = self.forward(images)
|
346 |
+
print("Generated the aligned features with ViT and Qformer")
|
347 |
prompt = self.build_prompt(aligned_features, user_input)
|
348 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
349 |
image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
350 |
image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
|
351 |
image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
|
352 |
+
print("Generating the diagnosis with llama")
|
353 |
# Generate response
|
354 |
outputs = self.llama.generate(
|
355 |
inputs_embeds=image_embeddings,
|
|
|
358 |
top_p=0.9,
|
359 |
do_sample=True
|
360 |
)
|
361 |
+
print("Generated diagnosis")
|
362 |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
363 |
|
364 |
class SkinGPTClassifier:
|
|
|
431 |
st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
|
432 |
image = Image.open(uploaded_file).convert("RGB")
|
433 |
if not st.session_state.conversation:
|
|
|
434 |
with st.spinner("Analyzing image..."):
|
435 |
+
result = classifier.predict(image)
|
436 |
+
if "error" in result:
|
437 |
+
st.error(result["error"])
|
438 |
+
else:
|
439 |
+
st.session_state.conversation.append(("assistant", result))
|
440 |
+
with st.chat_message("assistant"):
|
441 |
+
st.markdown(result)
|
442 |
else:
|
443 |
# Follow-up questions
|
444 |
if user_query := st.chat_input("Ask a follow-up question..."):
|