Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,8 @@ from langchain.prompts import PromptTemplate
|
|
6 |
from supabase import create_client, Client
|
7 |
from datetime import datetime
|
8 |
from typing import List, Dict
|
|
|
|
|
9 |
|
10 |
# Initialize FastAPI app
|
11 |
app = FastAPI()
|
@@ -15,17 +17,22 @@ SUPABASE_URL = "https://ykkbxlbonywjmvbyfvwo.supabase.co"
|
|
15 |
SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Inlra2J4bGJvbnl3am12YnlmdndvIiwicm9sZSI6ImFub24iLCJpYXQiOjE3Mzk5NTA2NjIsImV4cCI6MjA1NTUyNjY2Mn0.2BZul_igHKmZtQGhbwV3PvRsCikxviL8ogTKPD3XhuU"
|
16 |
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
17 |
|
18 |
-
# Create a request model with context and
|
19 |
class SearchQuery(BaseModel):
|
20 |
query: str
|
21 |
context: str = None # Optional context field
|
22 |
-
user_id: str #
|
|
|
23 |
|
24 |
# Create a response model for history
|
25 |
class ConversationHistory(BaseModel):
|
|
|
|
|
26 |
query: str
|
27 |
-
response:
|
28 |
timestamp: str
|
|
|
|
|
29 |
|
30 |
# Initialize LangChain with Groq
|
31 |
llm = ChatGroq(
|
@@ -34,7 +41,7 @@ llm = ChatGroq(
|
|
34 |
groq_api_key="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" # Replace with your actual Groq API key
|
35 |
)
|
36 |
|
37 |
-
# Define prompt templates
|
38 |
prompt_templates = {
|
39 |
"common_threats": PromptTemplate(
|
40 |
input_variables=["query", "context"],
|
@@ -52,46 +59,97 @@ prompt_templates = {
|
|
52 |
Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context.
|
53 |
"""
|
54 |
),
|
55 |
-
# You can keep other templates here if you want to manually select them later
|
56 |
}
|
57 |
|
58 |
# Initialize chains for each prompt
|
59 |
chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}
|
60 |
|
61 |
-
# Helper function to get conversation history for a user
|
62 |
-
def get_conversation_history(user_id: str) -> List[Dict]:
|
63 |
try:
|
64 |
-
|
|
|
|
|
|
|
65 |
return response.data
|
66 |
except Exception as e:
|
67 |
print(f"Error retrieving history: {e}")
|
68 |
return []
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
# Helper function to save conversation to Supabase
|
71 |
-
def save_conversation(user_id: str, query: str, response: Dict):
|
72 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
conversation = {
|
74 |
"user_id": user_id,
|
75 |
"query": query,
|
76 |
-
"response": response,
|
77 |
-
"timestamp": datetime.utcnow().isoformat()
|
|
|
|
|
78 |
}
|
79 |
supabase.table("conversation_history").insert(conversation).execute()
|
|
|
80 |
except Exception as e:
|
81 |
print(f"Error saving conversation: {e}")
|
|
|
82 |
|
83 |
@app.post("/search")
|
84 |
async def process_search(search_query: SearchQuery):
|
85 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
# Set default context if not provided
|
87 |
base_context = search_query.context or "You are a cybersecurity expert."
|
88 |
|
89 |
-
# Retrieve previous conversation history for context
|
90 |
-
history = get_conversation_history(search_query.user_id)
|
91 |
-
history_context = "\n".join([f"Previous Query: {item['query']}\nPrevious Response: {item['response']['Use Clear Language: Avoid ambiguity and complex wording']}" for item in history])
|
92 |
full_context = f"{base_context}\n{history_context}" if history_context else base_context
|
93 |
|
94 |
-
# Default to the "general" prompt template
|
95 |
query_type = "general"
|
96 |
|
97 |
# Process the query using the general chain
|
@@ -105,21 +163,41 @@ async def process_search(search_query: SearchQuery):
|
|
105 |
"Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity."
|
106 |
}
|
107 |
|
108 |
-
# Save the conversation to Supabase
|
109 |
-
save_conversation(
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
return {
|
112 |
"status": "success",
|
113 |
"response": structured_response,
|
|
|
114 |
"classified_type": query_type
|
115 |
}
|
116 |
except Exception as e:
|
117 |
raise HTTPException(status_code=500, detail=str(e))
|
118 |
|
119 |
@app.get("/history/{user_id}")
|
120 |
-
async def get_history(user_id: str):
|
121 |
try:
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
return {
|
124 |
"status": "success",
|
125 |
"history": history
|
@@ -127,9 +205,27 @@ async def get_history(user_id: str):
|
|
127 |
except Exception as e:
|
128 |
raise HTTPException(status_code=500, detail=str(e))
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
@app.get("/")
|
131 |
async def root():
|
132 |
-
return {"message": "Search API with structured response and
|
133 |
|
134 |
if __name__ == "__main__":
|
135 |
import uvicorn
|
|
|
6 |
from supabase import create_client, Client
|
7 |
from datetime import datetime
|
8 |
from typing import List, Dict
|
9 |
+
import json
|
10 |
+
import uuid
|
11 |
|
12 |
# Initialize FastAPI app
|
13 |
app = FastAPI()
|
|
|
17 |
SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Inlra2J4bGJvbnl3am12YnlmdndvIiwicm9sZSI6ImFub24iLCJpYXQiOjE3Mzk5NTA2NjIsImV4cCI6MjA1NTUyNjY2Mn0.2BZul_igHKmZtQGhbwV3PvRsCikxviL8ogTKPD3XhuU"
|
18 |
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
19 |
|
20 |
+
# Create a request model with context, user_id, and optional thread_id
|
21 |
class SearchQuery(BaseModel):
|
22 |
query: str
|
23 |
context: str = None # Optional context field
|
24 |
+
user_id: str # UUID string to identify the user for storing history
|
25 |
+
thread_id: str = None # Optional thread_id to append to an existing thread
|
26 |
|
27 |
# Create a response model for history
|
28 |
class ConversationHistory(BaseModel):
|
29 |
+
id: str # UUID as string
|
30 |
+
user_id: str # UUID as string
|
31 |
query: str
|
32 |
+
response: str # Response stored as TEXT in the DB
|
33 |
timestamp: str
|
34 |
+
thread_id: str # UUID as string
|
35 |
+
title: str = None # Optional title
|
36 |
|
37 |
# Initialize LangChain with Groq
|
38 |
llm = ChatGroq(
|
|
|
41 |
groq_api_key="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" # Replace with your actual Groq API key
|
42 |
)
|
43 |
|
44 |
+
# Define prompt templates
|
45 |
prompt_templates = {
|
46 |
"common_threats": PromptTemplate(
|
47 |
input_variables=["query", "context"],
|
|
|
59 |
Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context.
|
60 |
"""
|
61 |
),
|
|
|
62 |
}
|
63 |
|
64 |
# Initialize chains for each prompt
|
65 |
chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}
|
66 |
|
67 |
+
# Helper function to get conversation history for a user (by user_id and optionally thread_id)
|
68 |
+
def get_conversation_history(user_id: str, thread_id: str = None) -> List[Dict]:
|
69 |
try:
|
70 |
+
query = supabase.table("conversation_history").select("*").eq("user_id", user_id)
|
71 |
+
if thread_id:
|
72 |
+
query = query.eq("thread_id", thread_id)
|
73 |
+
response = query.order("timestamp", desc=True).execute()
|
74 |
return response.data
|
75 |
except Exception as e:
|
76 |
print(f"Error retrieving history: {e}")
|
77 |
return []
|
78 |
|
79 |
+
# Helper function to get all threads for a user (distinct thread_ids with their titles)
|
80 |
+
def get_user_threads(user_id: str) -> List[Dict]:
|
81 |
+
try:
|
82 |
+
# Select distinct threads with their titles
|
83 |
+
response = supabase.table("conversation_history")\
|
84 |
+
.select("thread_id, title")\
|
85 |
+
.eq("user_id", user_id)\
|
86 |
+
.order("timestamp", desc=True)\
|
87 |
+
.execute()
|
88 |
+
# Remove duplicates while preserving order
|
89 |
+
seen = set()
|
90 |
+
threads = []
|
91 |
+
for item in response.data:
|
92 |
+
if item["thread_id"] not in seen:
|
93 |
+
seen.add(item["thread_id"])
|
94 |
+
threads.append({
|
95 |
+
"thread_id": item["thread_id"],
|
96 |
+
"title": item["title"] or f"Thread {len(threads) + 1}"
|
97 |
+
})
|
98 |
+
return threads
|
99 |
+
except Exception as e:
|
100 |
+
print(f"Error retrieving threads: {e}")
|
101 |
+
return []
|
102 |
+
|
103 |
# Helper function to save conversation to Supabase
|
104 |
+
def save_conversation(user_id: str, query: str, response: Dict, thread_id: str = None, title: str = None):
|
105 |
try:
|
106 |
+
# If no thread_id is provided, generate a new one
|
107 |
+
if not thread_id:
|
108 |
+
thread_id = str(uuid.uuid4())
|
109 |
+
|
110 |
+
# If no title is provided, generate a default one based on the query
|
111 |
+
if not title:
|
112 |
+
title = query[:50] + "..." if len(query) > 50 else query
|
113 |
+
|
114 |
conversation = {
|
115 |
"user_id": user_id,
|
116 |
"query": query,
|
117 |
+
"response": json.dumps(response), # Convert response Dict to string
|
118 |
+
"timestamp": datetime.utcnow().isoformat(),
|
119 |
+
"thread_id": thread_id,
|
120 |
+
"title": title
|
121 |
}
|
122 |
supabase.table("conversation_history").insert(conversation).execute()
|
123 |
+
return thread_id
|
124 |
except Exception as e:
|
125 |
print(f"Error saving conversation: {e}")
|
126 |
+
raise
|
127 |
|
128 |
@app.post("/search")
|
129 |
async def process_search(search_query: SearchQuery):
|
130 |
try:
|
131 |
+
# Validate user_id as UUID
|
132 |
+
try:
|
133 |
+
uuid.UUID(search_query.user_id)
|
134 |
+
except ValueError:
|
135 |
+
raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")
|
136 |
+
|
137 |
+
# Validate thread_id as UUID if provided
|
138 |
+
if search_query.thread_id:
|
139 |
+
try:
|
140 |
+
uuid.UUID(search_query.thread_id)
|
141 |
+
except ValueError:
|
142 |
+
raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID")
|
143 |
+
|
144 |
# Set default context if not provided
|
145 |
base_context = search_query.context or "You are a cybersecurity expert."
|
146 |
|
147 |
+
# Retrieve previous conversation history for context (within the same thread if thread_id is provided)
|
148 |
+
history = get_conversation_history(search_query.user_id, search_query.thread_id)
|
149 |
+
history_context = "\n".join([f"Previous Query: {item['query']}\nPrevious Response: {json.loads(item['response'])['Use Clear Language: Avoid ambiguity and complex wording']}" for item in history])
|
150 |
full_context = f"{base_context}\n{history_context}" if history_context else base_context
|
151 |
|
152 |
+
# Default to the "general" prompt template
|
153 |
query_type = "general"
|
154 |
|
155 |
# Process the query using the general chain
|
|
|
163 |
"Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity."
|
164 |
}
|
165 |
|
166 |
+
# Save the conversation to Supabase (append to existing thread or create new)
|
167 |
+
thread_id = save_conversation(
|
168 |
+
user_id=search_query.user_id,
|
169 |
+
query=search_query.query,
|
170 |
+
response=structured_response,
|
171 |
+
thread_id=search_query.thread_id
|
172 |
+
)
|
173 |
|
174 |
return {
|
175 |
"status": "success",
|
176 |
"response": structured_response,
|
177 |
+
"thread_id": thread_id,
|
178 |
"classified_type": query_type
|
179 |
}
|
180 |
except Exception as e:
|
181 |
raise HTTPException(status_code=500, detail=str(e))
|
182 |
|
183 |
@app.get("/history/{user_id}")
|
184 |
+
async def get_history(user_id: str, thread_id: str = None):
|
185 |
try:
|
186 |
+
# Validate user_id as UUID
|
187 |
+
try:
|
188 |
+
uuid.UUID(user_id)
|
189 |
+
except ValueError:
|
190 |
+
raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")
|
191 |
+
|
192 |
+
# Validate thread_id as UUID if provided
|
193 |
+
if thread_id:
|
194 |
+
try:
|
195 |
+
uuid.UUID(thread_id)
|
196 |
+
except ValueError:
|
197 |
+
raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID")
|
198 |
+
|
199 |
+
# Get conversation history for the user (optionally filtered by thread_id)
|
200 |
+
history = get_conversation_history(user_id, thread_id)
|
201 |
return {
|
202 |
"status": "success",
|
203 |
"history": history
|
|
|
205 |
except Exception as e:
|
206 |
raise HTTPException(status_code=500, detail=str(e))
|
207 |
|
208 |
+
@app.get("/threads/{user_id}")
|
209 |
+
async def get_threads(user_id: str):
|
210 |
+
try:
|
211 |
+
# Validate user_id as UUID
|
212 |
+
try:
|
213 |
+
uuid.UUID(user_id)
|
214 |
+
except ValueError:
|
215 |
+
raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")
|
216 |
+
|
217 |
+
# Get all threads for the user
|
218 |
+
threads = get_user_threads(user_id)
|
219 |
+
return {
|
220 |
+
"status": "success",
|
221 |
+
"threads": threads
|
222 |
+
}
|
223 |
+
except Exception as e:
|
224 |
+
raise HTTPException(status_code=500, detail=str(e))
|
225 |
+
|
226 |
@app.get("/")
|
227 |
async def root():
|
228 |
+
return {"message": "Search API with structured response, history, and threads is running"}
|
229 |
|
230 |
if __name__ == "__main__":
|
231 |
import uvicorn
|