Spaces:
Sleeping
Sleeping
Create search_ocean.py
Browse filesHelps retrieve the most relevant info according to the query. It has two indexes, one for images (only two images in kaggle, not accessible from here, just for multimodal testing use) and the other one with a text_index. The text index contains over 90 english recipes in various text formats and were uploaded from kaggle.
- search_ocean.py +41 -0
search_ocean.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
import clip
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# Load FAISS indexes
|
8 |
+
text_index = faiss.read_index("text_recipes_test_br_00.ocean")
|
9 |
+
image_index = faiss.read_index("image_recipes_test_br_00.ocean")
|
10 |
+
|
11 |
+
# Initialize models for text and image processing
|
12 |
+
text_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L12-v2")
|
13 |
+
clip_model, clip_preprocess = clip.load("ViT-B/32", device="cpu")
|
14 |
+
|
15 |
+
def search_text(query):
|
16 |
+
# Generate text embedding
|
17 |
+
query_vec = text_encoder.encode(query).astype("float32")
|
18 |
+
|
19 |
+
# Search in the FAISS text index, retrieve k most common
|
20 |
+
distances, indices = text_index.search(query_vec.reshape(1, -1), k=3)
|
21 |
+
# Format results in a structured way
|
22 |
+
results = [
|
23 |
+
{"content": text_metadata[idx], "distance": float(dist), "relevance_score": 1.0 / (1.0 + dist)}
|
24 |
+
for idx, dist in zip(indices[0], distances[0])
|
25 |
+
]
|
26 |
+
return results
|
27 |
+
|
28 |
+
def search_image(image_file):
|
29 |
+
# Generate image embedding
|
30 |
+
image = clip_preprocess(Image.open(image_file)).unsqueeze(0)
|
31 |
+
with torch.no_grad():
|
32 |
+
query_vec = clip_model.encode_image(image).squeeze().numpy().astype("float32")
|
33 |
+
|
34 |
+
# Search in the FAISS image index, retrieve k most common
|
35 |
+
distances, indices = image_index.search(query_vec.reshape(1, -1), k=3)
|
36 |
+
# Format results in a structured way
|
37 |
+
results = [
|
38 |
+
{"content": image_metadata[idx], "distance": float(dist), "relevance_score": 1.0 / (1.0 + dist)}
|
39 |
+
for idx, dist in zip(indices[0], distances[0])
|
40 |
+
]
|
41 |
+
return results
|