intelliSOC / app.py
rajrakeshdr's picture
Update app.py
6af12dc verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_groq import ChatGroq
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from supabase import create_client, Client
from datetime import datetime
from typing import List, Dict
import json
import uuid
# Initialize FastAPI app
app = FastAPI()
# Supabase setup (replace with your Supabase URL and key)
SUPABASE_URL = "https://ykkbxlbonywjmvbyfvwo.supabase.co"
SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Inlra2J4bGJvbnl3am12YnlmdndvIiwicm9sZSI6ImFub24iLCJpYXQiOjE3Mzk5NTA2NjIsImV4cCI6MjA1NTUyNjY2Mn0.2BZul_igHKmZtQGhbwV3PvRsCikxviL8ogTKPD3XhuU"
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
# Create a request model with context, user_id, and optional thread_id
class SearchQuery(BaseModel):
query: str
context: str = None # Optional context field
user_id: str # UUID string to identify the user for storing history
thread_id: str = None # Optional thread_id to append to an existing thread
# Create a response model for history
class ConversationHistory(BaseModel):
id: str # UUID as string
user_id: str # UUID as string
query: str
response: str # Response stored as TEXT in the DB
timestamp: str
thread_id: str # UUID as string
title: str = None # Optional title
# Initialize LangChain with Groq
llm = ChatGroq(
temperature=0.7,
model_name="mixtral-8x7b-32768",
groq_api_key="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" # Replace with your actual Groq API key
)
# Define prompt templates
prompt_templates = {
"common_threats": PromptTemplate(
input_variables=["query", "context"],
template="""
Context: {context}
Query: {query}
Provide a comprehensive overview of the most common cybersecurity threats faced by organizations on a daily basis, including details on threat_1, threat_2, and threat_3. Also, provide effective measures to mitigate these risks and protect critical data and systems.
"""
),
"general": PromptTemplate(
input_variables=["query", "context"],
template="""
Context: You are a cybersecurity expert with extensive experience in all sub-streams of the industry, including but not limited to network security, application security, cloud security, threat intelligence, penetration testing, and incident response. {context}
Query: {query}
Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context.
"""
),
}
# Initialize chains for each prompt
chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}
# Helper function to get conversation history for a user (by user_id and optionally thread_id)
def get_conversation_history(user_id: str, thread_id: str = None) -> List[Dict]:
try:
query = supabase.table("conversation_history").select("*").eq("user_id", user_id)
if thread_id:
query = query.eq("thread_id", thread_id)
response = query.order("timestamp", desc=True).execute()
return response.data
except Exception as e:
print(f"Error retrieving history: {e}")
return []
# Helper function to get all threads for a user (distinct thread_ids with their titles)
def get_user_threads(user_id: str) -> List[Dict]:
try:
# Select distinct threads with their titles
response = supabase.table("conversation_history")\
.select("thread_id, title")\
.eq("user_id", user_id)\
.order("timestamp", desc=True)\
.execute()
# Remove duplicates while preserving order
seen = set()
threads = []
for item in response.data:
if item["thread_id"] not in seen:
seen.add(item["thread_id"])
threads.append({
"thread_id": item["thread_id"],
"title": item["title"] or f"Thread {len(threads) + 1}"
})
return threads
except Exception as e:
print(f"Error retrieving threads: {e}")
return []
# Helper function to save conversation to Supabase
def save_conversation(user_id: str, query: str, response: Dict, thread_id: str = None, title: str = None):
try:
# If no thread_id is provided, generate a new one
if not thread_id:
thread_id = str(uuid.uuid4())
# If no title is provided, generate a default one based on the query
if not title:
title = query[:50] + "..." if len(query) > 50 else query
conversation = {
"user_id": user_id,
"query": query,
"response": json.dumps(response), # Convert response Dict to string
"timestamp": datetime.utcnow().isoformat(),
"thread_id": thread_id,
"title": title
}
supabase.table("conversation_history").insert(conversation).execute()
return thread_id
except Exception as e:
print(f"Error saving conversation: {e}")
raise
@app.post("/search")
async def process_search(search_query: SearchQuery):
try:
# Validate user_id as UUID
try:
uuid.UUID(search_query.user_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")
# Validate thread_id as UUID if provided
if search_query.thread_id:
try:
uuid.UUID(search_query.thread_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID")
# Set default context if not provided
base_context = search_query.context or "You are a cybersecurity expert."
# Retrieve previous conversation history for context (within the same thread if thread_id is provided)
history = get_conversation_history(search_query.user_id, search_query.thread_id)
history_context = "\n".join([f"Previous Query: {item['query']}\nPrevious Response: {json.loads(item['response'])['Use Clear Language: Avoid ambiguity and complex wording']}" for item in history])
full_context = f"{base_context}\n{history_context}" if history_context else base_context
# Default to the "general" prompt template
query_type = "general"
# Process the query using the general chain
raw_response = chains[query_type].run(query=search_query.query, context=full_context)
# Structure the response according to the desired format
structured_response = {
"Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.",
"Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{full_context}', guiding the response to align with cybersecurity expertise.",
"Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(),
"Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity."
}
# Save the conversation to Supabase (append to existing thread or create new)
thread_id = save_conversation(
user_id=search_query.user_id,
query=search_query.query,
response=structured_response,
thread_id=search_query.thread_id
)
return {
"status": "success",
"response": structured_response,
"thread_id": thread_id,
"classified_type": query_type
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/history/{user_id}")
async def get_history(user_id: str, thread_id: str = None):
try:
# Validate user_id as UUID
try:
uuid.UUID(user_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")
# Validate thread_id as UUID if provided
if thread_id:
try:
uuid.UUID(thread_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID")
# Get conversation history for the user (optionally filtered by thread_id)
history = get_conversation_history(user_id, thread_id)
return {
"status": "success",
"history": history
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/threads/{user_id}")
async def get_threads(user_id: str):
try:
# Validate user_id as UUID
try:
uuid.UUID(user_id)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")
# Get all threads for the user
threads = get_user_threads(user_id)
return {
"status": "success",
"threads": threads
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "Search API with structured response, history, and threads is running"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)