File size: 8,686 Bytes
0db8b33
 
 
 
 
 
 
 
 
 
8181a7b
 
 
0db8b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3214c9b
 
 
 
 
 
 
0db8b33
 
 
 
 
 
 
 
 
 
 
3214c9b
 
 
 
 
 
 
 
 
 
 
 
0db8b33
 
 
 
 
 
3214c9b
0db8b33
 
 
 
 
 
 
3214c9b
 
 
 
 
 
 
 
 
 
 
 
 
0db8b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
os.environ['OMP_NUM_THREADS'] = '1' # Limit OpenMP threads, might help prevent crashes
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np
import pickle
import json # Import json module
from tqdm import tqdm

# --- Configuration ---
MODEL_DATA_DIR = "model_data_json"  # Path to downloaded JSON data
INDEX_FILE = "index.faiss"
MAP_FILE = "index_to_metadata.pkl" # Changed filename to reflect content
EMBEDDING_MODEL = 'all-mpnet-base-v2'  # Efficient and good quality model
ENCODE_BATCH_SIZE = 32  # Process descriptions in smaller batches
# Tags to exclude from indexing text
COMMON_EXCLUDED_TAGS = {'transformers'} # Add other common tags if needed
EXCLUDED_TAG_PREFIXES = ('arxiv:', 'base_model:', 'dataset:', 'diffusers:', 'license:') # Add other prefixes if needed
MODEL_EXPLANATION_KEY = "model_explanation_gemini" # Key for the new explanation field
# ---

def load_model_data(directory):
    """Loads model data, filters tags (by length, common words, prefixes), and combines relevant info for indexing."""
    all_texts = [] # Store combined text (model_id + description + filtered_tags)
    all_metadata = [] # Store dicts: {'model_id': ..., 'tags': ..., 'downloads': ...}
    print(f"Loading model data from JSON files in: {directory}")
    if not os.path.isdir(directory):
        print(f"Error: Directory not found: {directory}")
        return [], []

    filenames = [f for f in os.listdir(directory) if f.endswith(".json")] # Look for .json files
    for filename in tqdm(filenames, desc="Reading JSON files"):
        filepath = os.path.join(directory, filename)
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                data = json.load(f)
                # Ensure required fields exist
                if 'description' in data and 'model_id' in data:
                    description = data['description']
                    model_id = data['model_id'] # Get model_id
                    if description: # Only index if description is not empty
                        original_tags = data.get('tags', [])
                        # Filter tags: remove short tags, common tags, and tags with specific prefixes
                        filtered_tags = [
                            str_tag for tag in original_tags
                            if (
                                tag and isinstance(tag, str) and # Ensure tag exists and is a string
                                len(tag) > 3 and
                                (str_tag := str(tag)).lower() not in COMMON_EXCLUDED_TAGS and
                                not str_tag.lower().startswith(EXCLUDED_TAG_PREFIXES) # Check for prefixes
                            )
                        ]
                        tag_string = " ".join(filtered_tags)
                        explanation = data.get(MODEL_EXPLANATION_KEY) # Get the new explanation

                        # Get the new metadata fields
                        release_year = data.get('release_year')
                        parameter_count = data.get('parameter_count')
                        is_fine_tuned = data.get('is_fine_tuned', False)
                        category = data.get('category', 'Other')
                        model_family = data.get('model_family')

                        # --- Construct combined text with priority weighting ---
                        text_parts = []
                        # 1. Add explanation (repeated for emphasis) if available
                        if explanation and isinstance(explanation, str):
                            text_parts.append(f"Summary: {explanation}")
                            text_parts.append(f"Summary: {explanation}") # Repeat for higher weight
                        # 2. Add model name
                        text_parts.append(f"Model: {model_id}")
                        # 3. Add filtered tags if available
                        if tag_string:
                            text_parts.append(f"Tags: {tag_string}")
                        # 4. Add category, model family and parameter count for better search
                        if category:
                            text_parts.append(f"Category: {category}")
                        if model_family:
                            text_parts.append(f"Family: {model_family}")
                        if parameter_count:
                            text_parts.append(f"Parameters: {parameter_count}")
                        if release_year:
                            text_parts.append(f"Year: {release_year}")
                        if is_fine_tuned:
                            text_parts.append("Fine-tuned model")
                        # 5. Add original description
                        text_parts.append(f"Description: {description}")

                        combined_text = " ".join(text_parts).strip() # Join all parts
                        # --- End construction ---

                        all_texts.append(combined_text)
                        # Add all metadata to the entry
                        metadata_entry = {
                            "model_id": model_id,
                            "tags": original_tags, # Keep ORIGINAL tags in metadata
                            "downloads": data.get('downloads', 0)
                        }
                        if explanation and isinstance(explanation, str):
                            metadata_entry[MODEL_EXPLANATION_KEY] = explanation
                        
                        # Add the new metadata fields
                        if release_year:
                            metadata_entry["release_year"] = release_year
                        if parameter_count:
                            metadata_entry["parameter_count"] = parameter_count
                        if is_fine_tuned is not None:
                            metadata_entry["is_fine_tuned"] = is_fine_tuned
                        if category:
                            metadata_entry["category"] = category
                        if model_family:
                            metadata_entry["model_family"] = model_family
                            
                        all_metadata.append(metadata_entry)
                else:
                    print(f"Warning: Skipping {filename}, missing 'description' or 'model_id' key.")
        except json.JSONDecodeError:
            print(f"Warning: Skipping {filename}, invalid JSON.")
        except Exception as e:
            print(f"Warning: Could not read or process {filename}: {e}")

    print(f"Loaded data for {len(all_texts)} models with valid descriptions after tag filtering.")
    return all_texts, all_metadata

