File size: 9,859 Bytes
ae479fd 076c725 6170969 076c725 1acd638 d8ffd44 ae479fd 076c725 6170969 226641d 076c725 d8ffd44 076c725 411f496 076c725 411f496 076c725 411f496 076c725 6170969 411f496 226641d d8ffd44 bebbe8f 411f496 90ada56 411f496 90ada56 411f496 226641d d8ffd44 1273455 90ada56 d8ffd44 3969dd2 1acd638 1273455 d8ffd44 411f496 7b32796 90ada56 411f496 d8ffd44 7e569d6 411f496 1273455 411f496 7e569d6 411f496 d8ffd44 411f496 d8ffd44 076c725 c521252 076c725 90ada56 226641d 411f496 90ada56 411f496 90ada56 411f496 90ada56 411f496 076c725 226641d 076c725 411f496 90ada56 411f496 226641d 0282eea 20c970d d8ffd44 411f496 90ada56 d8ffd44 411f496 bebbe8f 411f496 6170969 bebbe8f 90ada56 0282eea bebbe8f 0282eea 076c725 411f496 076c725 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
import streamlit as st
import faiss
import numpy as np
import pickle
import json
from sentence_transformers import SentenceTransformer
from transformers import (
pipeline,
RagTokenizer,
RagRetriever,
RagSequenceForGeneration,
)
import torch
# ========================
# File Names & Model Names
# ========================
INDEX_FILE = "faiss_index.index"
CHUNKS_FILE = "chunks.pkl"
CURATED_QA_FILE = "curated_qa_pairs.json"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
QA_MODEL_NAME = "deepset/roberta-large-squad2" # For the standard QA pipeline
# ========================
# Loading Functions (cached)
# ========================
@st.cache_resource
def load_index_and_chunks():
try:
index = faiss.read_index(INDEX_FILE)
with open(CHUNKS_FILE, "rb") as f:
chunks = pickle.load(f)
return index, chunks
except Exception as e:
st.error(f"Error loading FAISS index and chunks: {e}")
return None, None
@st.cache_resource
def load_embedding_model():
try:
model = SentenceTransformer(EMBEDDING_MODEL_NAME)
return model
except Exception as e:
st.error(f"Error loading embedding model: {e}")
return None
@st.cache_resource
def load_qa_pipeline():
try:
qa_pipe = pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
return qa_pipe
except Exception as e:
st.error(f"Error loading QA pipeline: {e}")
return None
@st.cache_resource
def load_curated_qa_pairs(json_file=CURATED_QA_FILE):
try:
with open(json_file, "r", encoding="utf-8") as f:
curated_qa_pairs = json.load(f)
return curated_qa_pairs
except Exception as e:
st.error(f"Error loading curated Q/A pairs from JSON: {e}")
return []
# ========================================
# Standard: Retrieve Curated Q/A Pair Function
# ========================================
def get_curated_pair(query, curated_qa, embed_model, threshold=1.0):
try:
curated_questions = [qa["question"] for qa in curated_qa]
query_embedding = embed_model.encode([query]).astype("float32")
curated_embeddings = embed_model.encode(curated_questions, show_progress_bar=False)
curated_embeddings = np.array(curated_embeddings).astype("float32")
# Build a temporary FAISS index for the curated questions
dimension = curated_embeddings.shape[1]
curated_index = faiss.IndexFlatL2(dimension)
curated_index.add(curated_embeddings)
k = 1
distances, indices = curated_index.search(query_embedding, k)
if distances[0][0] < threshold:
idx = indices[0][0]
return curated_qa[idx]
except Exception as e:
st.error(f"Error retrieving curated Q/A pair: {e}")
return None
# ============================================================
# Custom RAG Retriever: Uses your FAISS index & PDF passages
# ============================================================
class CustomRagRetriever(RagRetriever):
"""
A custom retriever that uses your FAISS index and passages.
It encodes the query with the provided embedding model,
searches your FAISS index, and returns the top retrieved documents.
"""
def __init__(self, config, faiss_index, passages, embed_model, tokenizer, n_docs=5):
self.faiss_index = faiss_index # Your custom FAISS index of PDF embeddings
self.passages = passages # List of PDF passage texts
self.embed_model = embed_model # Embedding model used for encoding queries
self.n_docs = n_docs # Number of top documents to retrieve
self.tokenizer = tokenizer # Save tokenizer for internal use if needed
# Override init_retrieval to bypass loading default passages.
self.init_retrieval = lambda: None
# Call the parent constructor with the required arguments.
super().__init__(config, question_encoder_tokenizer=tokenizer, generator_tokenizer=tokenizer)
def retrieve(self, query, n_docs=None):
try:
if n_docs is None:
n_docs = self.n_docs
# Encode the query using the embedding model
query_embedding = self.embed_model.encode([query]).astype("float32")
distances, indices = self.faiss_index.search(query_embedding, n_docs)
retrieved_docs = [self.passages[i] for i in indices[0]]
return {
"doc_ids": indices,
"doc_scores": distances,
"retrieved_docs": retrieved_docs,
}
except Exception as e:
st.error(f"Error in custom retrieval: {e}")
return {"doc_ids": None, "doc_scores": None, "retrieved_docs": []}
# ============================================================
# Load RAG Model with Custom Retriever (cached for performance)
# ============================================================
@st.cache_resource
def load_rag_model(_faiss_index, passages, _embed_model):
try:
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
custom_retriever = CustomRagRetriever(
config=rag_model.config,
faiss_index=_faiss_index,
passages=passages,
embed_model=_embed_model,
tokenizer=tokenizer,
n_docs=5
)
rag_model.set_retriever(custom_retriever)
return tokenizer, rag_model
except Exception as e:
st.error(f"Error loading RAG model with custom retriever: {e}")
return None, None
def generate_rag_answer(query, tokenizer, rag_model):
try:
inputs = tokenizer(query, return_tensors="pt")
with torch.no_grad():
generated_ids = rag_model.generate(**inputs)
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return answer
except Exception as e:
st.error(f"Error generating answer with RAG model: {e}")
return ""
# ========================================
# Main Streamlit App
# ========================================
def main():
st.title("Takalama")
if "conversation_history" not in st.session_state:
st.session_state.conversation_history = ""
with st.spinner("Loading index and passages..."):
index, chunks = load_index_and_chunks()
if index is None or chunks is None:
return
with st.spinner("Loading embedding model..."):
embed_model = load_embedding_model()
if embed_model is None:
return
with st.spinner("Loading QA pipeline..."):
qa_pipeline = load_qa_pipeline()
if qa_pipeline is None:
return
with st.spinner("Loading curated Q/A pairs..."):
curated_qa_pairs = load_curated_qa_pairs()
st.write("Enter your question about the PDF document:")
query = st.text_input("Question:")
if query:
st.session_state.conversation_history += f"User: {query}\n"
with st.spinner("Retrieving relevant PDF context..."):
try:
query_embedding = embed_model.encode([query]).astype("float32")
k = 3 # Retrieve top 3 chunks
distances, indices = index.search(query_embedding, k)
pdf_context = ""
for idx in indices[0]:
pdf_context += chunks[idx] + "\n"
except Exception as e:
st.error(f"Error retrieving PDF context: {e}")
return
base_context = st.session_state.conversation_history + "\n"
# --- Option 1: Use RAG Model with Custom Retriever ---
if st.button("Use RAG Model with Custom Retriever"):
with st.spinner("Generating answer using RAG model..."):
tokenizer_rag, rag_model = load_rag_model(index, chunks, embed_model)
if tokenizer_rag is None or rag_model is None:
return
rag_answer = generate_rag_answer(query, tokenizer_rag, rag_model)
st.write("**RAG Model Answer:**")
st.write(rag_answer)
st.session_state.conversation_history += f"AI (RAG): {rag_answer}\n"
return
# --- Option 2: Use Standard QA Pipeline with Curated Q/A Pairs ---
with st.spinner("Checking for curated Q/A pair..."):
curated_pair = get_curated_pair(query, curated_qa_pairs, embed_model)
if curated_pair:
st.info("A curated Q/A pair was found and will be used for the answer by default.")
use_full_data = st.checkbox("High Reasoning", value=False)
if not use_full_data:
answer = curated_pair["answer"]
st.write(answer)
st.session_state.conversation_history += f"AI: {answer}\n"
return
else:
context_to_use = base_context + pdf_context
else:
context_to_use = base_context + pdf_context
with st.expander("Show Full PDF Context"):
st.write(pdf_context)
st.subheader("Answer:")
with st.spinner("Generating answer using standard QA pipeline..."):
try:
result = qa_pipeline(question=query, context=context_to_use)
answer = result["answer"]
st.write(answer)
st.session_state.conversation_history += f"AI: {answer}\n"
except Exception as e:
st.error(f"Error generating answer using QA pipeline: {e}")
if __name__ == "__main__":
main()
|