LOTR-Sage / app.py
SparshSG's picture
Upload 3 files
8977466 verified
import re
import numpy as np
import faiss
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import gradio as gr
# --- Helper Functions ---
def preprocess_text(text):
"""Clean and preprocess the text by removing multiple newlines and extra spaces."""
text = re.sub(r'\n+', '\n', text)
text = re.sub(r'[ ]{2,}', ' ', text)
return text.strip()
def chunk_text(text, max_chunk_size=500, overlap=100):
"""Chunk the text into smaller parts with overlap."""
chunks = []
start = 0
while start < len(text):
end = start + max_chunk_size
chunk = text[start:end]
chunks.append(chunk)
start += max_chunk_size - overlap
return chunks
def retrieve_relevant_chunks(query, k=3, return_score=False):
"""Retrieve the most relevant chunks from the script based on the query."""
query_embedding = embedding_model.encode([query])
distances, indices = index.search(np.array(query_embedding), k)
retrieved_chunks = [chunk_lookup[i] for i in indices[0]]
similarity_scores = [1 / (1 + d) for d in distances[0]] # Convert distance to similarity (closer to 1 is better)
context = "\n".join(retrieved_chunks)
top_score = similarity_scores[0] # Best match
return (context, top_score) if return_score else context
def build_prompt(query, context):
"""Build a prompt for Falcon-7B model with context."""
return f"""You are a helpful assistant that answers questions based only on the movie script context provided below.
Context:
{context}
Question: {query}
Do not answer using your own knowledge. Only use the context. If unsure or if the answer is not in the context, reply: "I cannot answer that as the information is not in the script"
Answer:"""
# --- Load and Preprocess Data ---
with open("LOTR_script.txt", "r", encoding="utf-8") as file:
movie_script = file.read()
movie_script = preprocess_text(movie_script)
chunks = chunk_text(movie_script)
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = embedding_model.encode(chunks, show_progress_bar=True)
dimension = embeddings.shape[1] # e.g., 384 for MiniLM
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings))
chunk_lookup = {i: chunk for i, chunk in enumerate(chunks)}
# --- Load Falcon Model ---
model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
# --- Main Answering Function ---
def answer_question(query):
"""Generate an answer to the query based on relevant chunks."""
context, avg_score = retrieve_relevant_chunks(query, k=3, return_score=True)
context_str = context[:1500] # Truncate for model input
threshold = 0.4
if avg_score < threshold:
return f"I don't know.\n\n📊 Avg Similarity Score: {round(avg_score, 2)} (Below threshold)"
prompt = build_prompt(query, context_str)
response = generator(prompt, max_new_tokens=200, do_sample=True, temperature=0.7)[0]["generated_text"]
if "Answer:" in response:
answer = response.split("Answer:")[-1].strip()
else:
answer = response.strip()
return f"{answer}\n\n📊 Avg Similarity Score: {round(avg_score, 2)}"
# Predefined questions for dropdown menu
predefined_questions = [
"What is the main goal of the Fellowship?",
"What is the relationship between Gandalf and Saruman?",
"How do the hobbits react when they first see the world outside the Shire?",
"What does the city of Isengard represent in Saruman’s betrayal?"
]
# --- Gradio Interface ---
interface = gr.Interface(
fn=answer_question,
inputs=[
gr.Dropdown(choices=predefined_questions, label="Select a predefined question"),
gr.Textbox(lines=2, placeholder="Or enter your own question..."),
],
outputs="text",
title="🧝 LOTR Sage (Movie Q&A Bot)",
description="Ask questions about The Lord of the Rings (Fellowship of the Ring) movie script. Powered by FAISS + Falcon-7B."
)
if __name__ == "__main__":
interface.launch()