def build_and_save_index(texts_to_index, metadata_list):
    """Builds and saves the FAISS index and metadata mapping based on combined text."""
    if not texts_to_index:
        print("No text data to index.")
        return

    print(f"Loading sentence transformer model: {EMBEDDING_MODEL}")
    # Consider adding device='mps' if on Apple Silicon and PyTorch supports it well enough,
    # but start with CPU for stability.
    model = SentenceTransformer(EMBEDDING_MODEL)

    print(f"Generating embeddings for combined text in batches of {ENCODE_BATCH_SIZE}...")
    all_embeddings = []
    for i in tqdm(range(0, len(texts_to_index), ENCODE_BATCH_SIZE), desc="Encoding batches"):
        batch = texts_to_index[i:i+ENCODE_BATCH_SIZE]
        batch_embeddings = model.encode(batch, convert_to_numpy=True)
        all_embeddings.append(batch_embeddings)

    if not all_embeddings:
        print("No embeddings generated. Cannot build index.")
        return

    embeddings = np.vstack(all_embeddings) # Combine embeddings from all batches

    # Ensure embeddings are float32 for FAISS
    embeddings = embeddings.astype('float32')

    # Build FAISS index
    print("Building FAISS index...")
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)  # Using simple L2 distance
    index.add(embeddings)
    print(f"FAISS index built with {index.ntotal} vectors.")

    # Save the index
    faiss.write_index(index, INDEX_FILE)
    print(f"FAISS index saved to: {INDEX_FILE}")

    # Create mapping from index position to metadata dictionary
    index_to_metadata = {i: metadata for i, metadata in enumerate(metadata_list)}
    with open(MAP_FILE, 'wb') as f:
        pickle.dump(index_to_metadata, f)
    print(f"Index-to-Metadata mapping saved to: {MAP_FILE}")

if __name__ == "__main__":
    combined_texts, metadata_list = load_model_data(MODEL_DATA_DIR)
    build_and_save_index(combined_texts, metadata_list)
    print("\nIndex building complete.")