rajrakeshdr commited on
Commit
dc842ab
·
verified ·
1 Parent(s): c908d2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -56
app.py CHANGED
@@ -3,6 +3,7 @@ from pydantic import BaseModel
3
  from langchain_groq import ChatGroq
4
  from langchain.chains import LLMChain
5
  from langchain.prompts import PromptTemplate
 
6
  import os
7
 
8
  # Initialize FastAPI app
@@ -22,7 +23,6 @@ llm = ChatGroq(
22
 
23
  # Define all prompt templates
24
  prompt_templates = {
25
- # Cybersecurity Threats and Mitigation
26
  "common_threats": PromptTemplate(
27
  input_variables=["query", "context"],
28
  template="""
@@ -199,7 +199,6 @@ prompt_templates = {
199
  Provide a comprehensive guide on securing wireless networks and IoT devices, including the use of encryption, network segmentation, and regular vulnerability assessments.
200
  """
201
  ),
202
- # General Prompt
203
  "general": PromptTemplate(
204
  input_variables=["query", "context"],
205
  template="""
@@ -213,56 +212,35 @@ prompt_templates = {
213
  # Initialize chains for each prompt
214
  chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}
215
 
216
- # Classify user input to determine the appropriate prompt
 
 
 
 
 
 
217
  def classify_query(query: str) -> str:
218
- # Simple keyword-based classification
219
- if "common threats" in query.lower():
220
- return "common_threats"
221
- elif "task prioritization" in query.lower():
222
- return "task_prioritization"
223
- elif "network traffic tools" in query.lower():
224
- return "network_traffic_tools"
225
- elif "vulnerability assessments" in query.lower():
226
- return "vulnerability_assessments"
227
- elif "security policies" in query.lower():
228
- return "security_policies"
229
- elif "staying updated" in query.lower():
230
- return "staying_updated"
231
- elif "immediate incidents" in query.lower():
232
- return "immediate_incidents"
233
- elif "collaboration with IT teams" in query.lower():
234
- return "collaboration_it_teams"
235
- elif "incident investigation" in query.lower():
236
- return "incident_investigation"
237
- elif "securing remote workers" in query.lower():
238
- return "securing_remote_workers"
239
- elif "disaster recovery" in query.lower():
240
- return "disaster_recovery"
241
- elif "user access management" in query.lower():
242
- return "user_access_management"
243
- elif "cloud security" in query.lower():
244
- return "cloud_security"
245
- elif "security KPIs" in query.lower():
246
- return "security_kpis"
247
- elif "employee security education" in query.lower():
248
- return "employee_security_education"
249
- elif "common challenges" in query.lower():
250
- return "common_challenges"
251
- elif "compliance standards" in query.lower():
252
- return "compliance_standards"
253
- elif "encryption role" in query.lower():
254
- return "encryption_role"
255
- elif "mobile device security" in query.lower():
256
- return "mobile_device_security"
257
- elif "security audits" in query.lower():
258
- return "security_audits"
259
- elif "patch management" in query.lower():
260
- return "patch_management"
261
- elif "wireless and IoT security" in query.lower():
262
- return "wireless_iot_security"
263
- # Default to the general prompt
264
- else:
265
- return "general"
266
 
267
  @app.post("/search")
268
  async def process_search(search_query: SearchQuery):
@@ -270,22 +248,36 @@ async def process_search(search_query: SearchQuery):
270
  # Set default context if not provided
271
  context = search_query.context or "You are a cybersecurity expert."
272
 
273
- # Classify the query
274
  query_type = classify_query(search_query.query)
275
 
276
  # Process the query using the appropriate chain
277
  if query_type in chains:
278
- response = chains[query_type].run(query=search_query.query, context=context)
279
  else:
280
- response = chains["general"].run(query=search_query.query, context=context)
 
 
 
 
 
 
 
 
281
 
282
  return {
283
  "status": "success",
284
- "response": response
 
285
  }
286
  except Exception as e:
287
  raise HTTPException(status_code=500, detail=str(e))
288
 
289
  @app.get("/")
290
  async def root():
291
- return {"message": "Search API is running"}
 
 
 
 
 
 
3
  from langchain_groq import ChatGroq
4
  from langchain.chains import LLMChain
5
  from langchain.prompts import PromptTemplate
6
+ from transformers import pipeline
7
  import os
8
 
9
  # Initialize FastAPI app
 
23
 
24
  # Define all prompt templates
25
  prompt_templates = {
 
26
  "common_threats": PromptTemplate(
27
  input_variables=["query", "context"],
28
  template="""
 
199
  Provide a comprehensive guide on securing wireless networks and IoT devices, including the use of encryption, network segmentation, and regular vulnerability assessments.
200
  """
201
  ),
 
202
  "general": PromptTemplate(
203
  input_variables=["query", "context"],
204
  template="""
 
212
  # Initialize chains for each prompt
213
  chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}
214
 
215
+ # Initialize the zero-shot classifier
216
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
217
+
218
+ # Define the possible question types (labels) based on your prompt templates
219
+ question_types = list(prompt_templates.keys())
220
+
221
+ # Classifier function using the model
222
  def classify_query(query: str) -> str:
223
+ """
224
+ Classify the query using a zero-shot classification model.
225
+ Returns the most likely question type from the prompt templates.
226
+ """
227
+ try:
228
+ # Perform zero-shot classification
229
+ result = classifier(query, candidate_labels=question_types)
230
+
231
+ # Get the label with the highest score
232
+ predicted_type = result["labels"][0]
233
+ confidence = result["scores"][0]
234
+
235
+ # If confidence is too low (e.g., < 0.5), fallback to 'general'
236
+ if confidence < 0.5:
237
+ print(f"Low confidence ({confidence}) for query '{query}', falling back to 'general'")
238
+ return "general"
239
+
240
+ return predicted_type
241
+ except Exception as e:
242
+ print(f"Error in classification: {e}")
243
+ return "general" # Fallback to general in case of errors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  @app.post("/search")
246
  async def process_search(search_query: SearchQuery):
 
248
  # Set default context if not provided
249
  context = search_query.context or "You are a cybersecurity expert."
250
 
251
+ # Classify the query using the model
252
  query_type = classify_query(search_query.query)
253
 
254
  # Process the query using the appropriate chain
255
  if query_type in chains:
256
+ raw_response = chains[query_type].run(query=search_query.query, context=context)
257
  else:
258
+ raw_response = chains["general"].run(query=search_query.query, context=context)
259
+
260
+ # Structure the response according to the desired format
261
+ structured_response = {
262
+ "Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.",
263
+ "Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{context}', guiding the response to align with cybersecurity expertise.",
264
+ "Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(), # The raw response from Grok, cleaned up
265
+ "Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity."
266
+ }
267
 
268
  return {
269
  "status": "success",
270
+ "response": structured_response,
271
+ "classified_type": query_type # Optional: return the classified type for debugging
272
  }
273
  except Exception as e:
274
  raise HTTPException(status_code=500, detail=str(e))
275
 
276
  @app.get("/")
277
  async def root():
278
+ return {"message": "Search API with structured response is running"}
279
+
280
+ # Run the app (optional, for local testing)
281
+ if __name__ == "__main__":
282
+ import uvicorn
283
+ uvicorn.run(app, host="0.0.0.0", port=8000)