Upload 3 files
Browse files- LOTR_script.txt +0 -0
- app.py +127 -0
- requirements.txt +6 -0
LOTR_script.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import numpy as np
|
3 |
+
import faiss
|
4 |
+
import torch
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
|
10 |
+
# --- Helper Functions ---
|
11 |
+
def preprocess_text(text):
|
12 |
+
"""Clean and preprocess the text by removing multiple newlines and extra spaces."""
|
13 |
+
text = re.sub(r'\n+', '\n', text)
|
14 |
+
text = re.sub(r'[ ]{2,}', ' ', text)
|
15 |
+
return text.strip()
|
16 |
+
|
17 |
+
|
18 |
+
def chunk_text(text, max_chunk_size=500, overlap=100):
|
19 |
+
"""Chunk the text into smaller parts with overlap."""
|
20 |
+
chunks = []
|
21 |
+
start = 0
|
22 |
+
while start < len(text):
|
23 |
+
end = start + max_chunk_size
|
24 |
+
chunk = text[start:end]
|
25 |
+
chunks.append(chunk)
|
26 |
+
start += max_chunk_size - overlap
|
27 |
+
return chunks
|
28 |
+
|
29 |
+
|
30 |
+
def retrieve_relevant_chunks(query, k=3, return_score=False):
|
31 |
+
"""Retrieve the most relevant chunks from the script based on the query."""
|
32 |
+
query_embedding = embedding_model.encode([query])
|
33 |
+
distances, indices = index.search(np.array(query_embedding), k)
|
34 |
+
retrieved_chunks = [chunk_lookup[i] for i in indices[0]]
|
35 |
+
similarity_scores = [1 / (1 + d) for d in distances[0]] # Convert distance to similarity (closer to 1 is better)
|
36 |
+
context = "\n".join(retrieved_chunks)
|
37 |
+
top_score = similarity_scores[0] # Best match
|
38 |
+
return (context, top_score) if return_score else context
|
39 |
+
|
40 |
+
|
41 |
+
def build_prompt(query, context):
|
42 |
+
"""Build a prompt for Falcon-7B model with context."""
|
43 |
+
return f"""You are a helpful assistant that answers questions based only on the movie script context provided below.
|
44 |
+
|
45 |
+
Context:
|
46 |
+
{context}
|
47 |
+
|
48 |
+
Question: {query}
|
49 |
+
|
50 |
+
Do not answer using your own knowledge. Only use the context. If unsure or if the answer is not in the context, reply: "I cannot answer that as the information is not in the script"
|
51 |
+
Answer:"""
|
52 |
+
|
53 |
+
|
54 |
+
# --- Load and Preprocess Data ---
|
55 |
+
with open("LOTR_script.txt", "r", encoding="utf-8") as file:
|
56 |
+
movie_script = file.read()
|
57 |
+
|
58 |
+
movie_script = preprocess_text(movie_script)
|
59 |
+
chunks = chunk_text(movie_script)
|
60 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
61 |
+
embeddings = embedding_model.encode(chunks, show_progress_bar=True)
|
62 |
+
|
63 |
+
dimension = embeddings.shape[1] # e.g., 384 for MiniLM
|
64 |
+
index = faiss.IndexFlatL2(dimension)
|
65 |
+
index.add(np.array(embeddings))
|
66 |
+
chunk_lookup = {i: chunk for i, chunk in enumerate(chunks)}
|
67 |
+
|
68 |
+
# --- Load Falcon Model ---
|
69 |
+
model_name = "tiiuae/falcon-7b-instruct"
|
70 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
71 |
+
model = AutoModelForCausalLM.from_pretrained(
|
72 |
+
model_name,
|
73 |
+
trust_remote_code=True,
|
74 |
+
device_map="auto",
|
75 |
+
torch_dtype=torch.bfloat16
|
76 |
+
)
|
77 |
+
generator = pipeline(
|
78 |
+
"text-generation",
|
79 |
+
model=model,
|
80 |
+
tokenizer=tokenizer,
|
81 |
+
device_map="auto"
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
# --- Main Answering Function ---
|
86 |
+
def answer_question(query):
|
87 |
+
"""Generate an answer to the query based on relevant chunks."""
|
88 |
+
context, avg_score = retrieve_relevant_chunks(query, k=3, return_score=True)
|
89 |
+
context_str = context[:1500] # Truncate for model input
|
90 |
+
threshold = 0.4
|
91 |
+
if avg_score < threshold:
|
92 |
+
return f"I don't know.\n\n📊 Avg Similarity Score: {round(avg_score, 2)} (Below threshold)"
|
93 |
+
|
94 |
+
prompt = build_prompt(query, context_str)
|
95 |
+
response = generator(prompt, max_new_tokens=200, do_sample=True, temperature=0.7)[0]["generated_text"]
|
96 |
+
|
97 |
+
if "Answer:" in response:
|
98 |
+
answer = response.split("Answer:")[-1].strip()
|
99 |
+
else:
|
100 |
+
answer = response.strip()
|
101 |
+
|
102 |
+
return f"{answer}\n\n📊 Avg Similarity Score: {round(avg_score, 2)}"
|
103 |
+
|
104 |
+
|
105 |
+
# Predefined questions for dropdown menu
|
106 |
+
predefined_questions = [
|
107 |
+
"What is the main goal of the Fellowship?",
|
108 |
+
"What is the relationship between Gandalf and Saruman?",
|
109 |
+
"How do the hobbits react when they first see the world outside the Shire?",
|
110 |
+
"What does the city of Isengard represent in Saruman’s betrayal?"
|
111 |
+
]
|
112 |
+
|
113 |
+
# --- Gradio Interface ---
|
114 |
+
interface = gr.Interface(
|
115 |
+
fn=answer_question,
|
116 |
+
inputs=[
|
117 |
+
gr.Dropdown(choices=predefined_questions, label="Select a predefined question"),
|
118 |
+
gr.Textbox(lines=2, placeholder="Or enter your own question..."),
|
119 |
+
],
|
120 |
+
outputs="text",
|
121 |
+
title="🧝 LOTR Sage (Movie Q&A Bot)",
|
122 |
+
description="Ask questions about The Lord of the Rings (Fellowship of the Ring) movie script. Powered by FAISS + Falcon-7B."
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
interface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.19.0
|
2 |
+
sentence-transformers==2.2.0
|
3 |
+
faiss-cpu==1.7.4
|
4 |
+
torch==1.13.0
|
5 |
+
transformers==4.26.1
|
6 |
+
accelerate==0.16.0
|