File size: 5,601 Bytes
9bfc86c
1b04b96
 
d346441
99afa50
 
 
 
 
 
 
1b04b96
 
 
 
 
99afa50
1b04b96
 
 
99afa50
8848e89
99afa50
ce3af46
 
 
 
 
 
 
 
1b04b96
43b460f
d346441
99afa50
 
da626d3
1b04b96
d346441
 
1b04b96
99afa50
 
 
 
 
d346441
99afa50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d346441
99afa50
 
 
1b04b96
99afa50
d346441
1b04b96
 
 
99afa50
 
 
1b04b96
99afa50
1b04b96
43b460f
d346441
58a211a
3fd6562
58a211a
 
 
43b460f
3fd6562
43b460f
599d161
da626d3
8848e89
 
 
 
599d161
8848e89
 
 
 
599d161
da626d3
8848e89
7c78daa
a523549
7c78daa
d346441
 
 
 
99afa50
1b04b96
7c78daa
99afa50
7c78daa
d346441
 
99afa50
d346441
 
 
 
7c78daa
 
 
99afa50
 
 
 
 
 
 
ce3af46
 
 
 
 
 
599d161
ce3af46
 
e6167f8
 
 
 
 
 
 
 
4433c64
 
 
 
 
 
e6167f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import faiss
import torch
import json
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from datasets import load_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load embedding model
embedding_model = HuggingFaceEmbeddings(
    model_name="all-MiniLM-L12-v2",
    model_kwargs={"device": device}
)

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
query_dataset_data = {}

# File path for storing recently asked questions and metrics
RECENT_QUESTIONS_FILE = "data_local/recent_questions.json"

# Ensure the file exists and initialize if empty
if not os.path.exists(RECENT_QUESTIONS_FILE):
    with open(RECENT_QUESTIONS_FILE, "w") as file:
        json.dump({"questions": []}, file, indent=4)

all_documents = []
ragbench = {}
index = None  
chunk_docs = []
documents = [] 
query_dataset_data = {}

# Ensure data directory exists
os.makedirs("data_local", exist_ok=True)

# Initialize a text splitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1024,
    chunk_overlap=100
)

def chunk_documents(docs):
    chunks = [chunk for doc in docs for chunk in text_splitter.split_text(doc)]
    return chunks

def create_faiss_index(dataset):
    # Load dataset
    ragbench_dataset = load_dataset("rungalileo/ragbench", dataset)

    for split in ragbench_dataset.keys():
        for row in ragbench_dataset[split]:
            # Ensure document is a string before appending
            doc = row["documents"]
            if isinstance(doc, list):
                # If doc is a list, join its elements into a single string
                doc = " ".join(doc)
            documents.append(doc)  # Extract document text
            # Chunking

    chunked_documents = chunk_documents(documents)

    # Save documents in JSON (metadata storage)
    with open(f"{dataset}_chunked_docs.json", "w") as f:
        json.dump(chunked_documents, f)

    print(len(chunked_documents))
    # Convert to embeddings
    embeddings = embedding_model.embed_documents(chunked_documents)

    # Convert embeddings to a NumPy array
    embeddings_np = np.array(embeddings, dtype=np.float32)


    # Save FAISS index
    index = faiss.IndexHNSWFlat(embeddings_np.shape[1], 32)  # 32 is the graph size
    index.add(embeddings_np)
    faiss.write_index(index, f"{dataset}_chunked_index.faiss")

    print(f"{dataset} stored as individual FAISS index!")

def load_ragbench():
    global ragbench
    if ragbench:
        return ragbench  
    datasets = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 
                    'tatqa', 'techqa']
    for dataset in datasets:
        ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
    return ragbench

def load_query_dataset(q_dataset):
    global query_dataset_data

    if query_dataset_data.get(q_dataset):  
        return query_dataset_data[q_dataset]
    try:
        query_dataset_data[q_dataset] = load_dataset("rungalileo/ragbench", q_dataset)
    except Exception as e:
        print(f"Error loading dataset '{q_dataset}': {e}")
        return None  # Return None if the dataset fails to load

    return query_dataset_data[q_dataset]


def load_faiss(q_dataset):
    global index
    faiss_index_path = f"data_local/{q_dataset}_quantized.faiss"
    if os.path.exists(faiss_index_path):
        index = faiss.read_index(faiss_index_path)
        print("FAISS index loaded successfully.")
    else:
        print("FAISS index file not found. Run create_faiss_index_file() first.") 

def load_chunks(q_dataset):
    global chunk_docs
    metadata_path = f"data_local/{q_dataset}_chunked_docs.json"
    if os.path.exists(metadata_path):
        with open(metadata_path, "r") as f:
            chunk_docs = json.load(f)
        print("Metadata loaded successfully.")
    else:
        print("Metadata file not found. Run create_faiss_index_file() first.")

def load_data_from_faiss(q_dataset):
    load_faiss(q_dataset)
    load_chunks(q_dataset)

def rerank_documents(query, retrieved_docs):
    doc_texts = [doc for doc in retrieved_docs]
    scores = reranker.predict([[query, doc] for doc in doc_texts])
    ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
    return ranked_docs[:5]  # Return top 5 most relevant

def load_recent_questions():
    if os.path.exists(RECENT_QUESTIONS_FILE):
        with open(RECENT_QUESTIONS_FILE, "r") as file:
            return json.load(file)
    return {"questions": []}  # Default structure if file doesn't exist

def save_recent_question(question, response_time):
    data = load_recent_questions()

    #data["questions"] = [q for q in data["questions"] if q["question"] != question]
    if "question" in data["questions"] and question not in data["questions"]["question"]:
        # Append new question & metrics
        data["questions"].append({
            "question": question,
            "response_time": response_time
        })

    # Keep only the last 5 questions
    data["questions"]  = data["questions"][-5:]

    # Write back to file
    with open(RECENT_QUESTIONS_FILE, "w") as file:
        json.dump(data, file, indent=4)