rajrakeshdr commited on
Commit
d649e07
·
verified ·
1 Parent(s): 0caf559

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -211
app.py CHANGED
@@ -3,16 +3,29 @@ from pydantic import BaseModel
3
  from langchain_groq import ChatGroq
4
  from langchain.chains import LLMChain
5
  from langchain.prompts import PromptTemplate
6
- from transformers import pipeline
7
- import os
 
8
 
9
  # Initialize FastAPI app
10
  app = FastAPI()
11
 
12
- # Create a request model with context
 
 
 
 
 
13
  class SearchQuery(BaseModel):
14
  query: str
15
  context: str = None # Optional context field
 
 
 
 
 
 
 
16
 
17
  # Initialize LangChain with Groq
18
  llm = ChatGroq(
@@ -21,7 +34,7 @@ llm = ChatGroq(
21
  groq_api_key="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" # Replace with your actual Groq API key
22
  )
23
 
24
- # Define all prompt templates
25
  prompt_templates = {
26
  "common_threats": PromptTemplate(
27
  input_variables=["query", "context"],
@@ -31,174 +44,6 @@ prompt_templates = {
31
  Provide a comprehensive overview of the most common cybersecurity threats faced by organizations on a daily basis, including details on threat_1, threat_2, and threat_3. Also, provide effective measures to mitigate these risks and protect critical data and systems.
32
  """
33
  ),
34
- "task_prioritization": PromptTemplate(
35
- input_variables=["query", "context"],
36
- template="""
37
- Context: {context}
38
- Query: {query}
39
- Provide a guide on how cybersecurity professionals prioritize their tasks and responsibilities, focusing on the most critical areas such as threat detection, response times, and resource allocation.
40
- """
41
- ),
42
- "network_traffic_tools": PromptTemplate(
43
- input_variables=["query", "context"],
44
- template="""
45
- Context: {context}
46
- Query: {query}
47
- List and describe the most effective tools and software used for monitoring network traffic. Include tools for real-time analysis, anomaly detection, and reporting.
48
- """
49
- ),
50
- "vulnerability_assessments": PromptTemplate(
51
- input_variables=["query", "context"],
52
- template="""
53
- Context: {context}
54
- Query: {query}
55
- Provide best practices for conducting vulnerability assessments and penetration tests, including recommended frequencies and methodologies to ensure systems are adequately tested for vulnerabilities.
56
- """
57
- ),
58
- "security_policies": PromptTemplate(
59
- input_variables=["query", "context"],
60
- template="""
61
- Context: {context}
62
- Query: {query}
63
- Explain the role cybersecurity professionals have in developing, updating, and enforcing security policies within an organization. Include considerations for evolving threats and compliance requirements.
64
- """
65
- ),
66
- "staying_updated": PromptTemplate(
67
- input_variables=["query", "context"],
68
- template="""
69
- Context: {context}
70
- Query: {query}
71
- Describe the methods and tools cybersecurity professionals use to stay up-to-date on the latest cybersecurity threats, trends, and vulnerabilities, including ongoing education and industry resources.
72
- """
73
- ),
74
- "immediate_incidents": PromptTemplate(
75
- input_variables=["query", "context"],
76
- template="""
77
- Context: {context}
78
- Query: {query}
79
- Identify and describe the types of cybersecurity incidents that require immediate attention, such as data breaches, malware attacks, and denial-of-service attacks. Provide guidance on how to respond to each incident type.
80
- """
81
- ),
82
- "collaboration_it_teams": PromptTemplate(
83
- input_variables=["query", "context"],
84
- template="""
85
- Context: {context}
86
- Query: {query}
87
- Discuss how cybersecurity professionals work with IT teams to ensure system security, focusing on areas such as patch management, incident response, and ongoing risk management.
88
- """
89
- ),
90
- "incident_investigation": PromptTemplate(
91
- input_variables=["query", "context"],
92
- template="""
93
- Context: {context}
94
- Query: {query}
95
- Outline the steps involved in investigating and resolving a security incident, including initial detection, containment, root cause analysis, and reporting.
96
- """
97
- ),
98
- "securing_remote_workers": PromptTemplate(
99
- input_variables=["query", "context"],
100
- template="""
101
- Context: {context}
102
- Query: {query}
103
- Provide strategies for securing remote workers and their devices, including the use of VPNs, multi-factor authentication, and endpoint protection measures.
104
- """
105
- ),
106
- "disaster_recovery": PromptTemplate(
107
- input_variables=["query", "context"],
108
- template="""
109
- Context: {context}
110
- Query: {query}
111
- Explain the responsibilities of cybersecurity professionals in ensuring that disaster recovery and business continuity plans are developed, tested, and maintained to address security challenges.
112
- """
113
- ),
114
- "user_access_management": PromptTemplate(
115
- input_variables=["query", "context"],
116
- template="""
117
- Context: {context}
118
- Query: {query}
119
- Describe the best practices for managing user access and privileges, including role-based access control (RBAC), least privilege principles, and audit trails for sensitive systems.
120
- """
121
- ),
122
- "cloud_security": PromptTemplate(
123
- input_variables=["query", "context"],
124
- template="""
125
- Context: {context}
126
- Query: {query}
127
- Provide a list of best practices for securing cloud-based infrastructure, including the use of strong authentication, data encryption, and continuous monitoring.
128
- """
129
- ),
130
- "security_kpis": PromptTemplate(
131
- input_variables=["query", "context"],
132
- template="""
133
- Context: {context}
134
- Query: {query}
135
- Discuss the key performance indicators (KPIs) used by cybersecurity professionals to measure the effectiveness of security programs, such as incident response times, patching cycles, and vulnerability remediation rates.
136
- """
137
- ),
138
- "employee_security_education": PromptTemplate(
139
- input_variables=["query", "context"],
140
- template="""
141
- Context: {context}
142
- Query: {query}
143
- Describe the methods used by cybersecurity professionals to educate employees on security best practices, including training programs, phishing simulations, and awareness campaigns.
144
- """
145
- ),
146
- "common_challenges": PromptTemplate(
147
- input_variables=["query", "context"],
148
- template="""
149
- Context: {context}
150
- Query: {query}
151
- Identify and discuss the common challenges that cybersecurity professionals face, including resource limitations, evolving threats, and the complexities of compliance.
152
- """
153
- ),
154
- "compliance_standards": PromptTemplate(
155
- input_variables=["query", "context"],
156
- template="""
157
- Context: {context}
158
- Query: {query}
159
- Provide an overview of how cybersecurity professionals ensure compliance with industry standards and regulations, such as GDPR, HIPAA, and PCI DSS, including regular audits and reporting.
160
- """
161
- ),
162
- "encryption_role": PromptTemplate(
163
- input_variables=["query", "context"],
164
- template="""
165
- Context: {context}
166
- Query: {query}
167
- Explain the role of encryption in protecting sensitive data, focusing on encryption methods, data-at-rest vs. data-in-transit, and how encryption helps mitigate the risks of data breaches.
168
- """
169
- ),
170
- "mobile_device_security": PromptTemplate(
171
- input_variables=["query", "context"],
172
- template="""
173
- Context: {context}
174
- Query: {query}
175
- Provide strategies for managing and securing mobile devices and applications, including mobile device management (MDM), app whitelisting, and secure communication methods.
176
- """
177
- ),
178
- "security_audits": PromptTemplate(
179
- input_variables=["query", "context"],
180
- template="""
181
- Context: {context}
182
- Query: {query}
183
- Outline the steps involved in conducting security audits and risk assessments, including identifying potential threats, assessing vulnerabilities, and recommending mitigation strategies.
184
- """
185
- ),
186
- "patch_management": PromptTemplate(
187
- input_variables=["query", "context"],
188
- template="""
189
- Context: {context}
190
- Query: {query}
191
- Describe the best practices for managing patch updates and ensuring software security, including patch management policies, vulnerability scanning, and prioritizing patches based on risk.
192
- """
193
- ),
194
- "wireless_iot_security": PromptTemplate(
195
- input_variables=["query", "context"],
196
- template="""
197
- Context: {context}
198
- Query: {query}
199
- Provide a comprehensive guide on securing wireless networks and IoT devices, including the use of encryption, network segmentation, and regular vulnerability assessments.
200
- """
201
- ),
202
  "general": PromptTemplate(
203
  input_variables=["query", "context"],
204
  template="""
@@ -207,77 +52,85 @@ prompt_templates = {
207
  Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context.
208
  """
209
  ),
 
210
  }
211
 
212
  # Initialize chains for each prompt
213
  chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}
214
 
215
- # Initialize the zero-shot classifier
216
- classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
217
-
218
- # Define the possible question types (labels) based on your prompt templates
219
- question_types = list(prompt_templates.keys())
 
 
 
220
 
221
- # Classifier function using the model
222
- def classify_query(query: str) -> str:
223
- """
224
- Classify the query using a zero-shot classification model.
225
- Returns the most likely question type from the prompt templates.
226
- """
227
  try:
228
- # Perform zero-shot classification
229
- result = classifier(query, candidate_labels=question_types)
230
-
231
- # Get the label with the highest score
232
- predicted_type = result["labels"][0]
233
- confidence = result["scores"][0]
234
-
235
- # If confidence is too low (e.g., < 0.5), fallback to 'general'
236
- if confidence < 0.5:
237
- print(f"Low confidence ({confidence}) for query '{query}', falling back to 'general'")
238
- return "general"
239
-
240
- return predicted_type
241
  except Exception as e:
242
- print(f"Error in classification: {e}")
243
- return "general" # Fallback to general in case of errors
244
 
245
  @app.post("/search")
246
  async def process_search(search_query: SearchQuery):
247
  try:
248
  # Set default context if not provided
249
- context = search_query.context or "You are a cybersecurity expert."
 
 
 
 
 
250
 
251
- # Classify the query using the model
252
- query_type = classify_query(search_query.query)
253
 
254
- # Process the query using the appropriate chain
255
- if query_type in chains:
256
- raw_response = chains[query_type].run(query=search_query.query, context=context)
257
- else:
258
- raw_response = chains["general"].run(query=search_query.query, context=context)
259
 
260
  # Structure the response according to the desired format
261
  structured_response = {
262
  "Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.",
263
- "Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{context}', guiding the response to align with cybersecurity expertise.",
264
- "Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(), # The raw response from Grok, cleaned up
265
  "Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity."
266
  }
267
 
 
 
 
268
  return {
269
  "status": "success",
270
  "response": structured_response,
271
- "classified_type": query_type # Optional: return the classified type for debugging
 
 
 
 
 
 
 
 
 
 
 
272
  }
273
  except Exception as e:
274
  raise HTTPException(status_code=500, detail=str(e))
275
 
276
  @app.get("/")
277
  async def root():
278
- return {"message": "Search API with structured response is running"}
279
 
280
- # Run the app (optional, for local testing)
281
  if __name__ == "__main__":
282
  import uvicorn
283
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
3
  from langchain_groq import ChatGroq
4
  from langchain.chains import LLMChain
5
  from langchain.prompts import PromptTemplate
6
+ from supabase import create_client, Client
7
+ from datetime import datetime
8
+ from typing import List, Dict
9
 
10
  # Initialize FastAPI app
11
  app = FastAPI()
12
 
13
+ # Supabase setup (replace with your Supabase URL and key)
14
+ SUPABASE_URL = "https://ykkbxlbonywjmvbyfvwo.supabase.co"
15
+ SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Inlra2J4bGJvbnl3am12YnlmdndvIiwicm9sZSI6ImFub24iLCJpYXQiOjE3Mzk5NTA2NjIsImV4cCI6MjA1NTUyNjY2Mn0.2BZul_igHKmZtQGhbwV3PvRsCikxviL8ogTKPD3XhuU"
16
+ supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
17
+
18
+ # Create a request model with context and user_id
19
  class SearchQuery(BaseModel):
20
  query: str
21
  context: str = None # Optional context field
22
+ user_id: str # Required to identify the user for storing history
23
+
24
+ # Create a response model for history
25
+ class ConversationHistory(BaseModel):
26
+ query: str
27
+ response: Dict
28
+ timestamp: str
29
 
30
  # Initialize LangChain with Groq
31
  llm = ChatGroq(
 
34
  groq_api_key="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" # Replace with your actual Groq API key
35
  )
36
 
37
+ # Define prompt templates (keeping all for future flexibility, but defaulting to "general")
38
  prompt_templates = {
39
  "common_threats": PromptTemplate(
40
  input_variables=["query", "context"],
 
44
  Provide a comprehensive overview of the most common cybersecurity threats faced by organizations on a daily basis, including details on threat_1, threat_2, and threat_3. Also, provide effective measures to mitigate these risks and protect critical data and systems.
45
  """
46
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  "general": PromptTemplate(
48
  input_variables=["query", "context"],
49
  template="""
 
52
  Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context.
53
  """
54
  ),
55
+ # You can keep other templates here if you want to manually select them later
56
  }
57
 
58
  # Initialize chains for each prompt
59
  chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}
60
 
61
+ # Helper function to get conversation history for a user
62
+ def get_conversation_history(user_id: str) -> List[Dict]:
63
+ try:
64
+ response = supabase.table("conversation_history").select("*").eq("user_id", user_id).order("timestamp", desc=True).execute()
65
+ return response.data
66
+ except Exception as e:
67
+ print(f"Error retrieving history: {e}")
68
+ return []
69
 
70
+ # Helper function to save conversation to Supabase
71
+ def save_conversation(user_id: str, query: str, response: Dict):
 
 
 
 
72
  try:
73
+ conversation = {
74
+ "user_id": user_id,
75
+ "query": query,
76
+ "response": response,
77
+ "timestamp": datetime.utcnow().isoformat()
78
+ }
79
+ supabase.table("conversation_history").insert(conversation).execute()
 
 
 
 
 
 
80
  except Exception as e:
81
+ print(f"Error saving conversation: {e}")
 
82
 
83
  @app.post("/search")
84
  async def process_search(search_query: SearchQuery):
85
  try:
86
  # Set default context if not provided
87
+ base_context = search_query.context or "You are a cybersecurity expert."
88
+
89
+ # Retrieve previous conversation history for context
90
+ history = get_conversation_history(search_query.user_id)
91
+ history_context = "\n".join([f"Previous Query: {item['query']}\nPrevious Response: {item['response']['Use Clear Language: Avoid ambiguity and complex wording']}" for item in history])
92
+ full_context = f"{base_context}\n{history_context}" if history_context else base_context
93
 
94
+ # Default to the "general" prompt template (no classification)
95
+ query_type = "general"
96
 
97
+ # Process the query using the general chain
98
+ raw_response = chains[query_type].run(query=search_query.query, context=full_context)
 
 
 
99
 
100
  # Structure the response according to the desired format
101
  structured_response = {
102
  "Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.",
103
+ "Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{full_context}', guiding the response to align with cybersecurity expertise.",
104
+ "Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(),
105
  "Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity."
106
  }
107
 
108
+ # Save the conversation to Supabase
109
+ save_conversation(search_query.user_id, search_query.query, structured_response)
110
+
111
  return {
112
  "status": "success",
113
  "response": structured_response,
114
+ "classified_type": query_type
115
+ }
116
+ except Exception as e:
117
+ raise HTTPException(status_code=500, detail=str(e))
118
+
119
+ @app.get("/history/{user_id}")
120
+ async def get_history(user_id: str):
121
+ try:
122
+ history = get_conversation_history(user_id)
123
+ return {
124
+ "status": "success",
125
+ "history": history
126
  }
127
  except Exception as e:
128
  raise HTTPException(status_code=500, detail=str(e))
129
 
130
  @app.get("/")
131
  async def root():
132
+ return {"message": "Search API with structured response and history is running"}
133
 
 
134
  if __name__ == "__main__":
135
  import uvicorn
136
  uvicorn.run(app, host="0.0.0.0", port=8000)