LuisMBA commited on
Commit
e8957e6
·
verified ·
1 Parent(s): 6120321

Create search_ocean.py

Browse files

Helps retrieve the most relevant info according to the query. It has two indexes, one for images (only two images in kaggle, not accessible from here, just for multimodal testing use) and the other one with a text_index. The text index contains over 90 english recipes in various text formats and were uploaded from kaggle.

Files changed (1) hide show
  1. search_ocean.py +41 -0
search_ocean.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ from sentence_transformers import SentenceTransformer
3
+ import clip
4
+ from PIL import Image
5
+ import torch
6
+
7
+ # Load FAISS indexes
8
+ text_index = faiss.read_index("text_recipes_test_br_00.ocean")
9
+ image_index = faiss.read_index("image_recipes_test_br_00.ocean")
10
+
11
+ # Initialize models for text and image processing
12
+ text_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L12-v2")
13
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device="cpu")
14
+
15
+ def search_text(query):
16
+ # Generate text embedding
17
+ query_vec = text_encoder.encode(query).astype("float32")
18
+
19
+ # Search in the FAISS text index, retrieve k most common
20
+ distances, indices = text_index.search(query_vec.reshape(1, -1), k=3)
21
+ # Format results in a structured way
22
+ results = [
23
+ {"content": text_metadata[idx], "distance": float(dist), "relevance_score": 1.0 / (1.0 + dist)}
24
+ for idx, dist in zip(indices[0], distances[0])
25
+ ]
26
+ return results
27
+
28
+ def search_image(image_file):
29
+ # Generate image embedding
30
+ image = clip_preprocess(Image.open(image_file)).unsqueeze(0)
31
+ with torch.no_grad():
32
+ query_vec = clip_model.encode_image(image).squeeze().numpy().astype("float32")
33
+
34
+ # Search in the FAISS image index, retrieve k most common
35
+ distances, indices = image_index.search(query_vec.reshape(1, -1), k=3)
36
+ # Format results in a structured way
37
+ results = [
38
+ {"content": image_metadata[idx], "distance": float(dist), "relevance_score": 1.0 / (1.0 + dist)}
39
+ for idx, dist in zip(indices[0], distances[0])
40
+ ]
41
+ return results