NotShrirang commited on
Commit
67e3cab
·
1 Parent(s): a409078

feat: add searching with image

Browse files
data_search/adapter_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def get_adapter_model(in_shape, out_shape):
6
+ model = nn.Sequential(
7
+ nn.Linear(in_shape, 1024),
8
+ nn.ReLU(),
9
+ nn.Linear(1024, 1024),
10
+ nn.ReLU(),
11
+ nn.Linear(1024, out_shape)
12
+ )
13
+ return model
14
+
15
+
16
+ def load_adapter_model():
17
+ model = get_adapter_model(512, 384)
18
+ model.load_state_dict(torch.load("./weights/adapter_model.pt", map_location=torch.device('cpu')))
19
+ return model
data_search/data_search_page.py CHANGED
@@ -5,8 +5,9 @@ from PIL import Image
5
  import streamlit as st
6
  import sys
7
  import torch
8
- from vectordb import search_image_index, search_text_index
9
  from utils import load_image_index, load_text_index, get_local_files
 
10
 
11
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
12
 
@@ -18,12 +19,17 @@ def data_search(clip_model, preprocess, text_embedding_model, device):
18
  model, preprocess = clip.load("ViT-B/32", device=device)
19
  model.load_state_dict(torch.load(f"annotations/{file_name}/finetuned_model.pt", weights_only=True))
20
  return model, preprocess
 
 
 
 
 
21
 
22
  st.title("Data Search")
23
 
24
  images = os.listdir("images/")
25
  if images == []:
26
- st.warning("No images found in the data directory.")
27
  return
28
 
29
  annotation_projects = get_local_files("annotations/", get_details=True)
@@ -51,8 +57,13 @@ def data_search(clip_model, preprocess, text_embedding_model, device):
51
  else:
52
  st.info("Using Default Model")
53
 
 
 
 
54
  text_input = st.text_input("Search Database")
55
- if st.button("Search", disabled=text_input.strip() == ""):
 
 
56
  if os.path.exists("./vectorstore/image_index.index"):
57
  image_index, image_data = load_image_index()
58
  if os.path.exists("./vectorstore/text_index.index"):
@@ -64,10 +75,21 @@ def data_search(clip_model, preprocess, text_embedding_model, device):
64
  if not os.path.exists("./vectorstore/text_data.csv"):
65
  st.warning("No Text Index Found. So not searching for text.")
66
  text_index = None
67
- if image_index is not None:
68
- image_indices = search_image_index(text_input, image_index, clip_model, k=3)
69
- if text_index is not None:
70
- text_indices = search_text_index(text_input, text_index, text_embedding_model, k=3)
 
 
 
 
 
 
 
 
 
 
 
71
  if not image_index and not text_index:
72
  st.error("No Data Found! Please add data to the database.")
73
  st.subheader("Top 3 Results")
 
5
  import streamlit as st
6
  import sys
7
  import torch
8
+ from vectordb import search_image_index, search_text_index, search_image_index_with_image, search_text_index_with_image
9
  from utils import load_image_index, load_text_index, get_local_files
10
+ from data_search import adapter_utils
11
 
12
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
 
 
19
  model, preprocess = clip.load("ViT-B/32", device=device)
20
  model.load_state_dict(torch.load(f"annotations/{file_name}/finetuned_model.pt", weights_only=True))
21
  return model, preprocess
22
+
23
+ @st.cache_resource
24
+ def load_adapter():
25
+ adapter = adapter_utils.load_adapter_model()
26
+ return adapter
27
 
28
  st.title("Data Search")
29
 
30
  images = os.listdir("images/")
31
  if images == []:
32
+ st.warning("No Images Found! Please upload images to the database.")
33
  return
34
 
35
  annotation_projects = get_local_files("annotations/", get_details=True)
 
57
  else:
58
  st.info("Using Default Model")
59
 
60
+ adapter = load_adapter()
61
+ adapter.to(device)
62
+
63
  text_input = st.text_input("Search Database")
64
+ image_input = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
65
+
66
+ if st.button("Search", disabled=text_input.strip() == "" and image_input is None):
67
  if os.path.exists("./vectorstore/image_index.index"):
68
  image_index, image_data = load_image_index()
69
  if os.path.exists("./vectorstore/text_index.index"):
 
75
  if not os.path.exists("./vectorstore/text_data.csv"):
76
  st.warning("No Text Index Found. So not searching for text.")
77
  text_index = None
78
+ if image_input:
79
+ image = Image.open(image_input)
80
+ image = preprocess(image).unsqueeze(0).to(device)
81
+ with torch.no_grad():
82
+ image_features = clip_model.encode_image(image)
83
+ adapted_text_embeddings = adapter(image_features)
84
+ if image_index is not None:
85
+ image_indices = search_image_index_with_image(image_features, image_index, clip_model, k=3)
86
+ if text_index is not None:
87
+ text_indices = search_text_index_with_image(adapted_text_embeddings, text_index, text_embedding_model, k=3)
88
+ else:
89
+ if image_index is not None:
90
+ image_indices = search_image_index(text_input, image_index, clip_model, k=3)
91
+ if text_index is not None:
92
+ text_indices = search_text_index(text_input, text_index, text_embedding_model, k=3)
93
  if not image_index and not text_index:
94
  st.error("No Data Found! Please add data to the database.")
95
  st.subheader("Top 3 Results")