from fastapi import FastAPI, HTTPException from pydantic import BaseModel from langchain_groq import ChatGroq from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from supabase import create_client, Client from datetime import datetime from typing import List, Dict import json import uuid # Initialize FastAPI app app = FastAPI() # Supabase setup (replace with your Supabase URL and key) SUPABASE_URL = "https://ykkbxlbonywjmvbyfvwo.supabase.co" SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Inlra2J4bGJvbnl3am12YnlmdndvIiwicm9sZSI6ImFub24iLCJpYXQiOjE3Mzk5NTA2NjIsImV4cCI6MjA1NTUyNjY2Mn0.2BZul_igHKmZtQGhbwV3PvRsCikxviL8ogTKPD3XhuU" supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) # Create a request model with context, user_id, and optional thread_id class SearchQuery(BaseModel): query: str context: str = None # Optional context field user_id: str # UUID string to identify the user for storing history thread_id: str = None # Optional thread_id to append to an existing thread # Create a response model for history class ConversationHistory(BaseModel): id: str # UUID as string user_id: str # UUID as string query: str response: str # Response stored as TEXT in the DB timestamp: str thread_id: str # UUID as string title: str = None # Optional title # Initialize LangChain with Groq llm = ChatGroq( temperature=0.7, model_name="mixtral-8x7b-32768", groq_api_key="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" # Replace with your actual Groq API key ) # Define prompt templates prompt_templates = { "common_threats": PromptTemplate( input_variables=["query", "context"], template=""" Context: {context} Query: {query} Provide a comprehensive overview of the most common cybersecurity threats faced by organizations on a daily basis, including details on threat_1, threat_2, and threat_3. Also, provide effective measures to mitigate these risks and protect critical data and systems. """ ), "general": PromptTemplate( input_variables=["query", "context"], template=""" Context: You are a cybersecurity expert with extensive experience in all sub-streams of the industry, including but not limited to network security, application security, cloud security, threat intelligence, penetration testing, and incident response. {context} Query: {query} Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context. """ ), } # Initialize chains for each prompt chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()} # Helper function to get conversation history for a user (by user_id and optionally thread_id) def get_conversation_history(user_id: str, thread_id: str = None) -> List[Dict]: try: query = supabase.table("conversation_history").select("*").eq("user_id", user_id) if thread_id: query = query.eq("thread_id", thread_id) response = query.order("timestamp", desc=True).execute() return response.data except Exception as e: print(f"Error retrieving history: {e}") return [] # Helper function to get all threads for a user (distinct thread_ids with their titles) def get_user_threads(user_id: str) -> List[Dict]: try: # Select distinct threads with their titles response = supabase.table("conversation_history")\ .select("thread_id, title")\ .eq("user_id", user_id)\ .order("timestamp", desc=True)\ .execute() # Remove duplicates while preserving order seen = set() threads = [] for item in response.data: if item["thread_id"] not in seen: seen.add(item["thread_id"]) threads.append({ "thread_id": item["thread_id"], "title": item["title"] or f"Thread {len(threads) + 1}" }) return threads except Exception as e: print(f"Error retrieving threads: {e}") return [] # Helper function to save conversation to Supabase def save_conversation(user_id: str, query: str, response: Dict, thread_id: str = None, title: str = None): try: # If no thread_id is provided, generate a new one if not thread_id: thread_id = str(uuid.uuid4()) # If no title is provided, generate a default one based on the query if not title: title = query[:50] + "..." if len(query) > 50 else query conversation = { "user_id": user_id, "query": query, "response": json.dumps(response), # Convert response Dict to string "timestamp": datetime.utcnow().isoformat(), "thread_id": thread_id, "title": title } supabase.table("conversation_history").insert(conversation).execute() return thread_id except Exception as e: print(f"Error saving conversation: {e}") raise @app.post("/search") async def process_search(search_query: SearchQuery): try: # Validate user_id as UUID try: uuid.UUID(search_query.user_id) except ValueError: raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID") # Validate thread_id as UUID if provided if search_query.thread_id: try: uuid.UUID(search_query.thread_id) except ValueError: raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID") # Set default context if not provided base_context = search_query.context or "You are a cybersecurity expert." # Retrieve previous conversation history for context (within the same thread if thread_id is provided) history = get_conversation_history(search_query.user_id, search_query.thread_id) history_context = "\n".join([f"Previous Query: {item['query']}\nPrevious Response: {json.loads(item['response'])['Use Clear Language: Avoid ambiguity and complex wording']}" for item in history]) full_context = f"{base_context}\n{history_context}" if history_context else base_context # Default to the "general" prompt template query_type = "general" # Process the query using the general chain raw_response = chains[query_type].run(query=search_query.query, context=full_context) # Structure the response according to the desired format structured_response = { "Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.", "Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{full_context}', guiding the response to align with cybersecurity expertise.", "Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(), "Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity." } # Save the conversation to Supabase (append to existing thread or create new) thread_id = save_conversation( user_id=search_query.user_id, query=search_query.query, response=structured_response, thread_id=search_query.thread_id ) return { "status": "success", "response": structured_response, "thread_id": thread_id, "classified_type": query_type } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/history/{user_id}") async def get_history(user_id: str, thread_id: str = None): try: # Validate user_id as UUID try: uuid.UUID(user_id) except ValueError: raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID") # Validate thread_id as UUID if provided if thread_id: try: uuid.UUID(thread_id) except ValueError: raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID") # Get conversation history for the user (optionally filtered by thread_id) history = get_conversation_history(user_id, thread_id) return { "status": "success", "history": history } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/threads/{user_id}") async def get_threads(user_id: str): try: # Validate user_id as UUID try: uuid.UUID(user_id) except ValueError: raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID") # Get all threads for the user threads = get_user_threads(user_id) return { "status": "success", "threads": threads } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): return {"message": "Search API with structured response, history, and threads is running"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)