Spaces:
Sleeping
Sleeping
import faiss | |
import torch | |
import json | |
import os | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from datasets import load_dataset | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from sentence_transformers import CrossEncoder | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load embedding model | |
embedding_model = HuggingFaceEmbeddings( | |
model_name="all-MiniLM-L12-v2", | |
model_kwargs={"device": device} | |
) | |
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
query_dataset_data = {} | |
# File path for storing recently asked questions and metrics | |
RECENT_QUESTIONS_FILE = "data_local/recent_questions.json" | |
# Ensure the file exists and initialize if empty | |
if not os.path.exists(RECENT_QUESTIONS_FILE): | |
with open(RECENT_QUESTIONS_FILE, "w") as file: | |
json.dump({"questions": []}, file, indent=4) | |
all_documents = [] | |
ragbench = {} | |
index = None | |
chunk_docs = [] | |
documents = [] | |
query_dataset_data = {} | |
# Ensure data directory exists | |
os.makedirs("data_local", exist_ok=True) | |
# Initialize a text splitter | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1024, | |
chunk_overlap=100 | |
) | |
def chunk_documents(docs): | |
chunks = [chunk for doc in docs for chunk in text_splitter.split_text(doc)] | |
return chunks | |
def create_faiss_index(dataset): | |
# Load dataset | |
ragbench_dataset = load_dataset("rungalileo/ragbench", dataset) | |
for split in ragbench_dataset.keys(): | |
for row in ragbench_dataset[split]: | |
# Ensure document is a string before appending | |
doc = row["documents"] | |
if isinstance(doc, list): | |
# If doc is a list, join its elements into a single string | |
doc = " ".join(doc) | |
documents.append(doc) # Extract document text | |
# Chunking | |
chunked_documents = chunk_documents(documents) | |
# Save documents in JSON (metadata storage) | |
with open(f"{dataset}_chunked_docs.json", "w") as f: | |
json.dump(chunked_documents, f) | |
print(len(chunked_documents)) | |
# Convert to embeddings | |
embeddings = embedding_model.embed_documents(chunked_documents) | |
# Convert embeddings to a NumPy array | |
embeddings_np = np.array(embeddings, dtype=np.float32) | |
# Save FAISS index | |
index = faiss.IndexHNSWFlat(embeddings_np.shape[1], 32) # 32 is the graph size | |
index.add(embeddings_np) | |
faiss.write_index(index, f"{dataset}_chunked_index.faiss") | |
print(f"{dataset} stored as individual FAISS index!") | |
def load_ragbench(): | |
global ragbench | |
if ragbench: | |
return ragbench | |
datasets = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', | |
'tatqa', 'techqa'] | |
for dataset in datasets: | |
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset) | |
return ragbench | |
def load_query_dataset(q_dataset): | |
global query_dataset_data | |
if query_dataset_data.get(q_dataset): | |
return query_dataset_data[q_dataset] | |
try: | |
query_dataset_data[q_dataset] = load_dataset("rungalileo/ragbench", q_dataset) | |
except Exception as e: | |
print(f"Error loading dataset '{q_dataset}': {e}") | |
return None # Return None if the dataset fails to load | |
return query_dataset_data[q_dataset] | |
def load_faiss(q_dataset): | |
global index | |
faiss_index_path = f"data_local/{q_dataset}_quantized.faiss" | |
if os.path.exists(faiss_index_path): | |
index = faiss.read_index(faiss_index_path) | |
print("FAISS index loaded successfully.") | |
else: | |
print("FAISS index file not found. Run create_faiss_index_file() first.") | |
def load_chunks(q_dataset): | |
global chunk_docs | |
metadata_path = f"data_local/{q_dataset}_chunked_docs.json" | |
if os.path.exists(metadata_path): | |
with open(metadata_path, "r") as f: | |
chunk_docs = json.load(f) | |
print("Metadata loaded successfully.") | |
else: | |
print("Metadata file not found. Run create_faiss_index_file() first.") | |
def load_data_from_faiss(q_dataset): | |
load_faiss(q_dataset) | |
load_chunks(q_dataset) | |
def rerank_documents(query, retrieved_docs): | |
doc_texts = [doc for doc in retrieved_docs] | |
scores = reranker.predict([[query, doc] for doc in doc_texts]) | |
ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)] | |
return ranked_docs[:5] # Return top 5 most relevant | |
def load_recent_questions(): | |
if os.path.exists(RECENT_QUESTIONS_FILE): | |
with open(RECENT_QUESTIONS_FILE, "r") as file: | |
return json.load(file) | |
return {"questions": []} # Default structure if file doesn't exist | |
def save_recent_question(question, response_time): | |
data = load_recent_questions() | |
#data["questions"] = [q for q in data["questions"] if q["question"] != question] | |
if "question" in data["questions"] and question not in data["questions"]["question"]: | |
# Append new question & metrics | |
data["questions"].append({ | |
"question": question, | |
"response_time": response_time | |
}) | |
# Keep only the last 5 questions | |
data["questions"] = data["questions"][-5:] | |
# Write back to file | |
with open(RECENT_QUESTIONS_FILE, "w") as file: | |
json.dump(data, file, indent=4) | |