Spaces:
Running
Running
File size: 6,968 Bytes
0db8b33 8181a7b 0db8b33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
os.environ['OMP_NUM_THREADS'] = '1' # Limit OpenMP threads, might help prevent crashes
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np
import pickle
import json # Import json module
from tqdm import tqdm
# --- Configuration ---
MODEL_DATA_DIR = "model_data_json" # Path to downloaded JSON data
INDEX_FILE = "index.faiss"
MAP_FILE = "index_to_metadata.pkl" # Changed filename to reflect content
EMBEDDING_MODEL = 'all-mpnet-base-v2' # Efficient and good quality model
ENCODE_BATCH_SIZE = 32 # Process descriptions in smaller batches
# Tags to exclude from indexing text
COMMON_EXCLUDED_TAGS = {'transformers'} # Add other common tags if needed
EXCLUDED_TAG_PREFIXES = ('arxiv:', 'base_model:', 'dataset:', 'diffusers:', 'license:') # Add other prefixes if needed
MODEL_EXPLANATION_KEY = "model_explanation_gemini" # Key for the new explanation field
# ---
def load_model_data(directory):
"""Loads model data, filters tags (by length, common words, prefixes), and combines relevant info for indexing."""
all_texts = [] # Store combined text (model_id + description + filtered_tags)
all_metadata = [] # Store dicts: {'model_id': ..., 'tags': ..., 'downloads': ...}
print(f"Loading model data from JSON files in: {directory}")
if not os.path.isdir(directory):
print(f"Error: Directory not found: {directory}")
return [], []
filenames = [f for f in os.listdir(directory) if f.endswith(".json")] # Look for .json files
for filename in tqdm(filenames, desc="Reading JSON files"):
filepath = os.path.join(directory, filename)
try:
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
# Ensure required fields exist
if 'description' in data and 'model_id' in data:
description = data['description']
model_id = data['model_id'] # Get model_id
if description: # Only index if description is not empty
original_tags = data.get('tags', [])
# Filter tags: remove short tags, common tags, and tags with specific prefixes
filtered_tags = [
str_tag for tag in original_tags
if (
tag and isinstance(tag, str) and # Ensure tag exists and is a string
len(tag) > 3 and
(str_tag := str(tag)).lower() not in COMMON_EXCLUDED_TAGS and
not str_tag.lower().startswith(EXCLUDED_TAG_PREFIXES) # Check for prefixes
)
]
tag_string = " ".join(filtered_tags)
explanation = data.get(MODEL_EXPLANATION_KEY) # Get the new explanation
# --- Construct combined text with priority weighting ---
text_parts = []
# 1. Add explanation (repeated for emphasis) if available
if explanation and isinstance(explanation, str):
text_parts.append(f"Summary: {explanation}")
text_parts.append(f"Summary: {explanation}") # Repeat for higher weight
# 2. Add model name
text_parts.append(f"Model: {model_id}")
# 3. Add filtered tags if available
if tag_string:
text_parts.append(f"Tags: {tag_string}")
# 4. Add original description
text_parts.append(f"Description: {description}")
combined_text = " ".join(text_parts).strip() # Join all parts
# --- End construction ---
all_texts.append(combined_text)
# Add explanation to metadata as well for potential display
metadata_entry = {
"model_id": model_id,
"tags": original_tags, # Keep ORIGINAL tags in metadata
"downloads": data.get('downloads', 0)
}
if explanation and isinstance(explanation, str):
metadata_entry[MODEL_EXPLANATION_KEY] = explanation
all_metadata.append(metadata_entry)
else:
print(f"Warning: Skipping {filename}, missing 'description' or 'model_id' key.")
except json.JSONDecodeError:
print(f"Warning: Skipping {filename}, invalid JSON.")
except Exception as e:
print(f"Warning: Could not read or process {filename}: {e}")
print(f"Loaded data for {len(all_texts)} models with valid descriptions after tag filtering.")
return all_texts, all_metadata
def build_and_save_index(texts_to_index, metadata_list):
"""Builds and saves the FAISS index and metadata mapping based on combined text."""
if not texts_to_index:
print("No text data to index.")
return
print(f"Loading sentence transformer model: {EMBEDDING_MODEL}")
# Consider adding device='mps' if on Apple Silicon and PyTorch supports it well enough,
# but start with CPU for stability.
model = SentenceTransformer(EMBEDDING_MODEL)
print(f"Generating embeddings for combined text in batches of {ENCODE_BATCH_SIZE}...")
all_embeddings = []
for i in tqdm(range(0, len(texts_to_index), ENCODE_BATCH_SIZE), desc="Encoding batches"):
batch = texts_to_index[i:i+ENCODE_BATCH_SIZE]
batch_embeddings = model.encode(batch, convert_to_numpy=True)
all_embeddings.append(batch_embeddings)
if not all_embeddings:
print("No embeddings generated. Cannot build index.")
return
embeddings = np.vstack(all_embeddings) # Combine embeddings from all batches
# Ensure embeddings are float32 for FAISS
embeddings = embeddings.astype('float32')
# Build FAISS index
print("Building FAISS index...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension) # Using simple L2 distance
index.add(embeddings)
print(f"FAISS index built with {index.ntotal} vectors.")
# Save the index
faiss.write_index(index, INDEX_FILE)
print(f"FAISS index saved to: {INDEX_FILE}")
# Create mapping from index position to metadata dictionary
index_to_metadata = {i: metadata for i, metadata in enumerate(metadata_list)}
with open(MAP_FILE, 'wb') as f:
pickle.dump(index_to_metadata, f)
print(f"Index-to-Metadata mapping saved to: {MAP_FILE}")
if __name__ == "__main__":
combined_texts, metadata_list = load_model_data(MODEL_DATA_DIR)
build_and_save_index(combined_texts, metadata_list)
print("\nIndex building complete.") |