import os os.environ['OMP_NUM_THREADS'] = '1' # Limit OpenMP threads, might help prevent crashes import faiss from sentence_transformers import SentenceTransformer import numpy as np import pickle import json # Import json module from tqdm import tqdm # --- Configuration --- MODEL_DATA_DIR = "model_data_json" # Path to downloaded JSON data INDEX_FILE = "index.faiss" MAP_FILE = "index_to_metadata.pkl" # Changed filename to reflect content EMBEDDING_MODEL = 'all-mpnet-base-v2' # Efficient and good quality model ENCODE_BATCH_SIZE = 32 # Process descriptions in smaller batches # Tags to exclude from indexing text COMMON_EXCLUDED_TAGS = {'transformers'} # Add other common tags if needed EXCLUDED_TAG_PREFIXES = ('arxiv:', 'base_model:', 'dataset:', 'diffusers:', 'license:') # Add other prefixes if needed MODEL_EXPLANATION_KEY = "model_explanation_gemini" # Key for the new explanation field # --- def load_model_data(directory): """Loads model data, filters tags (by length, common words, prefixes), and combines relevant info for indexing.""" all_texts = [] # Store combined text (model_id + description + filtered_tags) all_metadata = [] # Store dicts: {'model_id': ..., 'tags': ..., 'downloads': ...} print(f"Loading model data from JSON files in: {directory}") if not os.path.isdir(directory): print(f"Error: Directory not found: {directory}") return [], [] filenames = [f for f in os.listdir(directory) if f.endswith(".json")] # Look for .json files for filename in tqdm(filenames, desc="Reading JSON files"): filepath = os.path.join(directory, filename) try: with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) # Ensure required fields exist if 'description' in data and 'model_id' in data: description = data['description'] model_id = data['model_id'] # Get model_id if description: # Only index if description is not empty original_tags = data.get('tags', []) # Filter tags: remove short tags, common tags, and tags with specific prefixes filtered_tags = [ str_tag for tag in original_tags if ( tag and isinstance(tag, str) and # Ensure tag exists and is a string len(tag) > 3 and (str_tag := str(tag)).lower() not in COMMON_EXCLUDED_TAGS and not str_tag.lower().startswith(EXCLUDED_TAG_PREFIXES) # Check for prefixes ) ] tag_string = " ".join(filtered_tags) explanation = data.get(MODEL_EXPLANATION_KEY) # Get the new explanation # Get the new metadata fields release_year = data.get('release_year') parameter_count = data.get('parameter_count') is_fine_tuned = data.get('is_fine_tuned', False) category = data.get('category', 'Other') model_family = data.get('model_family') # --- Construct combined text with priority weighting --- text_parts = [] # 1. Add explanation (repeated for emphasis) if available if explanation and isinstance(explanation, str): text_parts.append(f"Summary: {explanation}") text_parts.append(f"Summary: {explanation}") # Repeat for higher weight # 2. Add model name text_parts.append(f"Model: {model_id}") # 3. Add filtered tags if available if tag_string: text_parts.append(f"Tags: {tag_string}") # 4. Add category, model family and parameter count for better search if category: text_parts.append(f"Category: {category}") if model_family: text_parts.append(f"Family: {model_family}") if parameter_count: text_parts.append(f"Parameters: {parameter_count}") if release_year: text_parts.append(f"Year: {release_year}") if is_fine_tuned: text_parts.append("Fine-tuned model") # 5. Add original description text_parts.append(f"Description: {description}") combined_text = " ".join(text_parts).strip() # Join all parts # --- End construction --- all_texts.append(combined_text) # Add all metadata to the entry metadata_entry = { "model_id": model_id, "tags": original_tags, # Keep ORIGINAL tags in metadata "downloads": data.get('downloads', 0) } if explanation and isinstance(explanation, str): metadata_entry[MODEL_EXPLANATION_KEY] = explanation # Add the new metadata fields if release_year: metadata_entry["release_year"] = release_year if parameter_count: metadata_entry["parameter_count"] = parameter_count if is_fine_tuned is not None: metadata_entry["is_fine_tuned"] = is_fine_tuned if category: metadata_entry["category"] = category if model_family: metadata_entry["model_family"] = model_family all_metadata.append(metadata_entry) else: print(f"Warning: Skipping {filename}, missing 'description' or 'model_id' key.") except json.JSONDecodeError: print(f"Warning: Skipping {filename}, invalid JSON.") except Exception as e: print(f"Warning: Could not read or process {filename}: {e}") print(f"Loaded data for {len(all_texts)} models with valid descriptions after tag filtering.") return all_texts, all_metadata def build_and_save_index(texts_to_index, metadata_list): """Builds and saves the FAISS index and metadata mapping based on combined text.""" if not texts_to_index: print("No text data to index.") return print(f"Loading sentence transformer model: {EMBEDDING_MODEL}") # Consider adding device='mps' if on Apple Silicon and PyTorch supports it well enough, # but start with CPU for stability. model = SentenceTransformer(EMBEDDING_MODEL) print(f"Generating embeddings for combined text in batches of {ENCODE_BATCH_SIZE}...") all_embeddings = [] for i in tqdm(range(0, len(texts_to_index), ENCODE_BATCH_SIZE), desc="Encoding batches"): batch = texts_to_index[i:i+ENCODE_BATCH_SIZE] batch_embeddings = model.encode(batch, convert_to_numpy=True) all_embeddings.append(batch_embeddings) if not all_embeddings: print("No embeddings generated. Cannot build index.") return embeddings = np.vstack(all_embeddings) # Combine embeddings from all batches # Ensure embeddings are float32 for FAISS embeddings = embeddings.astype('float32') # Build FAISS index print("Building FAISS index...") dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) # Using simple L2 distance index.add(embeddings) print(f"FAISS index built with {index.ntotal} vectors.") # Save the index faiss.write_index(index, INDEX_FILE) print(f"FAISS index saved to: {INDEX_FILE}") # Create mapping from index position to metadata dictionary index_to_metadata = {i: metadata for i, metadata in enumerate(metadata_list)} with open(MAP_FILE, 'wb') as f: pickle.dump(index_to_metadata, f) print(f"Index-to-Metadata mapping saved to: {MAP_FILE}") if __name__ == "__main__": combined_texts, metadata_list = load_model_data(MODEL_DATA_DIR) build_and_save_index(combined_texts, metadata_list) print("\nIndex building complete.")