File size: 9,694 Bytes
85807eb
 
 
04b4725
85807eb
d649e07
 
 
6af12dc
 
85807eb
d17c60a
1476c30
85807eb
d649e07
 
 
 
 
6af12dc
85807eb
 
 
6af12dc
 
d649e07
 
 
6af12dc
 
d649e07
6af12dc
d649e07
6af12dc
 
85807eb
d17c60a
85807eb
 
 
 
 
 
6af12dc
2bf42e1
 
 
 
 
 
 
 
 
c908d2d
 
 
 
 
 
 
 
2bf42e1
85807eb
2bf42e1
 
85807eb
6af12dc
 
d649e07
6af12dc
 
 
 
d649e07
 
 
 
dc842ab
6af12dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d649e07
6af12dc
dc842ab
6af12dc
 
 
 
 
 
 
 
d649e07
 
 
6af12dc
 
 
 
d649e07
 
6af12dc
dc842ab
d649e07
6af12dc
85807eb
 
 
 
6af12dc
 
 
 
 
 
 
 
 
 
 
 
 
38c100a
d649e07
 
6af12dc
 
 
d649e07
1071d26
6af12dc
d649e07
2bf42e1
d649e07
 
dc842ab
 
 
 
d649e07
 
dc842ab
 
d17c60a
6af12dc
 
 
 
 
 
 
d649e07
85807eb
 
dc842ab
6af12dc
d649e07
 
 
 
 
 
6af12dc
d649e07
6af12dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d649e07
 
 
1071d26
85807eb
 
 
6af12dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85807eb
 
6af12dc
dc842ab
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_groq import ChatGroq
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from supabase import create_client, Client
from datetime import datetime
from typing import List, Dict
import json
import uuid

# Initialize FastAPI app
app = FastAPI()

# Supabase setup (replace with your Supabase URL and key)
SUPABASE_URL = "https://ykkbxlbonywjmvbyfvwo.supabase.co"
SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Inlra2J4bGJvbnl3am12YnlmdndvIiwicm9sZSI6ImFub24iLCJpYXQiOjE3Mzk5NTA2NjIsImV4cCI6MjA1NTUyNjY2Mn0.2BZul_igHKmZtQGhbwV3PvRsCikxviL8ogTKPD3XhuU"
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)

# Create a request model with context, user_id, and optional thread_id
class SearchQuery(BaseModel):
    query: str
    context: str = None  # Optional context field
    user_id: str  # UUID string to identify the user for storing history
    thread_id: str = None  # Optional thread_id to append to an existing thread

# Create a response model for history
class ConversationHistory(BaseModel):
    id: str  # UUID as string
    user_id: str  # UUID as string
    query: str
    response: str  # Response stored as TEXT in the DB
    timestamp: str
    thread_id: str  # UUID as string
    title: str = None  # Optional title

# Initialize LangChain with Groq
llm = ChatGroq(
    temperature=0.7,
    model_name="mixtral-8x7b-32768",
    groq_api_key="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"  # Replace with your actual Groq API key
)

# Define prompt templates
prompt_templates = {
    "common_threats": PromptTemplate(
        input_variables=["query", "context"],
        template="""
        Context: {context}
        Query: {query}
        Provide a comprehensive overview of the most common cybersecurity threats faced by organizations on a daily basis, including details on threat_1, threat_2, and threat_3. Also, provide effective measures to mitigate these risks and protect critical data and systems.
        """
    ),
    "general": PromptTemplate(
        input_variables=["query", "context"],
        template="""
        Context: You are a cybersecurity expert with extensive experience in all sub-streams of the industry, including but not limited to network security, application security, cloud security, threat intelligence, penetration testing, and incident response. {context}
        Query: {query}
        Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context.
        """
    ),
}

# Initialize chains for each prompt
chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}

# Helper function to get conversation history for a user (by user_id and optionally thread_id)
def get_conversation_history(user_id: str, thread_id: str = None) -> List[Dict]:
    try:
        query = supabase.table("conversation_history").select("*").eq("user_id", user_id)
        if thread_id:
            query = query.eq("thread_id", thread_id)
        response = query.order("timestamp", desc=True).execute()
        return response.data
    except Exception as e:
        print(f"Error retrieving history: {e}")
        return []

