File size: 6,968 Bytes
0db8b33
 
 
 
 
 
 
 
 
 
8181a7b
 
 
0db8b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
os.environ['OMP_NUM_THREADS'] = '1' # Limit OpenMP threads, might help prevent crashes
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np
import pickle
import json # Import json module
from tqdm import tqdm

# --- Configuration ---
MODEL_DATA_DIR = "model_data_json"  # Path to downloaded JSON data
INDEX_FILE = "index.faiss"
MAP_FILE = "index_to_metadata.pkl" # Changed filename to reflect content
EMBEDDING_MODEL = 'all-mpnet-base-v2'  # Efficient and good quality model
ENCODE_BATCH_SIZE = 32  # Process descriptions in smaller batches
# Tags to exclude from indexing text
COMMON_EXCLUDED_TAGS = {'transformers'} # Add other common tags if needed
EXCLUDED_TAG_PREFIXES = ('arxiv:', 'base_model:', 'dataset:', 'diffusers:', 'license:') # Add other prefixes if needed
MODEL_EXPLANATION_KEY = "model_explanation_gemini" # Key for the new explanation field
# ---

def load_model_data(directory):
    """Loads model data, filters tags (by length, common words, prefixes), and combines relevant info for indexing."""
    all_texts = [] # Store combined text (model_id + description + filtered_tags)
    all_metadata = [] # Store dicts: {'model_id': ..., 'tags': ..., 'downloads': ...}
    print(f"Loading model data from JSON files in: {directory}")
    if not os.path.isdir(directory):
        print(f"Error: Directory not found: {directory}")
        return [], []

    filenames = [f for f in os.listdir(directory) if f.endswith(".json")] # Look for .json files
    for filename in tqdm(filenames, desc="Reading JSON files"):
        filepath = os.path.join(directory, filename)
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                data = json.load(f)
                # Ensure required fields exist
                if 'description' in data and 'model_id' in data:
                    description = data['description']
                    model_id = data['model_id'] # Get model_id
                    if description: # Only index if description is not empty
                        original_tags = data.get('tags', [])
                        # Filter tags: remove short tags, common tags, and tags with specific prefixes
                        filtered_tags = [
                            str_tag for tag in original_tags
                            if (
                                tag and isinstance(tag, str) and # Ensure tag exists and is a string
                                len(tag) > 3 and
                                (str_tag := str(tag)).lower() not in COMMON_EXCLUDED_TAGS and
                                not str_tag.lower().startswith(EXCLUDED_TAG_PREFIXES) # Check for prefixes
                            )
                        ]
                        tag_string = " ".join(filtered_tags)
                        explanation = data.get(MODEL_EXPLANATION_KEY) # Get the new explanation

                        # --- Construct combined text with priority weighting ---
                        text_parts = []
                        # 1. Add explanation (repeated for emphasis) if available
                        if explanation and isinstance(explanation, str):
                            text_parts.append(f"Summary: {explanation}")
                            text_parts.append(f"Summary: {explanation}") # Repeat for higher weight
                        # 2. Add model name
                        text_parts.append(f"Model: {model_id}")
                        # 3. Add filtered tags if available
                        if tag_string:
                            text_parts.append(f"Tags: {tag_string}")
                        # 4. Add original description
                        text_parts.append(f"Description: {description}")

                        combined_text = " ".join(text_parts).strip() # Join all parts
                        # --- End construction ---

                        all_texts.append(combined_text)
                        # Add explanation to metadata as well for potential display
                        metadata_entry = {
                            "model_id": model_id,
                            "tags": original_tags, # Keep ORIGINAL tags in metadata
                            "downloads": data.get('downloads', 0)
                        }
                        if explanation and isinstance(explanation, str):
                            metadata_entry[MODEL_EXPLANATION_KEY] = explanation
                        all_metadata.append(metadata_entry)
                else:
                    print(f"Warning: Skipping {filename}, missing 'description' or 'model_id' key.")
        except json.JSONDecodeError:
            print(f"Warning: Skipping {filename}, invalid JSON.")
        except Exception as e:
            print(f"Warning: Could not read or process {filename}: {e}")

    print(f"Loaded data for {len(all_texts)} models with valid descriptions after tag filtering.")
    return all_texts, all_metadata

def build_and_save_index(texts_to_index, metadata_list):
    """Builds and saves the FAISS index and metadata mapping based on combined text."""
    if not texts_to_index:
        print("No text data to index.")
        return

    print(f"Loading sentence transformer model: {EMBEDDING_MODEL}")
    # Consider adding device='mps' if on Apple Silicon and PyTorch supports it well enough,
    # but start with CPU for stability.
    model = SentenceTransformer(EMBEDDING_MODEL)

    print(f"Generating embeddings for combined text in batches of {ENCODE_BATCH_SIZE}...")
    all_embeddings = []
    for i in tqdm(range(0, len(texts_to_index), ENCODE_BATCH_SIZE), desc="Encoding batches"):
        batch = texts_to_index[i:i+ENCODE_BATCH_SIZE]
        batch_embeddings = model.encode(batch, convert_to_numpy=True)
        all_embeddings.append(batch_embeddings)

    if not all_embeddings:
        print("No embeddings generated. Cannot build index.")
        return

    embeddings = np.vstack(all_embeddings) # Combine embeddings from all batches

    # Ensure embeddings are float32 for FAISS
    embeddings = embeddings.astype('float32')

    # Build FAISS index
    print("Building FAISS index...")
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)  # Using simple L2 distance
    index.add(embeddings)
    print(f"FAISS index built with {index.ntotal} vectors.")

    # Save the index
    faiss.write_index(index, INDEX_FILE)
    print(f"FAISS index saved to: {INDEX_FILE}")

    # Create mapping from index position to metadata dictionary
    index_to_metadata = {i: metadata for i, metadata in enumerate(metadata_list)}
    with open(MAP_FILE, 'wb') as f:
        pickle.dump(index_to_metadata, f)
    print(f"Index-to-Metadata mapping saved to: {MAP_FILE}")

if __name__ == "__main__":
    combined_texts, metadata_list = load_model_data(MODEL_DATA_DIR)
    build_and_save_index(combined_texts, metadata_list)
    print("\nIndex building complete.")