mgbam commited on
Commit
48ac3b6
·
verified ·
1 Parent(s): 9fd9472

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
- from transformers import AutoProcessor, AutoModelForImageTextRetrieval
6
  import torch.nn.functional as F
7
 
8
  #---------------------------------
@@ -13,16 +13,18 @@ def load_biomedclip_model():
13
  """Loads the BiomedCLIP model and tokenizer."""
14
  biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
15
  processor = AutoProcessor.from_pretrained(biomedclip_model_name)
16
- model = AutoModelForImageTextRetrieval.from_pretrained(biomedclip_model_name).cuda().eval()
17
- return model, processor
 
 
18
 
19
- def compute_similarity(image, text, biomedclip_model, biomedclip_processor):
20
  """Computes similarity scores using BiomedCLIP."""
21
  with torch.no_grad():
22
- inputs = biomedclip_processor(text=text, images=image, return_tensors="pt", padding=True).to(biomedclip_model.device)
23
- outputs = biomedclip_model(**inputs)
24
- image_embeds = outputs.image_embeds
25
- text_embeds = outputs.text_embeds
26
  image_embeds = F.normalize(image_embeds, dim=-1)
27
  text_embeds = F.normalize(text_embeds, dim=-1)
28
  similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
@@ -55,12 +57,12 @@ def gradio_ask(user_message, chatbot, chat_state):
55
  return '', chatbot, chat_state
56
 
57
  @spaces.GPU
58
- def gradio_answer(chatbot, chat_state, img_list, biomedclip_model, biomedclip_processor, similarity_output):
59
  """Computes and displays similarity scores."""
60
  if not img_list:
61
  return chatbot, chat_state, img_list, similarity_output
62
 
63
- similarity_score = compute_similarity(img_list[0], chatbot[-1][0], biomedclip_model, biomedclip_processor)
64
  print(f'Similarity Score is: {similarity_score}')
65
 
66
  similarity_text = f"Similarity Score: {similarity_score:.3f}"
@@ -77,7 +79,7 @@ examples_list=[
77
  ]
78
 
79
  # Load models and related resources outside of the Gradio block for loading on startup
80
- biomedclip_model, biomedclip_processor = load_biomedclip_model()
81
 
82
  with gr.Blocks() as demo:
83
  gr.Markdown(title)
@@ -100,7 +102,7 @@ with gr.Blocks() as demo:
100
  upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])
101
 
102
  text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
103
- gradio_answer, [chatbot, chat_state, img_list, biomedclip_model, biomedclip_processor, similarity_output], [chatbot, chat_state, img_list, similarity_output]
104
  )
105
  clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)
106
 
 
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoProcessor, AutoModel, CLIPVisionModel
6
  import torch.nn.functional as F
7
 
8
  #---------------------------------
 
13
  """Loads the BiomedCLIP model and tokenizer."""
14
  biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
15
  processor = AutoProcessor.from_pretrained(biomedclip_model_name)
16
+ config = AutoModel.from_pretrained(biomedclip_model_name).config
17
+ vision_model = CLIPVisionModel.from_pretrained(config.vision_config._name_or_path, torch_dtype=torch.float16).cuda().eval()
18
+ text_model = AutoModel.from_pretrained(config.text_config._name_or_path).cuda().eval()
19
+ return vision_model, text_model, processor
20
 
21
+ def compute_similarity(image, text, vision_model, text_model, biomedclip_processor):
22
  """Computes similarity scores using BiomedCLIP."""
23
  with torch.no_grad():
24
+ inputs = biomedclip_processor(text=text, images=image, return_tensors="pt", padding=True).to(text_model.device)
25
+ text_embeds = text_model(**inputs).last_hidden_state[:,0,:] # Extract the [CLS] token
26
+ image_inputs = biomedclip_processor(images=image, return_tensors="pt").to(vision_model.device)
27
+ image_embeds = vision_model(**image_inputs).last_hidden_state[:,0,:] # Extract the image embedding
28
  image_embeds = F.normalize(image_embeds, dim=-1)
29
  text_embeds = F.normalize(text_embeds, dim=-1)
30
  similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
 
57
  return '', chatbot, chat_state
58
 
59
  @spaces.GPU
60
+ def gradio_answer(chatbot, chat_state, img_list, vision_model, text_model, biomedclip_processor, similarity_output):
61
  """Computes and displays similarity scores."""
62
  if not img_list:
63
  return chatbot, chat_state, img_list, similarity_output
64
 
65
+ similarity_score = compute_similarity(img_list[0], chatbot[-1][0], vision_model, text_model, biomedclip_processor)
66
  print(f'Similarity Score is: {similarity_score}')
67
 
68
  similarity_text = f"Similarity Score: {similarity_score:.3f}"
 
79
  ]
80
 
81
  # Load models and related resources outside of the Gradio block for loading on startup
82
+ vision_model, text_model, biomedclip_processor = load_biomedclip_model()
83
 
84
  with gr.Blocks() as demo:
85
  gr.Markdown(title)
 
102
  upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])
103
 
104
  text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
105
+ gradio_answer, [chatbot, chat_state, img_list, vision_model, text_model, biomedclip_processor, similarity_output], [chatbot, chat_state, img_list, similarity_output]
106
  )
107
  clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)
108