File size: 1,799 Bytes
e8957e6
 
 
 
 
309d09c
e8957e6
 
 
 
 
b7cb86a
 
 
 
 
 
 
e8957e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import faiss
from sentence_transformers import SentenceTransformer
import clip
from PIL import Image
import torch
import json

# Load FAISS indexes
text_index = faiss.read_index("text_recipes_test_br_00.ocean")
image_index = faiss.read_index("image_recipes_test_br_00.ocean")

# Load metadata
with open("text_metadata.json", "r", encoding="utf-8") as f:
    text_metadata = json.load(f)

with open("image_metadata.json", "r", encoding="utf-8") as f:
    image_metadata = json.load(f)

# Initialize models for text and image processing
text_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L12-v2")
clip_model, clip_preprocess = clip.load("ViT-B/32", device="cpu")

def search_text(query):
    # Generate text embedding
    query_vec = text_encoder.encode(query).astype("float32")
    
    # Search in the FAISS text index, retrieve k most common
    distances, indices = text_index.search(query_vec.reshape(1, -1), k=3)
    # Format results in a structured way
    results = [
        {"content": text_metadata[idx], "distance": float(dist), "relevance_score": 1.0 / (1.0 + dist)}
        for idx, dist in zip(indices[0], distances[0])
    ]
    return results

def search_image(image_file):
    # Generate image embedding
    image = clip_preprocess(Image.open(image_file)).unsqueeze(0)
    with torch.no_grad():
        query_vec = clip_model.encode_image(image).squeeze().numpy().astype("float32")
    
    # Search in the FAISS image index, retrieve k most common
    distances, indices = image_index.search(query_vec.reshape(1, -1), k=3)
    # Format results in a structured way
    results = [
        {"content": image_metadata[idx], "distance": float(dist), "relevance_score": 1.0 / (1.0 + dist)}
        for idx, dist in zip(indices[0], distances[0])
    ]
    return results