# Helper function to get all threads for a user (distinct thread_ids with their titles)
def get_user_threads(user_id: str) -> List[Dict]:
    try:
        # Select distinct threads with their titles
        response = supabase.table("conversation_history")\
            .select("thread_id, title")\
            .eq("user_id", user_id)\
            .order("timestamp", desc=True)\
            .execute()
        # Remove duplicates while preserving order
        seen = set()
        threads = []
        for item in response.data:
            if item["thread_id"] not in seen:
                seen.add(item["thread_id"])
                threads.append({
                    "thread_id": item["thread_id"],
                    "title": item["title"] or f"Thread {len(threads) + 1}"
                })
        return threads
    except Exception as e:
        print(f"Error retrieving threads: {e}")
        return []

# Helper function to save conversation to Supabase
def save_conversation(user_id: str, query: str, response: Dict, thread_id: str = None, title: str = None):
    try:
        # If no thread_id is provided, generate a new one
        if not thread_id:
            thread_id = str(uuid.uuid4())
        
        # If no title is provided, generate a default one based on the query
        if not title:
            title = query[:50] + "..." if len(query) > 50 else query

        conversation = {
            "user_id": user_id,
            "query": query,
            "response": json.dumps(response),  # Convert response Dict to string
            "timestamp": datetime.utcnow().isoformat(),
            "thread_id": thread_id,
            "title": title
        }
        supabase.table("conversation_history").insert(conversation).execute()
        return thread_id
    except Exception as e:
        print(f"Error saving conversation: {e}")
        raise

@app.post("/search")
async def process_search(search_query: SearchQuery):
    try:
        # Validate user_id as UUID
        try:
            uuid.UUID(search_query.user_id)
        except ValueError:
            raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")

        # Validate thread_id as UUID if provided
        if search_query.thread_id:
            try:
                uuid.UUID(search_query.thread_id)
            except ValueError:
                raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID")

        # Set default context if not provided
        base_context = search_query.context or "You are a cybersecurity expert."

        # Retrieve previous conversation history for context (within the same thread if thread_id is provided)
        history = get_conversation_history(search_query.user_id, search_query.thread_id)
        history_context = "\n".join([f"Previous Query: {item['query']}\nPrevious Response: {json.loads(item['response'])['Use Clear Language: Avoid ambiguity and complex wording']}" for item in history])
        full_context = f"{base_context}\n{history_context}" if history_context else base_context

        # Default to the "general" prompt template
        query_type = "general"

        # Process the query using the general chain
        raw_response = chains[query_type].run(query=search_query.query, context=full_context)

        # Structure the response according to the desired format
        structured_response = {
            "Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.",
            "Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{full_context}', guiding the response to align with cybersecurity expertise.",
            "Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(),
            "Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity."
        }

        # Save the conversation to Supabase (append to existing thread or create new)
        thread_id = save_conversation(
            user_id=search_query.user_id,
            query=search_query.query,
            response=structured_response,
            thread_id=search_query.thread_id
        )

        return {
            "status": "success",
            "response": structured_response,
            "thread_id": thread_id,
            "classified_type": query_type
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/history/{user_id}")
async def get_history(user_id: str, thread_id: str = None):
    try:
        # Validate user_id as UUID
        try:
            uuid.UUID(user_id)
        except ValueError:
            raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")

        # Validate thread_id as UUID if provided
        if thread_id:
            try:
                uuid.UUID(thread_id)
            except ValueError:
                raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID")

        # Get conversation history for the user (optionally filtered by thread_id)
        history = get_conversation_history(user_id, thread_id)
        return {
            "status": "success",
            "history": history
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/threads/{user_id}")
async def get_threads(user_id: str):
    try:
        # Validate user_id as UUID
        try:
            uuid.UUID(user_id)
        except ValueError:
            raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID")

        # Get all threads for the user
        threads = get_user_threads(user_id)
        return {
            "status": "success",
            "threads": threads
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    return {"message": "Search API with structured response, history, and threads is running"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)