|
import streamlit as st |
|
import faiss |
|
import numpy as np |
|
import pickle |
|
import json |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import ( |
|
pipeline, |
|
RagTokenizer, |
|
RagRetriever, |
|
RagSequenceForGeneration, |
|
) |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
INDEX_FILE = "faiss_index.index" |
|
CHUNKS_FILE = "chunks.pkl" |
|
CURATED_QA_FILE = "curated_qa_pairs.json" |
|
|
|
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" |
|
QA_MODEL_NAME = "deepset/roberta-large-squad2" |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_index_and_chunks(): |
|
try: |
|
index = faiss.read_index(INDEX_FILE) |
|
with open(CHUNKS_FILE, "rb") as f: |
|
chunks = pickle.load(f) |
|
return index, chunks |
|
except Exception as e: |
|
st.error(f"Error loading FAISS index and chunks: {e}") |
|
return None, None |
|
|
|
@st.cache_resource |
|
def load_embedding_model(): |
|
try: |
|
model = SentenceTransformer(EMBEDDING_MODEL_NAME) |
|
return model |
|
except Exception as e: |
|
st.error(f"Error loading embedding model: {e}") |
|
return None |
|
|
|
@st.cache_resource |
|
def load_qa_pipeline(): |
|
try: |
|
qa_pipe = pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME) |
|
return qa_pipe |
|
except Exception as e: |
|
st.error(f"Error loading QA pipeline: {e}") |
|
return None |
|
|
|
@st.cache_resource |
|
def load_curated_qa_pairs(json_file=CURATED_QA_FILE): |
|
try: |
|
with open(json_file, "r", encoding="utf-8") as f: |
|
curated_qa_pairs = json.load(f) |
|
return curated_qa_pairs |
|
except Exception as e: |
|
st.error(f"Error loading curated Q/A pairs from JSON: {e}") |
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
def get_curated_pair(query, curated_qa, embed_model, threshold=1.0): |
|
try: |
|
curated_questions = [qa["question"] for qa in curated_qa] |
|
query_embedding = embed_model.encode([query]).astype("float32") |
|
curated_embeddings = embed_model.encode(curated_questions, show_progress_bar=False) |
|
curated_embeddings = np.array(curated_embeddings).astype("float32") |
|
|
|
|
|
dimension = curated_embeddings.shape[1] |
|
curated_index = faiss.IndexFlatL2(dimension) |
|
curated_index.add(curated_embeddings) |
|
|
|
k = 1 |
|
distances, indices = curated_index.search(query_embedding, k) |
|
|
|
if distances[0][0] < threshold: |
|
idx = indices[0][0] |
|
return curated_qa[idx] |
|
except Exception as e: |
|
st.error(f"Error retrieving curated Q/A pair: {e}") |
|
return None |
|
|
|
|
|
|
|
|
|
class CustomRagRetriever(RagRetriever): |
|
""" |
|
A custom retriever that uses your FAISS index and passages. |
|
It encodes the query with the provided embedding model, |
|
searches your FAISS index, and returns the top retrieved documents. |
|
""" |
|
def __init__(self, config, faiss_index, passages, embed_model, tokenizer, n_docs=5): |
|
self.faiss_index = faiss_index |
|
self.passages = passages |
|
self.embed_model = embed_model |
|
self.n_docs = n_docs |
|
self.tokenizer = tokenizer |
|
|
|
self.init_retrieval = lambda: None |
|
|
|
super().__init__(config, question_encoder_tokenizer=tokenizer, generator_tokenizer=tokenizer) |
|
|
|
def retrieve(self, query, n_docs=None): |
|
try: |
|
if n_docs is None: |
|
n_docs = self.n_docs |
|
|
|
query_embedding = self.embed_model.encode([query]).astype("float32") |
|
distances, indices = self.faiss_index.search(query_embedding, n_docs) |
|
retrieved_docs = [self.passages[i] for i in indices[0]] |
|
return { |
|
"doc_ids": indices, |
|
"doc_scores": distances, |
|
"retrieved_docs": retrieved_docs, |
|
} |
|
except Exception as e: |
|
st.error(f"Error in custom retrieval: {e}") |
|
return {"doc_ids": None, "doc_scores": None, "retrieved_docs": []} |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_rag_model(_faiss_index, passages, _embed_model): |
|
try: |
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") |
|
rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") |
|
|
|
custom_retriever = CustomRagRetriever( |
|
config=rag_model.config, |
|
faiss_index=_faiss_index, |
|
passages=passages, |
|
embed_model=_embed_model, |
|
tokenizer=tokenizer, |
|
n_docs=5 |
|
) |
|
rag_model.set_retriever(custom_retriever) |
|
return tokenizer, rag_model |
|
except Exception as e: |
|
st.error(f"Error loading RAG model with custom retriever: {e}") |
|
return None, None |
|
|
|
def generate_rag_answer(query, tokenizer, rag_model): |
|
try: |
|
inputs = tokenizer(query, return_tensors="pt") |
|
with torch.no_grad(): |
|
generated_ids = rag_model.generate(**inputs) |
|
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return answer |
|
except Exception as e: |
|
st.error(f"Error generating answer with RAG model: {e}") |
|
return "" |
|
|
|
|
|
|
|
|
|
def main(): |
|
st.title("Takalama") |
|
|
|
if "conversation_history" not in st.session_state: |
|
st.session_state.conversation_history = "" |
|
|
|
with st.spinner("Loading index and passages..."): |
|
index, chunks = load_index_and_chunks() |
|
if index is None or chunks is None: |
|
return |
|
|
|
with st.spinner("Loading embedding model..."): |
|
embed_model = load_embedding_model() |
|
if embed_model is None: |
|
return |
|
|
|
with st.spinner("Loading QA pipeline..."): |
|
qa_pipeline = load_qa_pipeline() |
|
if qa_pipeline is None: |
|
return |
|
|
|
with st.spinner("Loading curated Q/A pairs..."): |
|
curated_qa_pairs = load_curated_qa_pairs() |
|
|
|
st.write("Enter your question about the PDF document:") |
|
query = st.text_input("Question:") |
|
|
|
if query: |
|
st.session_state.conversation_history += f"User: {query}\n" |
|
|
|
with st.spinner("Retrieving relevant PDF context..."): |
|
try: |
|
query_embedding = embed_model.encode([query]).astype("float32") |
|
k = 3 |
|
distances, indices = index.search(query_embedding, k) |
|
pdf_context = "" |
|
for idx in indices[0]: |
|
pdf_context += chunks[idx] + "\n" |
|
except Exception as e: |
|
st.error(f"Error retrieving PDF context: {e}") |
|
return |
|
|
|
base_context = st.session_state.conversation_history + "\n" |
|
|
|
|
|
if st.button("Use RAG Model with Custom Retriever"): |
|
with st.spinner("Generating answer using RAG model..."): |
|
tokenizer_rag, rag_model = load_rag_model(index, chunks, embed_model) |
|
if tokenizer_rag is None or rag_model is None: |
|
return |
|
rag_answer = generate_rag_answer(query, tokenizer_rag, rag_model) |
|
st.write("**RAG Model Answer:**") |
|
st.write(rag_answer) |
|
st.session_state.conversation_history += f"AI (RAG): {rag_answer}\n" |
|
return |
|
|
|
|
|
with st.spinner("Checking for curated Q/A pair..."): |
|
curated_pair = get_curated_pair(query, curated_qa_pairs, embed_model) |
|
|
|
if curated_pair: |
|
st.info("A curated Q/A pair was found and will be used for the answer by default.") |
|
use_full_data = st.checkbox("High Reasoning", value=False) |
|
if not use_full_data: |
|
answer = curated_pair["answer"] |
|
st.write(answer) |
|
st.session_state.conversation_history += f"AI: {answer}\n" |
|
return |
|
else: |
|
context_to_use = base_context + pdf_context |
|
else: |
|
context_to_use = base_context + pdf_context |
|
|
|
with st.expander("Show Full PDF Context"): |
|
st.write(pdf_context) |
|
|
|
st.subheader("Answer:") |
|
with st.spinner("Generating answer using standard QA pipeline..."): |
|
try: |
|
result = qa_pipeline(question=query, context=context_to_use) |
|
answer = result["answer"] |
|
st.write(answer) |
|
st.session_state.conversation_history += f"AI: {answer}\n" |
|
except Exception as e: |
|
st.error(f"Error generating answer using QA pipeline: {e}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|