23RAG7 / data_processing.py
cb1716pics's picture
Upload 3 files
4433c64 verified
import faiss
import torch
import json
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from datasets import load_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load embedding model
embedding_model = HuggingFaceEmbeddings(
model_name="all-MiniLM-L12-v2",
model_kwargs={"device": device}
)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
query_dataset_data = {}
# File path for storing recently asked questions and metrics
RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"
# Ensure the file exists and initialize if empty
if not os.path.exists(RECENT_QUESTIONS_FILE):
with open(RECENT_QUESTIONS_FILE, "w") as file:
json.dump({"questions": []}, file, indent=4)
all_documents = []
ragbench = {}
index = None
chunk_docs = []
documents = []
query_dataset_data = {}
# Ensure data directory exists
os.makedirs("data_local", exist_ok=True)
# Initialize a text splitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=100
)
def chunk_documents(docs):
chunks = [chunk for doc in docs for chunk in text_splitter.split_text(doc)]
return chunks
def create_faiss_index(dataset):
# Load dataset
ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)
for split in ragbench_dataset.keys():
for row in ragbench_dataset[split]:
# Ensure document is a string before appending
doc = row["documents"]
if isinstance(doc, list):
# If doc is a list, join its elements into a single string
doc = " ".join(doc)
documents.append(doc) # Extract document text
# Chunking
chunked_documents = chunk_documents(documents)
# Save documents in JSON (metadata storage)
with open(f"{dataset}_chunked_docs.json", "w") as f:
json.dump(chunked_documents, f)
print(len(chunked_documents))
# Convert to embeddings
embeddings = embedding_model.embed_documents(chunked_documents)
# Convert embeddings to a NumPy array
embeddings_np = np.array(embeddings, dtype=np.float32)
# Save FAISS index
index = faiss.IndexHNSWFlat(embeddings_np.shape[1], 32) # 32 is the graph size
index.add(embeddings_np)
faiss.write_index(index, f"{dataset}_chunked_index.faiss")
print(f"{dataset} stored as individual FAISS index!")
def load_ragbench():
global ragbench
if ragbench:
return ragbench
datasets = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa',
'tatqa', 'techqa']
for dataset in datasets:
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
return ragbench
def load_query_dataset(q_dataset):
global query_dataset_data
if query_dataset_data.get(q_dataset):
return query_dataset_data[q_dataset]
try:
query_dataset_data[q_dataset] = load_dataset("rungalileo/ragbench", q_dataset)
except Exception as e:
print(f"Error loading dataset '{q_dataset}': {e}")
return None # Return None if the dataset fails to load
return query_dataset_data[q_dataset]
def load_faiss(q_dataset):
global index
faiss_index_path = f"data_local/{q_dataset}_quantized.faiss"
if os.path.exists(faiss_index_path):
index = faiss.read_index(faiss_index_path)
print("FAISS index loaded successfully.")
else:
print("FAISS index file not found. Run create_faiss_index_file() first.")
def load_chunks(q_dataset):
global chunk_docs
metadata_path = f"data_local/{q_dataset}_chunked_docs.json"
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
chunk_docs = json.load(f)
print("Metadata loaded successfully.")
else:
print("Metadata file not found. Run create_faiss_index_file() first.")
def load_data_from_faiss(q_dataset):
load_faiss(q_dataset)
load_chunks(q_dataset)
def rerank_documents(query, retrieved_docs):
doc_texts = [doc for doc in retrieved_docs]
scores = reranker.predict([[query, doc] for doc in doc_texts])
ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
return ranked_docs[:5] # Return top 5 most relevant
def load_recent_questions():
if os.path.exists(RECENT_QUESTIONS_FILE):
with open(RECENT_QUESTIONS_FILE, "r") as file:
return json.load(file)
return {"questions": []} # Default structure if file doesn't exist
def save_recent_question(question, response_time):
data = load_recent_questions()
#data["questions"] = [q for q in data["questions"] if q["question"] != question]
if "question" in data["questions"] and question not in data["questions"]["question"]:
# Append new question & metrics
data["questions"].append({
"question": question,
"response_time": response_time
})
# Keep only the last 5 questions
data["questions"] = data["questions"][-5:]
# Write back to file
with open(RECENT_QUESTIONS_FILE, "w") as file:
json.dump(data, file, indent=4)