Spaces:
Running
Running
import os | |
import json | |
import numpy as np | |
import faiss | |
import torch | |
import torch.nn as nn | |
from google.cloud import storage | |
from transformers import AutoTokenizer, AutoModel | |
import openai | |
import textwrap | |
import unicodedata | |
import streamlit as st | |
from utils import setup_gcp_auth, setup_openai_auth | |
import gc | |
# Force model to CPU for stability | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
# Local Paths | |
local_embeddings_file = "all_embeddings.npy" | |
local_faiss_index_file = "faiss_index.faiss" | |
local_text_chunks_file = "text_chunks.txt" | |
local_metadata_file = "metadata.jsonl" | |
# Load GCP authentication from utility function | |
def setup_gcp_client(): | |
try: | |
credentials = setup_gcp_auth() | |
# Get bucket name from secrets - required | |
try: | |
bucket_name_gcs = st.secrets["bucket_name_gcs"] | |
except KeyError: | |
print("β Error: GCS bucket name not found in secrets") | |
return None | |
storage_client = storage.Client(credentials=credentials) | |
bucket = storage_client.bucket(bucket_name_gcs) | |
print("β GCP client initialized successfully") | |
return bucket | |
except Exception as e: | |
print(f"β GCP client initialization error: {str(e)}") | |
return None | |
# Setup OpenAI authentication | |
def setup_openai_client(): | |
try: | |
setup_openai_auth() | |
print("β OpenAI client initialized successfully") | |
return True | |
except Exception as e: | |
print(f"β OpenAI client initialization error: {str(e)}") | |
return False | |
def load_model(): | |
"""Load the embedding model and store in session state""" | |
try: | |
# Check if model already loaded | |
if 'model' in st.session_state and st.session_state.model is not None: | |
print("Model already loaded in session state") | |
return st.session_state.tokenizer, st.session_state.model | |
print("Loading new model instance...") | |
# Force model to CPU | |
device = torch.device("cpu") | |
# Get embedding model path from secrets | |
try: | |
embedding_model = st.secrets["embedding_model"] | |
except KeyError: | |
print("β Error: Embedding model not found in secrets") | |
return None, None | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(embedding_model) | |
model = AutoModel.from_pretrained( | |
embedding_model, | |
torch_dtype=torch.float16 | |
) | |
# Move to CPU and set to eval mode | |
model = model.to(device) | |
model.eval() | |
# Disable gradient computation | |
torch.set_grad_enabled(False) | |
# Store in session state | |
st.session_state.tokenizer = tokenizer | |
st.session_state.model = model | |
print("β Model loaded successfully") | |
return tokenizer, model | |
except Exception as e: | |
print(f"β Error loading model: {str(e)}") | |
# Return None values - don't raise exception | |
return None, None | |
def download_file_from_gcs(bucket, gcs_path, local_path): | |
"""Download a file from GCS to local storage.""" | |
try: | |
# Check if file already exists | |
if os.path.exists(local_path): | |
print(f"File already exists locally: {local_path}") | |
return True | |
blob = bucket.blob(gcs_path) | |
blob.download_to_filename(local_path) | |
print(f"β Downloaded {gcs_path} β {local_path}") | |
return True | |
except Exception as e: | |
print(f"β Error downloading {gcs_path}: {str(e)}") | |
return False | |
def load_data_files(): | |
"""Load FAISS index, text chunks, and metadata""" | |
# Check if already loaded in session state | |
if 'faiss_index' in st.session_state and st.session_state.faiss_index is not None: | |
print("Using cached data files from session state") | |
return st.session_state.faiss_index, st.session_state.text_chunks, st.session_state.metadata_dict | |
# Initialize clients | |
bucket = setup_gcp_client() | |
openai_initialized = setup_openai_client() | |
if not bucket or not openai_initialized: | |
print("Failed to initialize required services") | |
return None, None, None | |
# Get GCS paths from secrets - required | |
try: | |
metadata_file_gcs = st.secrets["metadata_file_gcs"] | |
embeddings_file_gcs = st.secrets["embeddings_file_gcs"] | |
faiss_index_file_gcs = st.secrets["faiss_index_file_gcs"] | |
text_chunks_file_gcs = st.secrets["text_chunks_file_gcs"] | |
except KeyError as e: | |
print(f"β Error: Required GCS path not found in secrets: {e}") | |
return None, None, None | |
# Download necessary files | |
success = True | |
success &= download_file_from_gcs(bucket, faiss_index_file_gcs, local_faiss_index_file) | |
success &= download_file_from_gcs(bucket, text_chunks_file_gcs, local_text_chunks_file) | |
success &= download_file_from_gcs(bucket, metadata_file_gcs, local_metadata_file) | |
if not success: | |
print("Failed to download required files") | |
return None, None, None | |
# Load FAISS index | |
try: | |
faiss_index = faiss.read_index(local_faiss_index_file) | |
except Exception as e: | |
print(f"β Error loading FAISS index: {str(e)}") | |
return None, None, None | |
# Load text chunks | |
try: | |
text_chunks = {} # {ID -> (Title, Author, Text)} | |
with open(local_text_chunks_file, "r", encoding="utf-8") as f: | |
for line in f: | |
parts = line.strip().split("\t") | |
if len(parts) == 4: | |
text_chunks[int(parts[0])] = (parts[1], parts[2], parts[3]) | |
except Exception as e: | |
print(f"β Error loading text chunks: {str(e)}") | |
return None, None, None | |
# Load metadata | |
try: | |
metadata_dict = {} | |
with open(local_metadata_file, "r", encoding="utf-8") as f: | |
for line in f: | |
item = json.loads(line) | |
metadata_dict[item["Title"]] = item | |
except Exception as e: | |
print(f"β Error loading metadata: {str(e)}") | |
return None, None, None | |
print(f"β Data loaded successfully: {len(text_chunks)} passages available") | |
# Store in session state | |
st.session_state.faiss_index = faiss_index | |
st.session_state.text_chunks = text_chunks | |
st.session_state.metadata_dict = metadata_dict | |
return faiss_index, text_chunks, metadata_dict | |
def average_pool(last_hidden_states, attention_mask): | |
"""Average pooling for sentence embeddings.""" | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
# Cache for query embeddings | |
query_embedding_cache = {} | |
def get_embedding(text): | |
"""Generate embeddings for a text query""" | |
# Check cache first | |
if text in query_embedding_cache: | |
return query_embedding_cache[text] | |
try: | |
# Get model | |
if 'model' not in st.session_state or st.session_state.model is None: | |
tokenizer, model = load_model() | |
else: | |
tokenizer, model = st.session_state.tokenizer, st.session_state.model | |
# Handle model load failure | |
if model is None: | |
print("Model is None, returning zero embedding") | |
return np.zeros((1, 384), dtype=np.float32) | |
# Prepare text | |
input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}" | |
# Tokenize | |
inputs = tokenizer( | |
input_text, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=512, | |
return_attention_mask=True | |
) | |
# Generate embeddings | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask']) | |
embeddings = nn.functional.normalize(embeddings, p=2, dim=1) | |
embeddings = embeddings.detach().cpu().numpy() | |
# Clean up | |
del outputs, inputs | |
gc.collect() | |
# Cache and return | |
query_embedding_cache[text] = embeddings | |
return embeddings | |
except Exception as e: | |
print(f"β Embedding error: {str(e)}") | |
return np.zeros((1, 384), dtype=np.float32) | |
def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5): | |
"""Retrieve top-k most relevant passages using FAISS with metadata.""" | |
try: | |
print(f"\nπ Retrieving passages for query: {query}") | |
# Get query embedding | |
query_embedding = get_embedding(query) | |
# Search in FAISS index | |
distances, indices = faiss_index.search(query_embedding, top_k * 2) | |
print(f"Found {len(distances[0])} potential matches") | |
retrieved_passages = [] | |
retrieved_sources = [] | |
cited_titles = set() | |
# Process results | |
for dist, idx in zip(distances[0], indices[0]): | |
print(f"Distance: {dist:.4f}, Index: {idx}") | |
if idx in text_chunks and dist >= similarity_threshold: | |
title_with_txt, author, text = text_chunks[idx] | |
# Clean title | |
clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt | |
clean_title = unicodedata.normalize("NFC", clean_title) | |
# Skip duplicates | |
if clean_title in cited_titles: | |
continue | |
# Get metadata | |
metadata_entry = metadata_dict.get(clean_title, {}) | |
author = metadata_entry.get("Author", "Unknown") | |
publisher = metadata_entry.get("Publisher", "Unknown") | |
# Add to results | |
cited_titles.add(clean_title) | |
retrieved_passages.append(text) | |
retrieved_sources.append((clean_title, author, publisher)) | |
# Stop if we have enough | |
if len(retrieved_passages) == top_k: | |
break | |
print(f"Retrieved {len(retrieved_passages)} passages") | |
return retrieved_passages, retrieved_sources | |
except Exception as e: | |
print(f"β Error in retrieve_passages: {str(e)}") | |
return [], [] | |
def answer_with_llm(query, context=None, word_limit=100): | |
"""Generate an answer using OpenAI GPT model with formatted citations.""" | |
try: | |
# Format context | |
if context: | |
formatted_contexts = [] | |
total_chars = 0 | |
max_context_chars = 4000 | |
for (title, author, publisher), text in context: | |
remaining_space = max(0, max_context_chars - total_chars) | |
excerpt_len = min(150, remaining_space) | |
if excerpt_len > 50: | |
excerpt = text[:excerpt_len].strip() + "..." if len(text) > excerpt_len else text | |
formatted_context = f"[{title} by {author}, Published by {publisher}] {excerpt}" | |
formatted_contexts.append(formatted_context) | |
total_chars += len(formatted_context) | |
if total_chars >= max_context_chars: | |
break | |
formatted_context = "\n".join(formatted_contexts) | |
else: | |
formatted_context = "No relevant information available." | |
# System message | |
system_message = ( | |
"You are an AI specialized in Indian spiritual texts. " | |
"Answer based on context, summarizing ideas rather than quoting verbatim. " | |
"Ensure proper citation and do not include direct excerpts." | |
) | |
# User message | |
user_message = f""" | |
Context: | |
{formatted_context} | |
Question: | |
{query} | |
""" | |
# Get LLM model from secrets | |
try: | |
llm_model = st.secrets["llm_model"] | |
except KeyError: | |
print("β Error: LLM model not found in secrets") | |
return "I apologize, but I'm unable to answer at the moment." | |
# Call OpenAI API | |
response = openai.chat.completions.create( | |
model=llm_model, | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
], | |
max_tokens=200, | |
temperature=0.7 | |
) | |
answer = response.choices[0].message.content.strip() | |
# Enforce word limit | |
words = answer.split() | |
if len(words) > word_limit: | |
answer = " ".join(words[:word_limit]) | |
if not answer.endswith((".", "!", "?")): | |
answer += "." | |
return answer | |
except Exception as e: | |
print(f"β LLM API error: {str(e)}") | |
return "I apologize, but I'm unable to answer at the moment." | |
def format_citations(sources): | |
"""Format citations to display each one on a new line with a full stop if needed.""" | |
formatted_citations = [] | |
for title, author, publisher in sources: | |
# Check if the publisher already ends with a period, question mark, or exclamation mark | |
if publisher.endswith(('.', '!', '?')): | |
formatted_citations.append(f"π {title} by {author}, Published by {publisher}") | |
else: | |
formatted_citations.append(f"π {title} by {author}, Published by {publisher}.") | |
return "\n".join(formatted_citations) | |
def process_query(query, top_k=5, word_limit=100): | |
"""Process a query through the RAG pipeline with proper formatting.""" | |
print(f"\nπ Processing query: {query}") | |
# Load data files if not already loaded | |
faiss_index, text_chunks, metadata_dict = load_data_files() | |
# Check if data loaded successfully | |
if faiss_index is None or text_chunks is None or metadata_dict is None: | |
return { | |
"query": query, | |
"answer_with_rag": "β οΈ System error: Data files not loaded properly.", | |
"citations": "No citations available." | |
} | |
# Get relevant passages | |
retrieved_context, retrieved_sources = retrieve_passages( | |
query, | |
faiss_index, | |
text_chunks, | |
metadata_dict, | |
top_k=top_k | |
) | |
# Format citations | |
sources = format_citations(retrieved_sources) if retrieved_sources else "No citation available." | |
# Generate answer | |
if retrieved_context: | |
context_with_sources = list(zip(retrieved_sources, retrieved_context)) | |
llm_answer_with_rag = answer_with_llm(query, context_with_sources, word_limit=word_limit) | |
else: | |
llm_answer_with_rag = "β οΈ No relevant context found." | |
return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources} |