Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,64 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
""
|
7 |
-
client =
|
|
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
def
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
)
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
|
39 |
-
|
40 |
-
yield response
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
|
|
|
|
|
|
|
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from groq import Groq
|
4 |
+
import chromadb
|
5 |
+
from chromadb.config import Settings
|
6 |
+
import torch
|
7 |
+
from sentence_transformers import CrossEncoder
|
8 |
import gradio as gr
|
9 |
+
from datetime import datetime # Import datetime to get current time
|
10 |
|
11 |
+
# Load environment variables and initialize clients
|
12 |
+
load_dotenv()
|
13 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
14 |
+
client = Groq(api_key=GROQ_API_KEY)
|
15 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
|
17 |
+
def initialize_system():
|
18 |
+
chroma_client = chromadb.PersistentClient(
|
19 |
+
path="./chroma_db",
|
20 |
+
settings=Settings(anonymized_telemetry=False, allow_reset=True, is_persistent=True)
|
21 |
+
)
|
22 |
+
|
23 |
+
embedding_function = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(
|
24 |
+
model_name="sentence-transformers/all-mpnet-base-v2",
|
25 |
+
device=DEVICE
|
26 |
+
)
|
27 |
+
|
28 |
+
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=DEVICE)
|
29 |
+
collection = chroma_client.get_collection(name="website_content", embedding_function=embedding_function)
|
30 |
+
return chroma_client, collection, reranker
|
31 |
|
32 |
+
def get_context(message):
|
33 |
+
results = collection.query(
|
34 |
+
query_texts=[message],
|
35 |
+
n_results=500,
|
36 |
+
include=["metadatas", "documents", "distances"]
|
37 |
+
)
|
38 |
+
|
39 |
+
print(f"\n=== Search Results ===")
|
40 |
+
print(f"Initial ChromaDB results found: {len(results['documents'][0])}")
|
41 |
+
|
42 |
+
# Rerank all results
|
43 |
+
rerank_pairs = [(message, doc) for doc in results['documents'][0]]
|
44 |
+
rerank_scores = reranker.predict(rerank_pairs)
|
45 |
+
|
46 |
+
# Create list of results with scores
|
47 |
+
all_results = []
|
48 |
+
url_chunks = {} # Group chunks by URL
|
49 |
+
|
50 |
+
# Group chunks by URL and store their scores
|
51 |
+
for score, doc, metadata in zip(rerank_scores, results['documents'][0], results['metadatas'][0]):
|
52 |
+
url = metadata['url']
|
53 |
+
if url not in url_chunks:
|
54 |
+
url_chunks[url] = []
|
55 |
+
url_chunks[url].append({'text': doc, 'metadata': metadata, 'score': score})
|
56 |
+
|
57 |
+
# For each URL, select the best chunks while maintaining diversity
|
58 |
+
for url, chunks in url_chunks.items():
|
59 |
+
# Sort chunks for this URL by score
|
60 |
+
chunks.sort(key=lambda x: x['score'], reverse=True)
|
61 |
+
|
62 |
+
# Take up to 5 chunks per URL, but only if their scores are good
|
63 |
+
selected_chunks = []
|
64 |
+
for chunk in chunks[:5]: # 5 chunks per URL
|
65 |
+
# Only include if score is decent
|
66 |
+
if chunk['score'] > -10: # Increased threshold to ensure higher relevance
|
67 |
+
selected_chunks.append(chunk)
|
68 |
+
|
69 |
+
# Add selected chunks to final results
|
70 |
+
all_results.extend(selected_chunks)
|
71 |
+
|
72 |
+
# Sort all results by score for final ranking
|
73 |
+
all_results.sort(key=lambda x: x['score'], reverse=True)
|
74 |
+
|
75 |
+
# Take only top 20 results maximum
|
76 |
+
all_results = all_results[:20]
|
77 |
+
|
78 |
+
print(f"\nFinal results after reranking and filtering: {len(all_results)}")
|
79 |
+
if all_results:
|
80 |
+
print("\nTop Similarity Scores and URLs:")
|
81 |
+
for i, result in enumerate(all_results[:20], 1): # Show only top 20 in logs
|
82 |
+
print(f"{i}. Score: {result['score']:.4f} - URL: {result['metadata']['url']}")
|
83 |
+
print("=" * 50)
|
84 |
+
|
85 |
+
# Build context from filtered results
|
86 |
+
context = "\nRelevant Information:\n"
|
87 |
+
total_chars = 0
|
88 |
+
max_chars = 30000 # To ensure we don't exceed token limits
|
89 |
+
|
90 |
+
for result in all_results:
|
91 |
+
chunk_text = f"\nSource: {result['metadata']['url']}\n{result['text']}\n"
|
92 |
+
if total_chars + len(chunk_text) > max_chars:
|
93 |
+
break
|
94 |
+
context += chunk_text
|
95 |
+
total_chars += len(chunk_text)
|
96 |
+
|
97 |
+
print(f"\nFinal context length: {total_chars} characters")
|
98 |
+
return context
|
99 |
|
100 |
+
def chat_response(message, history):
|
101 |
+
"""Chat response function for Gradio interface"""
|
102 |
+
try:
|
103 |
+
# Get context
|
104 |
+
context = get_context(message)
|
105 |
+
|
106 |
+
# Get current time and date
|
107 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
108 |
+
|
109 |
+
# Build messages list
|
110 |
+
messages = [{
|
111 |
+
"role": "system",
|
112 |
+
"content": f"""You are an AI assistant for the British Antarctic Survey (BAS). Your responses should be based ONLY on the context provided below.
|
113 |
|
114 |
+
IMPORTANT INSTRUCTIONS:
|
115 |
+
1. ALWAYS thoroughly check the provided context before saying you don't have information
|
116 |
+
2. If you find ANY relevant information in the context, use it - even if it's not complete
|
117 |
+
3. If you find time-sensitive information in the context, share it - it's current as of when the context was retrieved
|
118 |
+
4. When citing sources, put them on a new line after the relevant information like this:
|
119 |
+
Here is some information about BAS.
|
120 |
+
Source: https://www.bas.ac.uk/example
|
121 |
|
122 |
+
5. Do not say things like:
|
123 |
+
- "I don't have access to real-time information"
|
124 |
+
- "I cannot browse the internet"
|
125 |
+
Instead, share what IS in the context, and only say "I don't have enough information" if you truly find nothing relevant to the users question.
|
126 |
|
127 |
+
6. Keep responses:
|
128 |
+
- With emojis where appropriate
|
129 |
+
- Without duplicate source citations
|
130 |
+
- Based strictly on the context below
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
Current Time: {current_time}
|
|
|
133 |
|
134 |
+
Context: {context}"""
|
135 |
+
}]
|
136 |
+
|
137 |
+
print("\n\n==========START Contents of the message being sent to the LLM==========\n")
|
138 |
+
print(messages)
|
139 |
+
print("\n\n==========END Contents of the message being sent to the LLM==========\n")
|
140 |
|
141 |
+
# Add history and current message
|
142 |
+
if history:
|
143 |
+
for h in history:
|
144 |
+
messages.append({"role": "user", "content": str(h[0])})
|
145 |
+
if h[1]: # If there's a response
|
146 |
+
messages.append({"role": "assistant", "content": str(h[1])})
|
147 |
+
|
148 |
+
messages.append({"role": "user", "content": str(message)})
|
149 |
+
|
150 |
+
# Get response
|
151 |
+
response = ""
|
152 |
+
completion = client.chat.completions.create(
|
153 |
+
model="llama-3.3-70b-versatile",
|
154 |
+
messages=messages,
|
155 |
+
temperature=0.7,
|
156 |
+
max_tokens=2000,
|
157 |
+
top_p=0.95,
|
158 |
+
stream=True
|
159 |
+
)
|
160 |
+
|
161 |
+
print("\n=== LLM Response Start ===")
|
162 |
+
for chunk in completion:
|
163 |
+
if chunk.choices[0].delta.content:
|
164 |
+
response += chunk.choices[0].delta.content
|
165 |
+
print(chunk.choices[0].delta.content, end='', flush=True)
|
166 |
+
yield response
|
167 |
+
print("\n=== LLM Response End ===\n")
|
168 |
|
169 |
+
except Exception as e:
|
170 |
+
error_msg = f"An error occurred: {str(e)}"
|
171 |
+
print(f"\nERROR: {error_msg}")
|
172 |
+
yield error_msg
|
173 |
|
174 |
if __name__ == "__main__":
|
175 |
+
try:
|
176 |
+
print("\n=== Starting Application ===")
|
177 |
+
|
178 |
+
# Initialise system
|
179 |
+
print("Initialising ChromaDB...")
|
180 |
+
chroma_client, collection, reranker = initialize_system()
|
181 |
+
print(f"Found {collection.count()} documents in collection")
|
182 |
+
|
183 |
+
print("\nCreating Gradio interface...")
|
184 |
+
|
185 |
+
# Create a simple Gradio interface
|
186 |
+
demo = gr.Blocks()
|
187 |
+
|
188 |
+
with demo:
|
189 |
+
gr.Markdown("# Website Chat Assistant")
|
190 |
+
gr.Markdown("Ask questions about the website.")
|
191 |
+
|
192 |
+
chatbot = gr.Chatbot(height=600)
|
193 |
+
msg = gr.Textbox(placeholder="Ask a question...", label="Your question")
|
194 |
+
clear = gr.Button("Clear")
|
195 |
+
|
196 |
+
def user(user_message, history):
|
197 |
+
return "", history + [[user_message, None]]
|
198 |
+
|
199 |
+
def bot(history):
|
200 |
+
if history and history[-1][1] is None:
|
201 |
+
for response in chat_response(history[-1][0], history[:-1]):
|
202 |
+
history[-1][1] = response
|
203 |
+
yield history
|
204 |
+
|
205 |
+
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
206 |
+
bot, chatbot, chatbot
|
207 |
+
)
|
208 |
+
|
209 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
210 |
+
|
211 |
+
# Launch with minimal configuration
|
212 |
+
demo.queue()
|
213 |
+
demo.launch(
|
214 |
+
server_name="127.0.0.1",
|
215 |
+
server_port=7860,
|
216 |
+
share=False
|
217 |
+
)
|
218 |
+
|
219 |
+
except Exception as e:
|
220 |
+
print(f"\nERROR: {str(e)}")
|
221 |
+
raise
|