Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ from pydantic import BaseModel
|
|
3 |
from langchain_groq import ChatGroq
|
4 |
from langchain.chains import LLMChain
|
5 |
from langchain.prompts import PromptTemplate
|
|
|
6 |
import os
|
7 |
|
8 |
# Initialize FastAPI app
|
@@ -22,7 +23,6 @@ llm = ChatGroq(
|
|
22 |
|
23 |
# Define all prompt templates
|
24 |
prompt_templates = {
|
25 |
-
# Cybersecurity Threats and Mitigation
|
26 |
"common_threats": PromptTemplate(
|
27 |
input_variables=["query", "context"],
|
28 |
template="""
|
@@ -199,7 +199,6 @@ prompt_templates = {
|
|
199 |
Provide a comprehensive guide on securing wireless networks and IoT devices, including the use of encryption, network segmentation, and regular vulnerability assessments.
|
200 |
"""
|
201 |
),
|
202 |
-
# General Prompt
|
203 |
"general": PromptTemplate(
|
204 |
input_variables=["query", "context"],
|
205 |
template="""
|
@@ -213,56 +212,35 @@ prompt_templates = {
|
|
213 |
# Initialize chains for each prompt
|
214 |
chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}
|
215 |
|
216 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
def classify_query(query: str) -> str:
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
return "
|
239 |
-
elif "disaster recovery" in query.lower():
|
240 |
-
return "disaster_recovery"
|
241 |
-
elif "user access management" in query.lower():
|
242 |
-
return "user_access_management"
|
243 |
-
elif "cloud security" in query.lower():
|
244 |
-
return "cloud_security"
|
245 |
-
elif "security KPIs" in query.lower():
|
246 |
-
return "security_kpis"
|
247 |
-
elif "employee security education" in query.lower():
|
248 |
-
return "employee_security_education"
|
249 |
-
elif "common challenges" in query.lower():
|
250 |
-
return "common_challenges"
|
251 |
-
elif "compliance standards" in query.lower():
|
252 |
-
return "compliance_standards"
|
253 |
-
elif "encryption role" in query.lower():
|
254 |
-
return "encryption_role"
|
255 |
-
elif "mobile device security" in query.lower():
|
256 |
-
return "mobile_device_security"
|
257 |
-
elif "security audits" in query.lower():
|
258 |
-
return "security_audits"
|
259 |
-
elif "patch management" in query.lower():
|
260 |
-
return "patch_management"
|
261 |
-
elif "wireless and IoT security" in query.lower():
|
262 |
-
return "wireless_iot_security"
|
263 |
-
# Default to the general prompt
|
264 |
-
else:
|
265 |
-
return "general"
|
266 |
|
267 |
@app.post("/search")
|
268 |
async def process_search(search_query: SearchQuery):
|
@@ -270,22 +248,36 @@ async def process_search(search_query: SearchQuery):
|
|
270 |
# Set default context if not provided
|
271 |
context = search_query.context or "You are a cybersecurity expert."
|
272 |
|
273 |
-
# Classify the query
|
274 |
query_type = classify_query(search_query.query)
|
275 |
|
276 |
# Process the query using the appropriate chain
|
277 |
if query_type in chains:
|
278 |
-
|
279 |
else:
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
return {
|
283 |
"status": "success",
|
284 |
-
"response":
|
|
|
285 |
}
|
286 |
except Exception as e:
|
287 |
raise HTTPException(status_code=500, detail=str(e))
|
288 |
|
289 |
@app.get("/")
|
290 |
async def root():
|
291 |
-
return {"message": "Search API is running"}
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from langchain_groq import ChatGroq
|
4 |
from langchain.chains import LLMChain
|
5 |
from langchain.prompts import PromptTemplate
|
6 |
+
from transformers import pipeline
|
7 |
import os
|
8 |
|
9 |
# Initialize FastAPI app
|
|
|
23 |
|
24 |
# Define all prompt templates
|
25 |
prompt_templates = {
|
|
|
26 |
"common_threats": PromptTemplate(
|
27 |
input_variables=["query", "context"],
|
28 |
template="""
|
|
|
199 |
Provide a comprehensive guide on securing wireless networks and IoT devices, including the use of encryption, network segmentation, and regular vulnerability assessments.
|
200 |
"""
|
201 |
),
|
|
|
202 |
"general": PromptTemplate(
|
203 |
input_variables=["query", "context"],
|
204 |
template="""
|
|
|
212 |
# Initialize chains for each prompt
|
213 |
chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()}
|
214 |
|
215 |
+
# Initialize the zero-shot classifier
|
216 |
+
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
217 |
+
|
218 |
+
# Define the possible question types (labels) based on your prompt templates
|
219 |
+
question_types = list(prompt_templates.keys())
|
220 |
+
|
221 |
+
# Classifier function using the model
|
222 |
def classify_query(query: str) -> str:
|
223 |
+
"""
|
224 |
+
Classify the query using a zero-shot classification model.
|
225 |
+
Returns the most likely question type from the prompt templates.
|
226 |
+
"""
|
227 |
+
try:
|
228 |
+
# Perform zero-shot classification
|
229 |
+
result = classifier(query, candidate_labels=question_types)
|
230 |
+
|
231 |
+
# Get the label with the highest score
|
232 |
+
predicted_type = result["labels"][0]
|
233 |
+
confidence = result["scores"][0]
|
234 |
+
|
235 |
+
# If confidence is too low (e.g., < 0.5), fallback to 'general'
|
236 |
+
if confidence < 0.5:
|
237 |
+
print(f"Low confidence ({confidence}) for query '{query}', falling back to 'general'")
|
238 |
+
return "general"
|
239 |
+
|
240 |
+
return predicted_type
|
241 |
+
except Exception as e:
|
242 |
+
print(f"Error in classification: {e}")
|
243 |
+
return "general" # Fallback to general in case of errors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
@app.post("/search")
|
246 |
async def process_search(search_query: SearchQuery):
|
|
|
248 |
# Set default context if not provided
|
249 |
context = search_query.context or "You are a cybersecurity expert."
|
250 |
|
251 |
+
# Classify the query using the model
|
252 |
query_type = classify_query(search_query.query)
|
253 |
|
254 |
# Process the query using the appropriate chain
|
255 |
if query_type in chains:
|
256 |
+
raw_response = chains[query_type].run(query=search_query.query, context=context)
|
257 |
else:
|
258 |
+
raw_response = chains["general"].run(query=search_query.query, context=context)
|
259 |
+
|
260 |
+
# Structure the response according to the desired format
|
261 |
+
structured_response = {
|
262 |
+
"Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.",
|
263 |
+
"Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{context}', guiding the response to align with cybersecurity expertise.",
|
264 |
+
"Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(), # The raw response from Grok, cleaned up
|
265 |
+
"Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity."
|
266 |
+
}
|
267 |
|
268 |
return {
|
269 |
"status": "success",
|
270 |
+
"response": structured_response,
|
271 |
+
"classified_type": query_type # Optional: return the classified type for debugging
|
272 |
}
|
273 |
except Exception as e:
|
274 |
raise HTTPException(status_code=500, detail=str(e))
|
275 |
|
276 |
@app.get("/")
|
277 |
async def root():
|
278 |
+
return {"message": "Search API with structured response is running"}
|
279 |
+
|
280 |
+
# Run the app (optional, for local testing)
|
281 |
+
if __name__ == "__main__":
|
282 |
+
import uvicorn
|
283 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|