Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
from PIL import Image
|
3 |
import gradio as gr
|
4 |
import spaces
|
5 |
-
from transformers import AutoProcessor,
|
6 |
import torch.nn.functional as F
|
7 |
|
8 |
#---------------------------------
|
@@ -13,16 +13,18 @@ def load_biomedclip_model():
|
|
13 |
"""Loads the BiomedCLIP model and tokenizer."""
|
14 |
biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
|
15 |
processor = AutoProcessor.from_pretrained(biomedclip_model_name)
|
16 |
-
|
17 |
-
|
|
|
|
|
18 |
|
19 |
-
def compute_similarity(image, text,
|
20 |
"""Computes similarity scores using BiomedCLIP."""
|
21 |
with torch.no_grad():
|
22 |
-
inputs = biomedclip_processor(text=text, images=image, return_tensors="pt", padding=True).to(
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
image_embeds = F.normalize(image_embeds, dim=-1)
|
27 |
text_embeds = F.normalize(text_embeds, dim=-1)
|
28 |
similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
|
@@ -55,12 +57,12 @@ def gradio_ask(user_message, chatbot, chat_state):
|
|
55 |
return '', chatbot, chat_state
|
56 |
|
57 |
@spaces.GPU
|
58 |
-
def gradio_answer(chatbot, chat_state, img_list,
|
59 |
"""Computes and displays similarity scores."""
|
60 |
if not img_list:
|
61 |
return chatbot, chat_state, img_list, similarity_output
|
62 |
|
63 |
-
similarity_score = compute_similarity(img_list[0], chatbot[-1][0],
|
64 |
print(f'Similarity Score is: {similarity_score}')
|
65 |
|
66 |
similarity_text = f"Similarity Score: {similarity_score:.3f}"
|
@@ -77,7 +79,7 @@ examples_list=[
|
|
77 |
]
|
78 |
|
79 |
# Load models and related resources outside of the Gradio block for loading on startup
|
80 |
-
|
81 |
|
82 |
with gr.Blocks() as demo:
|
83 |
gr.Markdown(title)
|
@@ -100,7 +102,7 @@ with gr.Blocks() as demo:
|
|
100 |
upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])
|
101 |
|
102 |
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
103 |
-
gradio_answer, [chatbot, chat_state, img_list,
|
104 |
)
|
105 |
clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)
|
106 |
|
|
|
2 |
from PIL import Image
|
3 |
import gradio as gr
|
4 |
import spaces
|
5 |
+
from transformers import AutoProcessor, AutoModel, CLIPVisionModel
|
6 |
import torch.nn.functional as F
|
7 |
|
8 |
#---------------------------------
|
|
|
13 |
"""Loads the BiomedCLIP model and tokenizer."""
|
14 |
biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
|
15 |
processor = AutoProcessor.from_pretrained(biomedclip_model_name)
|
16 |
+
config = AutoModel.from_pretrained(biomedclip_model_name).config
|
17 |
+
vision_model = CLIPVisionModel.from_pretrained(config.vision_config._name_or_path, torch_dtype=torch.float16).cuda().eval()
|
18 |
+
text_model = AutoModel.from_pretrained(config.text_config._name_or_path).cuda().eval()
|
19 |
+
return vision_model, text_model, processor
|
20 |
|
21 |
+
def compute_similarity(image, text, vision_model, text_model, biomedclip_processor):
|
22 |
"""Computes similarity scores using BiomedCLIP."""
|
23 |
with torch.no_grad():
|
24 |
+
inputs = biomedclip_processor(text=text, images=image, return_tensors="pt", padding=True).to(text_model.device)
|
25 |
+
text_embeds = text_model(**inputs).last_hidden_state[:,0,:] # Extract the [CLS] token
|
26 |
+
image_inputs = biomedclip_processor(images=image, return_tensors="pt").to(vision_model.device)
|
27 |
+
image_embeds = vision_model(**image_inputs).last_hidden_state[:,0,:] # Extract the image embedding
|
28 |
image_embeds = F.normalize(image_embeds, dim=-1)
|
29 |
text_embeds = F.normalize(text_embeds, dim=-1)
|
30 |
similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
|
|
|
57 |
return '', chatbot, chat_state
|
58 |
|
59 |
@spaces.GPU
|
60 |
+
def gradio_answer(chatbot, chat_state, img_list, vision_model, text_model, biomedclip_processor, similarity_output):
|
61 |
"""Computes and displays similarity scores."""
|
62 |
if not img_list:
|
63 |
return chatbot, chat_state, img_list, similarity_output
|
64 |
|
65 |
+
similarity_score = compute_similarity(img_list[0], chatbot[-1][0], vision_model, text_model, biomedclip_processor)
|
66 |
print(f'Similarity Score is: {similarity_score}')
|
67 |
|
68 |
similarity_text = f"Similarity Score: {similarity_score:.3f}"
|
|
|
79 |
]
|
80 |
|
81 |
# Load models and related resources outside of the Gradio block for loading on startup
|
82 |
+
vision_model, text_model, biomedclip_processor = load_biomedclip_model()
|
83 |
|
84 |
with gr.Blocks() as demo:
|
85 |
gr.Markdown(title)
|
|
|
102 |
upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])
|
103 |
|
104 |
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
105 |
+
gradio_answer, [chatbot, chat_state, img_list, vision_model, text_model, biomedclip_processor, similarity_output], [chatbot, chat_state, img_list, similarity_output]
|
106 |
)
|
107 |
clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)
|
108 |
|