Spaces:
Running
Running
File size: 9,694 Bytes
85807eb 04b4725 85807eb d649e07 6af12dc 85807eb d17c60a 1476c30 85807eb d649e07 6af12dc 85807eb 6af12dc d649e07 6af12dc d649e07 6af12dc d649e07 6af12dc 85807eb d17c60a 85807eb 6af12dc 2bf42e1 c908d2d 2bf42e1 85807eb 2bf42e1 85807eb 6af12dc d649e07 6af12dc d649e07 dc842ab 6af12dc d649e07 6af12dc dc842ab 6af12dc d649e07 6af12dc d649e07 6af12dc dc842ab d649e07 6af12dc 85807eb 6af12dc 38c100a d649e07 6af12dc d649e07 1071d26 6af12dc d649e07 2bf42e1 d649e07 dc842ab d649e07 dc842ab d17c60a 6af12dc d649e07 85807eb dc842ab 6af12dc d649e07 6af12dc d649e07 6af12dc d649e07 1071d26 85807eb 6af12dc 85807eb 6af12dc dc842ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_groq import ChatGroq
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from supabase import create_client, Client
from datetime import datetime
from typing import List, Dict
import json
import uuid
# Initialize FastAPI app
app = FastAPI()
# Supabase setup (replace with your Supabase URL and key)
SUPABASE_URL = "https://ykkbxlbonywjmvbyfvwo.supabase.co"
SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Inlra2J4bGJvbnl3am12YnlmdndvIiwicm9sZSI6ImFub24iLCJpYXQiOjE3Mzk5NTA2NjIsImV4cCI6MjA1NTUyNjY2Mn0.2BZul_igHKmZtQGhbwV3PvRsCikxviL8ogTKPD3XhuU"
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
# Create a request model with context, user_id, and optional thread_id
class SearchQuery(BaseModel):
query: str
context: str = None # Optional context field
user_id: str # UUID string to identify the user for storing history
thread_id: str = None # Optional thread_id to append to an existing thread
# Create a response model for history
class ConversationHistory(BaseModel):
id: str # UUID as string
user_id: str # UUID as string
query: str
response: str # Response stored as TEXT in the DB
timestamp: str
thread_id: str # UUID as string
title: str = None # Optional title
# Initialize LangChain with Groq
llm = ChatGroq(
temperature=0.7,
model_name="mixtral-8x7b-32768",
groq_api_key="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" # Replace with your actual Groq API key
)
# Define prompt templates
prompt_templates = {
"common_threats": PromptTemplate(
input_variables=["query", "context"],
template="""
Context: {context}
Query: {query}
Provide a comprehensive overview of the most common cybersecurity threats faced by organizations on a daily basis, including details on threat_1, threat_2, and threat_3. Also, provide effective measures to mitigate these risks and protect critical data and systems.
"""
),
"general": PromptTemplate(
input_variables=["query", "context"],
template="""
Context: You are a cybersecurity expert with extensive experience in all sub-streams of the industry, including but not limited to network security, application security, cloud security, threat intelligence, penetration testing, and incident response. {context}
Query: {query}
Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context.
"""
),
}
# Initialize chains for each prompt
chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}
# Helper function to get conversation history for a user (by user_id and optionally thread_id)
def get_conversation_history(user_id: str, thread_id: str = None) -> List[Dict]:
try:
query = supabase.table("conversation_history").select("*").eq("user_id", user_id)
if thread_id:
query = query.eq("thread_id", thread_id)
response = query.order("timestamp", desc=True).execute()
return response.data
except Exception as e:
print(f"Error retrieving history: {e}")
return []
# Helper function to get all threads for a user (distinct thread_ids with their titles)
def get_user_threads(user_id: str) -> List[Dict]:
try:
# Select distinct threads with their titles
response = supabase.table("conversation_history")\
.select("thread_id, title")\
.eq("user_id", user_id)\
.order("timestamp", desc=True)\
.execute()
# Remove duplicates while preserving order
seen = set()
threads = []
for item in response.data:
if item["thread_id"] not in seen:
seen.add(item["thread_id"])
threads.append({
"thread_id": item["thread_id"],
"title": item["title"] or f"Thread {len(threads) + 1}"
})
return threads
except Exception as e:
print(f"Error retrieving threads: {e}")
return []
# Helper function to save conversation to Supabase
def save_conversation(user_id: str, query: str, response: Dict, thread_id: str = None, title: str = None):
try:
# If no thread_id is provided, generate a new one
if not thread_id:
thread_id = str(uuid.uuid4())
# If no title is provided, generate a default one based on the query
if not title:
title = query[:50] + "..." if len(query) > 50 else query
conversation = {
"user_id": user_id,
"query": query,
"response": json.dumps(response), # Convert response Dict to string
"timestamp": datetime.utcnow().isoformat(),
"thread_id": thread_id,
"title": title
}
supabase.table("conversation_history").insert(conversation).execute()
return thread_id
except Exception as e:
print(f"Error saving conversation: {e}")
raise
@app.post("/search")
async def process_search(search_query: SearchQuery):
try:
# Validate user_id as UUID
try:
uuid.UUID(search_query.user_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")
# Validate thread_id as UUID if provided
if search_query.thread_id:
try:
uuid.UUID(search_query.thread_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID")
# Set default context if not provided
base_context = search_query.context or "You are a cybersecurity expert."
# Retrieve previous conversation history for context (within the same thread if thread_id is provided)
history = get_conversation_history(search_query.user_id, search_query.thread_id)
history_context = "\n".join([f"Previous Query: {item['query']}\nPrevious Response: {json.loads(item['response'])['Use Clear Language: Avoid ambiguity and complex wording']}" for item in history])
full_context = f"{base_context}\n{history_context}" if history_context else base_context
# Default to the "general" prompt template
query_type = "general"
# Process the query using the general chain
raw_response = chains[query_type].run(query=search_query.query, context=full_context)
# Structure the response according to the desired format
structured_response = {
"Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.",
"Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{full_context}', guiding the response to align with cybersecurity expertise.",
"Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(),
"Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity."
}
# Save the conversation to Supabase (append to existing thread or create new)
thread_id = save_conversation(
user_id=search_query.user_id,
query=search_query.query,
response=structured_response,
thread_id=search_query.thread_id
)
return {
"status": "success",
"response": structured_response,
"thread_id": thread_id,
"classified_type": query_type
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/history/{user_id}")
async def get_history(user_id: str, thread_id: str = None):
try:
# Validate user_id as UUID
try:
uuid.UUID(user_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")
# Validate thread_id as UUID if provided
if thread_id:
try:
uuid.UUID(thread_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID")
# Get conversation history for the user (optionally filtered by thread_id)
history = get_conversation_history(user_id, thread_id)
return {
"status": "success",
"history": history
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/threads/{user_id}")
async def get_threads(user_id: str):
try:
# Validate user_id as UUID
try:
uuid.UUID(user_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")
# Get all threads for the user
threads = get_user_threads(user_id)
return {
"status": "success",
"threads": threads
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "Search API with structured response, history, and threads is running"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |