Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from langchain_groq import ChatGroq | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
from supabase import create_client, Client | |
from datetime import datetime | |
from typing import List, Dict | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Supabase setup (replace with your Supabase URL and key) | |
SUPABASE_URL = "https://ykkbxlbonywjmvbyfvwo.supabase.co" | |
SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Inlra2J4bGJvbnl3am12YnlmdndvIiwicm9sZSI6ImFub24iLCJpYXQiOjE3Mzk5NTA2NjIsImV4cCI6MjA1NTUyNjY2Mn0.2BZul_igHKmZtQGhbwV3PvRsCikxviL8ogTKPD3XhuU" | |
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) | |
# Create a request model with context and user_id | |
class SearchQuery(BaseModel): | |
query: str | |
context: str = None # Optional context field | |
user_id: str # Required to identify the user for storing history | |
# Create a response model for history | |
class ConversationHistory(BaseModel): | |
query: str | |
response: Dict | |
timestamp: str | |
# Initialize LangChain with Groq | |
llm = ChatGroq( | |
temperature=0.7, | |
model_name="mixtral-8x7b-32768", | |
groq_api_key="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" # Replace with your actual Groq API key | |
) | |
# Define prompt templates (keeping all for future flexibility, but defaulting to "general") | |
prompt_templates = { | |
"common_threats": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Provide a comprehensive overview of the most common cybersecurity threats faced by organizations on a daily basis, including details on threat_1, threat_2, and threat_3. Also, provide effective measures to mitigate these risks and protect critical data and systems. | |
""" | |
), | |
"general": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: You are a cybersecurity expert with extensive experience in all sub-streams of the industry, including but not limited to network security, application security, cloud security, threat intelligence, penetration testing, and incident response. {context} | |
Query: {query} | |
Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context. | |
""" | |
), | |
# You can keep other templates here if you want to manually select them later | |
} | |
# Initialize chains for each prompt | |
chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()} | |
# Helper function to get conversation history for a user | |
def get_conversation_history(user_id: str) -> List[Dict]: | |
try: | |
response = supabase.table("conversation_history").select("*").eq("user_id", user_id).order("timestamp", desc=True).execute() | |
return response.data | |
except Exception as e: | |
print(f"Error retrieving history: {e}") | |
return [] | |
# Helper function to save conversation to Supabase | |
def save_conversation(user_id: str, query: str, response: Dict): | |
try: | |
conversation = { | |
"user_id": user_id, | |
"query": query, | |
"response": response, | |
"timestamp": datetime.utcnow().isoformat() | |
} | |
supabase.table("conversation_history").insert(conversation).execute() | |
except Exception as e: | |
print(f"Error saving conversation: {e}") | |
async def process_search(search_query: SearchQuery): | |
try: | |
# Set default context if not provided | |
base_context = search_query.context or "You are a cybersecurity expert." | |
# Retrieve previous conversation history for context | |
history = get_conversation_history(search_query.user_id) | |
history_context = "\n".join([f"Previous Query: {item['query']}\nPrevious Response: {item['response']['Use Clear Language: Avoid ambiguity and complex wording']}" for item in history]) | |
full_context = f"{base_context}\n{history_context}" if history_context else base_context | |
# Default to the "general" prompt template (no classification) | |
query_type = "general" | |
# Process the query using the general chain | |
raw_response = chains[query_type].run(query=search_query.query, context=full_context) | |
# Structure the response according to the desired format | |
structured_response = { | |
"Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.", | |
"Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{full_context}', guiding the response to align with cybersecurity expertise.", | |
"Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(), | |
"Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity." | |
} | |
# Save the conversation to Supabase | |
save_conversation(search_query.user_id, search_query.query, structured_response) | |
return { | |
"status": "success", | |
"response": structured_response, | |
"classified_type": query_type | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_history(user_id: str): | |
try: | |
history = get_conversation_history(user_id) | |
return { | |
"status": "success", | |
"history": history | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return {"message": "Search API with structured response and history is running"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |