Spaces:
Running
Running
Replit Deployment
commited on
Commit
·
89ae94f
1
Parent(s):
94b87a8
Deployment from Replit
Browse files- hf_database.py +6 -1
- security_hf.py +16 -0
- src/__init__.py +3 -0
- src/api/__init__.py +3 -0
- src/api/auth.py +105 -0
- src/api/database.py +73 -0
- src/api/main.py +73 -0
- src/api/routers/auth_router.py +74 -0
- src/api/routers/scraping_router.py +161 -0
- src/api/routers/threats_router.py +217 -0
- src/api/schemas.py +310 -0
- src/api/security.py +382 -0
- src/api/services/__init__.py +3 -0
- src/api/services/alert_service.py +316 -0
- src/api/services/dark_web_content_service.py +357 -0
- src/api/services/report_service.py +436 -0
- src/api/services/search_history_service.py +609 -0
- src/api/services/subscription_service.py +681 -0
- src/api/services/threat_service.py +411 -0
- src/api/services/user_service.py +166 -0
hf_database.py
CHANGED
@@ -10,7 +10,12 @@ from sqlalchemy.orm import sessionmaker
|
|
10 |
from sqlalchemy.pool import StaticPool
|
11 |
from src.models.base import Base
|
12 |
from src.models.user import User
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# Configure logging
|
16 |
logging.basicConfig(level=logging.INFO)
|
|
|
10 |
from sqlalchemy.pool import StaticPool
|
11 |
from src.models.base import Base
|
12 |
from src.models.user import User
|
13 |
+
try:
|
14 |
+
# Try to import from src.api.security first (for local development)
|
15 |
+
from src.api.security import get_password_hash
|
16 |
+
except ImportError:
|
17 |
+
# Fall back to simplified security module for HF (copied during deployment)
|
18 |
+
from security_hf import get_password_hash
|
19 |
|
20 |
# Configure logging
|
21 |
logging.basicConfig(level=logging.INFO)
|
security_hf.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simplified security module for Hugging Face deployment.
|
3 |
+
Contains only the essential functions needed for HF deployment.
|
4 |
+
"""
|
5 |
+
from passlib.context import CryptContext
|
6 |
+
|
7 |
+
# Set up password hashing
|
8 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
9 |
+
|
10 |
+
def get_password_hash(password: str) -> str:
|
11 |
+
"""Hash a password for storage"""
|
12 |
+
return pwd_context.hash(password)
|
13 |
+
|
14 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
15 |
+
"""Verify a password against a hash"""
|
16 |
+
return pwd_context.verify(plain_password, hashed_password)
|
src/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Package initialization for src.
|
3 |
+
"""
|
src/api/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Package initialization for API.
|
3 |
+
"""
|
src/api/auth.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import Depends, HTTPException, status
|
2 |
+
from fastapi.security import OAuth2PasswordBearer
|
3 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
4 |
+
from jose import JWTError, jwt
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
from typing import Optional, Dict, Any
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
|
10 |
+
from src.api.database import get_db
|
11 |
+
from src.api.schemas import TokenData, UserInDB
|
12 |
+
from src.api.services.user_service import get_user_by_username
|
13 |
+
|
14 |
+
# Configure logger
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
# Constants for JWT
|
18 |
+
SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-for-jwt-please-change-in-production")
|
19 |
+
ALGORITHM = "HS256"
|
20 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
21 |
+
|
22 |
+
# OAuth2PasswordBearer for token extraction
|
23 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/token")
|
24 |
+
|
25 |
+
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
26 |
+
"""
|
27 |
+
Create a JWT access token.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
data: Dictionary of data to encode in the token
|
31 |
+
expires_delta: Optional expiration time delta
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
str: JWT token
|
35 |
+
"""
|
36 |
+
to_encode = data.copy()
|
37 |
+
|
38 |
+
if expires_delta:
|
39 |
+
expire = datetime.utcnow() + expires_delta
|
40 |
+
else:
|
41 |
+
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
42 |
+
|
43 |
+
to_encode.update({"exp": expire})
|
44 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
45 |
+
|
46 |
+
return encoded_jwt
|
47 |
+
|
48 |
+
async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> UserInDB:
|
49 |
+
"""
|
50 |
+
Get the current authenticated user based on the JWT token.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
token: JWT token
|
54 |
+
db: Database session
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
UserInDB: User data
|
58 |
+
|
59 |
+
Raises:
|
60 |
+
HTTPException: If authentication fails
|
61 |
+
"""
|
62 |
+
credentials_exception = HTTPException(
|
63 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
64 |
+
detail="Could not validate credentials",
|
65 |
+
headers={"WWW-Authenticate": "Bearer"},
|
66 |
+
)
|
67 |
+
|
68 |
+
try:
|
69 |
+
# Decode JWT
|
70 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
71 |
+
username: str = payload.get("sub")
|
72 |
+
|
73 |
+
if username is None:
|
74 |
+
raise credentials_exception
|
75 |
+
|
76 |
+
token_data = TokenData(username=username)
|
77 |
+
except JWTError as e:
|
78 |
+
logger.error(f"JWT error: {e}")
|
79 |
+
raise credentials_exception
|
80 |
+
|
81 |
+
# Get user from database
|
82 |
+
user = await get_user_by_username(db, username=token_data.username)
|
83 |
+
|
84 |
+
if user is None:
|
85 |
+
raise credentials_exception
|
86 |
+
|
87 |
+
return user
|
88 |
+
|
89 |
+
async def get_current_active_user(current_user: UserInDB = Depends(get_current_user)) -> UserInDB:
|
90 |
+
"""
|
91 |
+
Get the current active user.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
current_user: Current authenticated user
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
UserInDB: User data
|
98 |
+
|
99 |
+
Raises:
|
100 |
+
HTTPException: If user is inactive
|
101 |
+
"""
|
102 |
+
if not current_user.is_active:
|
103 |
+
raise HTTPException(status_code=400, detail="Inactive user")
|
104 |
+
|
105 |
+
return current_user
|
src/api/database.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Database configuration and setup for API.
|
3 |
+
"""
|
4 |
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
5 |
+
from sqlalchemy.orm import sessionmaker
|
6 |
+
import os
|
7 |
+
from typing import AsyncGenerator
|
8 |
+
import logging
|
9 |
+
|
10 |
+
# Configure logger
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
# Get database URL from environment (convert synchronous to async URL)
|
14 |
+
db_url = os.getenv("DATABASE_URL", "")
|
15 |
+
if db_url.startswith("postgresql://"):
|
16 |
+
# Remove sslmode parameter if present which causes issues with asyncpg
|
17 |
+
if "?" in db_url:
|
18 |
+
base_url, params = db_url.split("?", 1)
|
19 |
+
param_list = params.split("&")
|
20 |
+
filtered_params = [p for p in param_list if not p.startswith("sslmode=")]
|
21 |
+
if filtered_params:
|
22 |
+
db_url = f"{base_url}?{'&'.join(filtered_params)}"
|
23 |
+
else:
|
24 |
+
db_url = base_url
|
25 |
+
|
26 |
+
ASYNC_DATABASE_URL = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
27 |
+
else:
|
28 |
+
ASYNC_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
|
29 |
+
|
30 |
+
# Create async engine
|
31 |
+
engine = create_async_engine(
|
32 |
+
ASYNC_DATABASE_URL,
|
33 |
+
echo=False, # Set to True for debugging
|
34 |
+
future=True,
|
35 |
+
pool_size=5,
|
36 |
+
max_overflow=10
|
37 |
+
)
|
38 |
+
|
39 |
+
# Create async session factory
|
40 |
+
async_session = sessionmaker(
|
41 |
+
engine,
|
42 |
+
class_=AsyncSession,
|
43 |
+
expire_on_commit=False
|
44 |
+
)
|
45 |
+
|
46 |
+
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
47 |
+
"""
|
48 |
+
Get database session generator.
|
49 |
+
|
50 |
+
Yields:
|
51 |
+
AsyncSession: Database session
|
52 |
+
"""
|
53 |
+
session = async_session()
|
54 |
+
try:
|
55 |
+
yield session
|
56 |
+
await session.commit()
|
57 |
+
except Exception as e:
|
58 |
+
await session.rollback()
|
59 |
+
logger.error(f"Database error: {e}")
|
60 |
+
raise
|
61 |
+
finally:
|
62 |
+
await session.close()
|
63 |
+
|
64 |
+
# Dependency for getting DB session
|
65 |
+
async def get_db_session() -> AsyncSession:
|
66 |
+
"""
|
67 |
+
Get database session.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
AsyncSession: Database session
|
71 |
+
"""
|
72 |
+
async with async_session() as session:
|
73 |
+
return session
|
src/api/main.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Depends, HTTPException, status
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
import logging
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from src.api.database import get_db
|
7 |
+
from src.api.auth import get_current_user
|
8 |
+
from src.api.routers import threats_router, indicators_router, auth_router, admin_router
|
9 |
+
|
10 |
+
# Configure logging
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO,
|
13 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
14 |
+
handlers=[
|
15 |
+
logging.StreamHandler(),
|
16 |
+
logging.FileHandler("app.log")
|
17 |
+
]
|
18 |
+
)
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
# Create FastAPI app
|
22 |
+
app = FastAPI(
|
23 |
+
title="CyberForge OSINT API",
|
24 |
+
description="API for Dark Web OSINT platform",
|
25 |
+
version="1.0.0"
|
26 |
+
)
|
27 |
+
|
28 |
+
# Add CORS middleware
|
29 |
+
app.add_middleware(
|
30 |
+
CORSMiddleware,
|
31 |
+
allow_origins=["*"], # Update for production
|
32 |
+
allow_credentials=True,
|
33 |
+
allow_methods=["*"],
|
34 |
+
allow_headers=["*"],
|
35 |
+
)
|
36 |
+
|
37 |
+
# Include routers for different endpoints
|
38 |
+
app.include_router(auth_router.router, prefix="/api/v1")
|
39 |
+
app.include_router(
|
40 |
+
threats_router.router,
|
41 |
+
prefix="/api/v1/threats",
|
42 |
+
tags=["threats"],
|
43 |
+
dependencies=[Depends(get_current_user)]
|
44 |
+
)
|
45 |
+
app.include_router(
|
46 |
+
indicators_router.router,
|
47 |
+
prefix="/api/v1/indicators",
|
48 |
+
tags=["indicators"],
|
49 |
+
dependencies=[Depends(get_current_user)]
|
50 |
+
)
|
51 |
+
app.include_router(
|
52 |
+
admin_router.router,
|
53 |
+
prefix="/api/v1/admin",
|
54 |
+
tags=["admin"],
|
55 |
+
dependencies=[Depends(get_current_user)]
|
56 |
+
)
|
57 |
+
|
58 |
+
@app.get("/api/health")
|
59 |
+
async def health_check():
|
60 |
+
"""Health check endpoint for monitoring."""
|
61 |
+
return {"status": "healthy", "version": "1.0.0"}
|
62 |
+
|
63 |
+
@app.on_event("startup")
|
64 |
+
async def startup_event():
|
65 |
+
"""Event handler for application startup."""
|
66 |
+
logger.info("Starting the CyberForge OSINT API")
|
67 |
+
# Add any startup tasks here (database connection, cache warming, etc.)
|
68 |
+
|
69 |
+
@app.on_event("shutdown")
|
70 |
+
async def shutdown_event():
|
71 |
+
"""Event handler for application shutdown."""
|
72 |
+
logger.info("Shutting down the CyberForge OSINT API")
|
73 |
+
# Add any cleanup tasks here (close connections, save state, etc.)
|
src/api/routers/auth_router.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Authentication router.
|
3 |
+
|
4 |
+
This module provides authentication endpoints for the API.
|
5 |
+
"""
|
6 |
+
from datetime import datetime, timedelta
|
7 |
+
from typing import Any, Dict, Optional
|
8 |
+
|
9 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
10 |
+
from fastapi.security import OAuth2PasswordRequestForm
|
11 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
12 |
+
|
13 |
+
from src.api.database import get_db
|
14 |
+
from src.api.security import (
|
15 |
+
ACCESS_TOKEN_EXPIRE_MINUTES,
|
16 |
+
Token,
|
17 |
+
UserInDB,
|
18 |
+
authenticate_user,
|
19 |
+
create_access_token,
|
20 |
+
get_current_active_user,
|
21 |
+
)
|
22 |
+
|
23 |
+
router = APIRouter(tags=["authentication"])
|
24 |
+
|
25 |
+
@router.post("/token", response_model=Token)
|
26 |
+
async def login_for_access_token(
|
27 |
+
form_data: OAuth2PasswordRequestForm = Depends(),
|
28 |
+
db: AsyncSession = Depends(get_db)
|
29 |
+
) -> Dict[str, Any]:
|
30 |
+
"""
|
31 |
+
OAuth2 compatible token login, get an access token for future requests.
|
32 |
+
"""
|
33 |
+
user = await authenticate_user(db, form_data.username, form_data.password)
|
34 |
+
if not user:
|
35 |
+
raise HTTPException(
|
36 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
37 |
+
detail="Incorrect username or password",
|
38 |
+
headers={"WWW-Authenticate": "Bearer"},
|
39 |
+
)
|
40 |
+
|
41 |
+
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
42 |
+
access_token = create_access_token(
|
43 |
+
data={"sub": user.username, "scopes": user.scopes},
|
44 |
+
expires_delta=access_token_expires
|
45 |
+
)
|
46 |
+
|
47 |
+
expires_at = datetime.utcnow() + access_token_expires
|
48 |
+
|
49 |
+
return {
|
50 |
+
"access_token": access_token,
|
51 |
+
"token_type": "bearer",
|
52 |
+
"expires_at": expires_at
|
53 |
+
}
|
54 |
+
|
55 |
+
@router.get("/users/me", response_model=UserInDB)
|
56 |
+
async def read_users_me(
|
57 |
+
current_user: UserInDB = Depends(get_current_active_user)
|
58 |
+
) -> UserInDB:
|
59 |
+
"""
|
60 |
+
Get current user.
|
61 |
+
"""
|
62 |
+
return current_user
|
63 |
+
|
64 |
+
@router.get("/users/me/scopes")
|
65 |
+
async def read_own_scopes(
|
66 |
+
current_user: UserInDB = Depends(get_current_active_user)
|
67 |
+
) -> Dict[str, Any]:
|
68 |
+
"""
|
69 |
+
Get current user's scopes.
|
70 |
+
"""
|
71 |
+
return {
|
72 |
+
"username": current_user.username,
|
73 |
+
"scopes": current_user.scopes
|
74 |
+
}
|
src/api/routers/scraping_router.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Body, BackgroundTasks
|
2 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
3 |
+
from typing import List, Optional, Dict, Any
|
4 |
+
import logging
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
from src.api.database import get_db
|
8 |
+
from src.api.auth import get_current_user
|
9 |
+
from src.api.schemas import User, CrawlRequest, CrawlResult
|
10 |
+
from src.services.scraper import WebScraper, ScraperError
|
11 |
+
from src.services.tor_proxy import TorProxyService, TorProxyError
|
12 |
+
|
13 |
+
# Configure logger
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
router = APIRouter(
|
17 |
+
prefix="/scraping",
|
18 |
+
tags=["scraping"],
|
19 |
+
responses={404: {"description": "Not found"}}
|
20 |
+
)
|
21 |
+
|
22 |
+
# Initialize services
|
23 |
+
scraper = WebScraper()
|
24 |
+
|
25 |
+
@router.post("/test-tor", response_model=Dict[str, Any])
|
26 |
+
async def test_tor_connection(
|
27 |
+
current_user: User = Depends(get_current_user)
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Test Tor connection.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
current_user: Current authenticated user
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Dict[str, Any]: Connection status
|
37 |
+
"""
|
38 |
+
try:
|
39 |
+
tor_proxy = TorProxyService()
|
40 |
+
is_connected = await tor_proxy.check_connection()
|
41 |
+
|
42 |
+
return {
|
43 |
+
"status": "success",
|
44 |
+
"is_connected": is_connected,
|
45 |
+
"timestamp": datetime.utcnow().isoformat()
|
46 |
+
}
|
47 |
+
except TorProxyError as e:
|
48 |
+
logger.error(f"Tor proxy error: {e}")
|
49 |
+
raise HTTPException(
|
50 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
51 |
+
detail=f"Tor proxy error: {str(e)}"
|
52 |
+
)
|
53 |
+
except Exception as e:
|
54 |
+
logger.error(f"Error testing Tor connection: {e}")
|
55 |
+
raise HTTPException(
|
56 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
57 |
+
detail=f"An error occurred: {str(e)}"
|
58 |
+
)
|
59 |
+
|
60 |
+
@router.post("/scrape", response_model=Dict[str, Any])
|
61 |
+
async def scrape_page(
|
62 |
+
url: str,
|
63 |
+
use_tor: bool = Body(False),
|
64 |
+
current_user: User = Depends(get_current_user)
|
65 |
+
):
|
66 |
+
"""
|
67 |
+
Scrape a single page.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
url: URL to scrape
|
71 |
+
use_tor: Whether to use Tor proxy
|
72 |
+
current_user: Current authenticated user
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Dict[str, Any]: Scraped content
|
76 |
+
"""
|
77 |
+
try:
|
78 |
+
result = await scraper.extract_content(url, use_tor=use_tor)
|
79 |
+
|
80 |
+
return {
|
81 |
+
"status": "success",
|
82 |
+
"data": result,
|
83 |
+
"timestamp": datetime.utcnow().isoformat()
|
84 |
+
}
|
85 |
+
except ScraperError as e:
|
86 |
+
logger.error(f"Scraper error: {e}")
|
87 |
+
raise HTTPException(
|
88 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
89 |
+
detail=f"Scraper error: {str(e)}"
|
90 |
+
)
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Error scraping page: {e}")
|
93 |
+
raise HTTPException(
|
94 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
95 |
+
detail=f"An error occurred: {str(e)}"
|
96 |
+
)
|
97 |
+
|
98 |
+
@router.post("/crawl", response_model=Dict[str, Any])
|
99 |
+
async def crawl_site(
|
100 |
+
crawl_request: CrawlRequest,
|
101 |
+
background_tasks: BackgroundTasks,
|
102 |
+
current_user: User = Depends(get_current_user)
|
103 |
+
):
|
104 |
+
"""
|
105 |
+
Crawl a site.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
crawl_request: Crawl request data
|
109 |
+
background_tasks: Background tasks
|
110 |
+
current_user: Current authenticated user
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Dict[str, Any]: Crawl status
|
114 |
+
"""
|
115 |
+
# For longer crawls, we add them as background tasks
|
116 |
+
# This prevents timeouts on the API request
|
117 |
+
|
118 |
+
# Start crawl in background
|
119 |
+
if crawl_request.max_depth > 1 or '.onion' in crawl_request.url:
|
120 |
+
# Add to background tasks
|
121 |
+
background_tasks.add_task(
|
122 |
+
scraper.crawl,
|
123 |
+
crawl_request.url,
|
124 |
+
max_depth=crawl_request.max_depth,
|
125 |
+
max_pages=50,
|
126 |
+
keyword_filter=crawl_request.keywords
|
127 |
+
)
|
128 |
+
|
129 |
+
return {
|
130 |
+
"status": "started",
|
131 |
+
"message": "Crawl started in background",
|
132 |
+
"timestamp": datetime.utcnow().isoformat()
|
133 |
+
}
|
134 |
+
else:
|
135 |
+
# For simple crawls, we perform them synchronously
|
136 |
+
try:
|
137 |
+
results = await scraper.crawl(
|
138 |
+
crawl_request.url,
|
139 |
+
max_depth=crawl_request.max_depth,
|
140 |
+
max_pages=10,
|
141 |
+
keyword_filter=crawl_request.keywords
|
142 |
+
)
|
143 |
+
|
144 |
+
return {
|
145 |
+
"status": "completed",
|
146 |
+
"results": results,
|
147 |
+
"count": len(results),
|
148 |
+
"timestamp": datetime.utcnow().isoformat()
|
149 |
+
}
|
150 |
+
except ScraperError as e:
|
151 |
+
logger.error(f"Scraper error: {e}")
|
152 |
+
raise HTTPException(
|
153 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
154 |
+
detail=f"Scraper error: {str(e)}"
|
155 |
+
)
|
156 |
+
except Exception as e:
|
157 |
+
logger.error(f"Error crawling site: {e}")
|
158 |
+
raise HTTPException(
|
159 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
160 |
+
detail=f"An error occurred: {str(e)}"
|
161 |
+
)
|
src/api/routers/threats_router.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query
|
2 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
3 |
+
from typing import List, Optional
|
4 |
+
import logging
|
5 |
+
|
6 |
+
from src.api.database import get_db
|
7 |
+
from src.api.auth import get_current_user
|
8 |
+
from src.api.schemas import (
|
9 |
+
Threat, ThreatCreate, ThreatUpdate, ThreatFilter,
|
10 |
+
PaginationParams, User
|
11 |
+
)
|
12 |
+
from src.api.services.threat_service import (
|
13 |
+
create_threat, get_threat_by_id, update_threat,
|
14 |
+
delete_threat, get_threats, get_threat_statistics
|
15 |
+
)
|
16 |
+
|
17 |
+
# Configure logger
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
router = APIRouter(
|
21 |
+
tags=["threats"],
|
22 |
+
responses={404: {"description": "Not found"}}
|
23 |
+
)
|
24 |
+
|
25 |
+
@router.post("/", response_model=Threat, status_code=status.HTTP_201_CREATED)
|
26 |
+
async def create_threat_endpoint(
|
27 |
+
threat_data: ThreatCreate,
|
28 |
+
db: AsyncSession = Depends(get_db),
|
29 |
+
current_user: User = Depends(get_current_user)
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Create a new threat.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
threat_data: Threat data
|
36 |
+
db: Database session
|
37 |
+
current_user: Current authenticated user
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Threat: Created threat
|
41 |
+
"""
|
42 |
+
try:
|
43 |
+
threat = await create_threat(db, threat_data)
|
44 |
+
|
45 |
+
if not threat:
|
46 |
+
raise HTTPException(
|
47 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
48 |
+
detail="Failed to create threat"
|
49 |
+
)
|
50 |
+
|
51 |
+
return threat
|
52 |
+
except Exception as e:
|
53 |
+
logger.error(f"Error creating threat: {e}")
|
54 |
+
raise HTTPException(
|
55 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
56 |
+
detail=f"An error occurred: {str(e)}"
|
57 |
+
)
|
58 |
+
|
59 |
+
@router.get("/{threat_id}", response_model=Threat)
|
60 |
+
async def get_threat_endpoint(
|
61 |
+
threat_id: int,
|
62 |
+
db: AsyncSession = Depends(get_db),
|
63 |
+
current_user: User = Depends(get_current_user)
|
64 |
+
):
|
65 |
+
"""
|
66 |
+
Get threat by ID.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
threat_id: Threat ID
|
70 |
+
db: Database session
|
71 |
+
current_user: Current authenticated user
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
Threat: Threat data
|
75 |
+
"""
|
76 |
+
threat = await get_threat_by_id(db, threat_id)
|
77 |
+
|
78 |
+
if not threat:
|
79 |
+
raise HTTPException(
|
80 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
81 |
+
detail=f"Threat with ID {threat_id} not found"
|
82 |
+
)
|
83 |
+
|
84 |
+
return threat
|
85 |
+
|
86 |
+
@router.put("/{threat_id}", response_model=Threat)
|
87 |
+
async def update_threat_endpoint(
|
88 |
+
threat_id: int,
|
89 |
+
threat_data: ThreatUpdate,
|
90 |
+
db: AsyncSession = Depends(get_db),
|
91 |
+
current_user: User = Depends(get_current_user)
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
Update threat.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
threat_id: Threat ID
|
98 |
+
threat_data: Threat data
|
99 |
+
db: Database session
|
100 |
+
current_user: Current authenticated user
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Threat: Updated threat
|
104 |
+
"""
|
105 |
+
# Check if threat exists
|
106 |
+
threat = await get_threat_by_id(db, threat_id)
|
107 |
+
|
108 |
+
if not threat:
|
109 |
+
raise HTTPException(
|
110 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
111 |
+
detail=f"Threat with ID {threat_id} not found"
|
112 |
+
)
|
113 |
+
|
114 |
+
# Update threat
|
115 |
+
updated_threat = await update_threat(db, threat_id, threat_data)
|
116 |
+
|
117 |
+
if not updated_threat:
|
118 |
+
raise HTTPException(
|
119 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
120 |
+
detail="Failed to update threat"
|
121 |
+
)
|
122 |
+
|
123 |
+
return updated_threat
|
124 |
+
|
125 |
+
@router.delete("/{threat_id}", status_code=status.HTTP_204_NO_CONTENT)
|
126 |
+
async def delete_threat_endpoint(
|
127 |
+
threat_id: int,
|
128 |
+
db: AsyncSession = Depends(get_db),
|
129 |
+
current_user: User = Depends(get_current_user)
|
130 |
+
):
|
131 |
+
"""
|
132 |
+
Delete threat.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
threat_id: Threat ID
|
136 |
+
db: Database session
|
137 |
+
current_user: Current authenticated user
|
138 |
+
"""
|
139 |
+
# Check if threat exists
|
140 |
+
threat = await get_threat_by_id(db, threat_id)
|
141 |
+
|
142 |
+
if not threat:
|
143 |
+
raise HTTPException(
|
144 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
145 |
+
detail=f"Threat with ID {threat_id} not found"
|
146 |
+
)
|
147 |
+
|
148 |
+
# Delete threat
|
149 |
+
deleted = await delete_threat(db, threat_id)
|
150 |
+
|
151 |
+
if not deleted:
|
152 |
+
raise HTTPException(
|
153 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
154 |
+
detail="Failed to delete threat"
|
155 |
+
)
|
156 |
+
|
157 |
+
@router.get("/", response_model=List[Threat])
|
158 |
+
async def get_threats_endpoint(
|
159 |
+
pagination: PaginationParams = Depends(),
|
160 |
+
severity: Optional[List[str]] = Query(None),
|
161 |
+
status: Optional[List[str]] = Query(None),
|
162 |
+
category: Optional[List[str]] = Query(None),
|
163 |
+
search: Optional[str] = Query(None),
|
164 |
+
from_date: Optional[str] = Query(None),
|
165 |
+
to_date: Optional[str] = Query(None),
|
166 |
+
db: AsyncSession = Depends(get_db),
|
167 |
+
current_user: User = Depends(get_current_user)
|
168 |
+
):
|
169 |
+
"""
|
170 |
+
Get threats with filtering and pagination.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
pagination: Pagination parameters
|
174 |
+
severity: Filter by severity
|
175 |
+
status: Filter by status
|
176 |
+
category: Filter by category
|
177 |
+
search: Search in title and description
|
178 |
+
from_date: Filter from date
|
179 |
+
to_date: Filter to date
|
180 |
+
db: Database session
|
181 |
+
current_user: Current authenticated user
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
List[Threat]: List of threats
|
185 |
+
"""
|
186 |
+
# Create filter params
|
187 |
+
filter_params = ThreatFilter(
|
188 |
+
severity=severity,
|
189 |
+
status=status,
|
190 |
+
category=category,
|
191 |
+
search=search,
|
192 |
+
from_date=from_date,
|
193 |
+
to_date=to_date
|
194 |
+
)
|
195 |
+
|
196 |
+
# Get threats
|
197 |
+
threats, total = await get_threats(db, filter_params, pagination)
|
198 |
+
|
199 |
+
return threats
|
200 |
+
|
201 |
+
@router.get("/statistics", response_model=dict)
|
202 |
+
async def get_threat_statistics_endpoint(
|
203 |
+
db: AsyncSession = Depends(get_db),
|
204 |
+
current_user: User = Depends(get_current_user)
|
205 |
+
):
|
206 |
+
"""
|
207 |
+
Get threat statistics.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
db: Database session
|
211 |
+
current_user: Current authenticated user
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
dict: Threat statistics
|
215 |
+
"""
|
216 |
+
statistics = await get_threat_statistics(db)
|
217 |
+
return statistics
|
src/api/schemas.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
API schemas for data validation and serialization.
|
3 |
+
"""
|
4 |
+
from pydantic import BaseModel, Field, validator, EmailStr
|
5 |
+
from typing import Optional, List, Dict, Any, Union
|
6 |
+
from datetime import datetime
|
7 |
+
from enum import Enum
|
8 |
+
|
9 |
+
# Pagination
|
10 |
+
class PaginationParams(BaseModel):
|
11 |
+
"""Pagination parameters."""
|
12 |
+
page: int = Field(1, ge=1, description="Page number")
|
13 |
+
size: int = Field(10, ge=1, le=100, description="Items per page")
|
14 |
+
|
15 |
+
# User schemas
|
16 |
+
class UserBase(BaseModel):
|
17 |
+
"""Base user schema."""
|
18 |
+
username: str
|
19 |
+
email: EmailStr
|
20 |
+
full_name: Optional[str] = None
|
21 |
+
is_active: bool = True
|
22 |
+
|
23 |
+
class UserCreate(UserBase):
|
24 |
+
"""User creation schema."""
|
25 |
+
password: str
|
26 |
+
|
27 |
+
class UserUpdate(BaseModel):
|
28 |
+
"""User update schema."""
|
29 |
+
username: Optional[str] = None
|
30 |
+
email: Optional[EmailStr] = None
|
31 |
+
full_name: Optional[str] = None
|
32 |
+
is_active: Optional[bool] = None
|
33 |
+
password: Optional[str] = None
|
34 |
+
|
35 |
+
class UserResponse(UserBase):
|
36 |
+
"""User response schema."""
|
37 |
+
id: int
|
38 |
+
is_superuser: bool = False
|
39 |
+
|
40 |
+
class Config:
|
41 |
+
orm_mode = True
|
42 |
+
|
43 |
+
# Token schemas
|
44 |
+
class Token(BaseModel):
|
45 |
+
"""Token schema."""
|
46 |
+
access_token: str
|
47 |
+
token_type: str = "bearer"
|
48 |
+
|
49 |
+
class TokenPayload(BaseModel):
|
50 |
+
"""Token payload schema."""
|
51 |
+
sub: Optional[int] = None
|
52 |
+
|
53 |
+
# Dark Web Content schemas
|
54 |
+
class DarkWebContentBase(BaseModel):
|
55 |
+
"""Base schema for dark web content."""
|
56 |
+
url: str
|
57 |
+
title: Optional[str] = None
|
58 |
+
content: str
|
59 |
+
content_type: str
|
60 |
+
source_name: Optional[str] = None
|
61 |
+
source_type: Optional[str] = None
|
62 |
+
language: Optional[str] = None
|
63 |
+
|
64 |
+
class DarkWebContentCreate(DarkWebContentBase):
|
65 |
+
"""Schema for creating dark web content."""
|
66 |
+
relevance_score: float = 0.0
|
67 |
+
sentiment_score: float = 0.0
|
68 |
+
entity_data: Optional[str] = None
|
69 |
+
|
70 |
+
class DarkWebContentUpdate(BaseModel):
|
71 |
+
"""Schema for updating dark web content."""
|
72 |
+
title: Optional[str] = None
|
73 |
+
content_status: Optional[str] = None
|
74 |
+
relevance_score: Optional[float] = None
|
75 |
+
sentiment_score: Optional[float] = None
|
76 |
+
entity_data: Optional[str] = None
|
77 |
+
|
78 |
+
class DarkWebContentResponse(DarkWebContentBase):
|
79 |
+
"""Schema for dark web content response."""
|
80 |
+
id: int
|
81 |
+
domain: Optional[str] = None
|
82 |
+
content_status: str
|
83 |
+
scraped_at: datetime
|
84 |
+
relevance_score: float
|
85 |
+
sentiment_score: float
|
86 |
+
entity_data: Optional[str] = None
|
87 |
+
|
88 |
+
class Config:
|
89 |
+
orm_mode = True
|
90 |
+
|
91 |
+
# Dark Web Mention schemas
|
92 |
+
class DarkWebMentionBase(BaseModel):
|
93 |
+
"""Base schema for dark web mention."""
|
94 |
+
content_id: int
|
95 |
+
keyword: str
|
96 |
+
keyword_category: Optional[str] = None
|
97 |
+
context: Optional[str] = None
|
98 |
+
snippet: Optional[str] = None
|
99 |
+
mention_type: Optional[str] = None
|
100 |
+
|
101 |
+
class DarkWebMentionCreate(DarkWebMentionBase):
|
102 |
+
"""Schema for creating dark web mention."""
|
103 |
+
confidence: float = 0.0
|
104 |
+
is_verified: bool = False
|
105 |
+
|
106 |
+
class DarkWebMentionUpdate(BaseModel):
|
107 |
+
"""Schema for updating dark web mention."""
|
108 |
+
keyword_category: Optional[str] = None
|
109 |
+
mention_type: Optional[str] = None
|
110 |
+
confidence: Optional[float] = None
|
111 |
+
is_verified: Optional[bool] = None
|
112 |
+
|
113 |
+
class DarkWebMentionResponse(DarkWebMentionBase):
|
114 |
+
"""Schema for dark web mention response."""
|
115 |
+
id: int
|
116 |
+
confidence: float
|
117 |
+
is_verified: bool
|
118 |
+
created_at: datetime
|
119 |
+
|
120 |
+
class Config:
|
121 |
+
orm_mode = True
|
122 |
+
|
123 |
+
# Threat schemas
|
124 |
+
class ThreatBase(BaseModel):
|
125 |
+
"""Base schema for threat."""
|
126 |
+
title: str
|
127 |
+
description: str
|
128 |
+
severity: str
|
129 |
+
category: str
|
130 |
+
|
131 |
+
class ThreatCreate(ThreatBase):
|
132 |
+
"""Schema for creating threat."""
|
133 |
+
status: str = "New"
|
134 |
+
source_url: Optional[str] = None
|
135 |
+
source_name: Optional[str] = None
|
136 |
+
source_type: Optional[str] = None
|
137 |
+
affected_entity: Optional[str] = None
|
138 |
+
affected_entity_type: Optional[str] = None
|
139 |
+
confidence_score: float = 0.0
|
140 |
+
risk_score: float = 0.0
|
141 |
+
|
142 |
+
class ThreatUpdate(BaseModel):
|
143 |
+
"""Schema for updating threat."""
|
144 |
+
title: Optional[str] = None
|
145 |
+
description: Optional[str] = None
|
146 |
+
severity: Optional[str] = None
|
147 |
+
status: Optional[str] = None
|
148 |
+
category: Optional[str] = None
|
149 |
+
affected_entity: Optional[str] = None
|
150 |
+
affected_entity_type: Optional[str] = None
|
151 |
+
confidence_score: Optional[float] = None
|
152 |
+
risk_score: Optional[float] = None
|
153 |
+
|
154 |
+
class ThreatResponse(ThreatBase):
|
155 |
+
"""Schema for threat response."""
|
156 |
+
id: int
|
157 |
+
status: str
|
158 |
+
source_url: Optional[str] = None
|
159 |
+
source_name: Optional[str] = None
|
160 |
+
source_type: Optional[str] = None
|
161 |
+
discovered_at: datetime
|
162 |
+
affected_entity: Optional[str] = None
|
163 |
+
affected_entity_type: Optional[str] = None
|
164 |
+
confidence_score: float
|
165 |
+
risk_score: float
|
166 |
+
|
167 |
+
class Config:
|
168 |
+
orm_mode = True
|
169 |
+
|
170 |
+
# Indicator schemas
|
171 |
+
class IndicatorBase(BaseModel):
|
172 |
+
"""Base schema for indicator."""
|
173 |
+
threat_id: int
|
174 |
+
value: str
|
175 |
+
indicator_type: str
|
176 |
+
description: Optional[str] = None
|
177 |
+
|
178 |
+
class IndicatorCreate(IndicatorBase):
|
179 |
+
"""Schema for creating indicator."""
|
180 |
+
is_verified: bool = False
|
181 |
+
context: Optional[str] = None
|
182 |
+
source: Optional[str] = None
|
183 |
+
confidence_score: float = 0.0
|
184 |
+
|
185 |
+
class IndicatorUpdate(BaseModel):
|
186 |
+
"""Schema for updating indicator."""
|
187 |
+
description: Optional[str] = None
|
188 |
+
is_verified: Optional[bool] = None
|
189 |
+
context: Optional[str] = None
|
190 |
+
source: Optional[str] = None
|
191 |
+
confidence_score: Optional[float] = None
|
192 |
+
|
193 |
+
class IndicatorResponse(IndicatorBase):
|
194 |
+
"""Schema for indicator response."""
|
195 |
+
id: int
|
196 |
+
is_verified: bool
|
197 |
+
context: Optional[str] = None
|
198 |
+
source: Optional[str] = None
|
199 |
+
confidence_score: float
|
200 |
+
first_seen: datetime
|
201 |
+
last_seen: datetime
|
202 |
+
|
203 |
+
class Config:
|
204 |
+
orm_mode = True
|
205 |
+
|
206 |
+
# Alert schemas
|
207 |
+
class AlertBase(BaseModel):
|
208 |
+
"""Base schema for alert."""
|
209 |
+
title: str
|
210 |
+
description: str
|
211 |
+
severity: str
|
212 |
+
category: str
|
213 |
+
|
214 |
+
class AlertCreate(AlertBase):
|
215 |
+
"""Schema for creating alert."""
|
216 |
+
source_url: Optional[str] = None
|
217 |
+
threat_id: Optional[int] = None
|
218 |
+
mention_id: Optional[int] = None
|
219 |
+
|
220 |
+
class AlertUpdate(BaseModel):
|
221 |
+
"""Schema for updating alert."""
|
222 |
+
status: str
|
223 |
+
action_taken: Optional[str] = None
|
224 |
+
assigned_to_id: Optional[int] = None
|
225 |
+
is_read: Optional[bool] = None
|
226 |
+
|
227 |
+
class AlertResponse(AlertBase):
|
228 |
+
"""Schema for alert response."""
|
229 |
+
id: int
|
230 |
+
status: str
|
231 |
+
generated_at: datetime
|
232 |
+
source_url: Optional[str] = None
|
233 |
+
is_read: bool
|
234 |
+
threat_id: Optional[int] = None
|
235 |
+
mention_id: Optional[int] = None
|
236 |
+
assigned_to_id: Optional[int] = None
|
237 |
+
action_taken: Optional[str] = None
|
238 |
+
resolved_at: Optional[datetime] = None
|
239 |
+
|
240 |
+
class Config:
|
241 |
+
orm_mode = True
|
242 |
+
|
243 |
+
# Report schemas
|
244 |
+
class ReportBase(BaseModel):
|
245 |
+
"""Base schema for report."""
|
246 |
+
report_id: str
|
247 |
+
title: str
|
248 |
+
summary: str
|
249 |
+
content: str
|
250 |
+
report_type: str
|
251 |
+
|
252 |
+
class ReportCreate(ReportBase):
|
253 |
+
"""Schema for creating report."""
|
254 |
+
status: str = "Draft"
|
255 |
+
severity: Optional[str] = None
|
256 |
+
publish_date: Optional[datetime] = None
|
257 |
+
time_period_start: Optional[datetime] = None
|
258 |
+
time_period_end: Optional[datetime] = None
|
259 |
+
keywords: Optional[str] = None
|
260 |
+
author_id: int
|
261 |
+
threat_ids: List[int] = []
|
262 |
+
|
263 |
+
class ReportUpdate(BaseModel):
|
264 |
+
"""Schema for updating report."""
|
265 |
+
title: Optional[str] = None
|
266 |
+
summary: Optional[str] = None
|
267 |
+
content: Optional[str] = None
|
268 |
+
report_type: Optional[str] = None
|
269 |
+
status: Optional[str] = None
|
270 |
+
severity: Optional[str] = None
|
271 |
+
publish_date: Optional[datetime] = None
|
272 |
+
time_period_start: Optional[datetime] = None
|
273 |
+
time_period_end: Optional[datetime] = None
|
274 |
+
keywords: Optional[str] = None
|
275 |
+
threat_ids: Optional[List[int]] = None
|
276 |
+
|
277 |
+
class ReportResponse(ReportBase):
|
278 |
+
"""Schema for report response."""
|
279 |
+
id: int
|
280 |
+
status: str
|
281 |
+
severity: Optional[str] = None
|
282 |
+
publish_date: Optional[datetime] = None
|
283 |
+
time_period_start: Optional[datetime] = None
|
284 |
+
time_period_end: Optional[datetime] = None
|
285 |
+
keywords: Optional[str] = None
|
286 |
+
author_id: int
|
287 |
+
|
288 |
+
class Config:
|
289 |
+
orm_mode = True
|
290 |
+
|
291 |
+
# Statistics response schemas
|
292 |
+
class ThreatStatisticsResponse(BaseModel):
|
293 |
+
"""Schema for threat statistics response."""
|
294 |
+
total_count: int
|
295 |
+
severity_counts: Dict[str, int]
|
296 |
+
status_counts: Dict[str, int]
|
297 |
+
category_counts: Dict[str, int]
|
298 |
+
time_series: List[Dict[str, Any]]
|
299 |
+
from_date: str
|
300 |
+
to_date: str
|
301 |
+
|
302 |
+
class ContentStatisticsResponse(BaseModel):
|
303 |
+
"""Schema for content statistics response."""
|
304 |
+
total_count: int
|
305 |
+
content_type_counts: Dict[str, int]
|
306 |
+
content_status_counts: Dict[str, int]
|
307 |
+
source_counts: Dict[str, int]
|
308 |
+
time_series: List[Dict[str, Any]]
|
309 |
+
from_date: str
|
310 |
+
to_date: str
|
src/api/security.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
API Security Module
|
3 |
+
|
4 |
+
This module provides security features for the API, including:
|
5 |
+
1. Authentication using JWT tokens
|
6 |
+
2. Rate limiting to prevent abuse
|
7 |
+
3. Role-based access control
|
8 |
+
4. Request validation
|
9 |
+
5. Audit logging
|
10 |
+
"""
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import logging
|
14 |
+
import secrets
|
15 |
+
from datetime import datetime, timedelta
|
16 |
+
from typing import Dict, List, Optional, Union, Any, Callable
|
17 |
+
|
18 |
+
from fastapi import Depends, HTTPException, Security, status, Request
|
19 |
+
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
|
20 |
+
from jose import JWTError, jwt
|
21 |
+
from passlib.context import CryptContext
|
22 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
23 |
+
from sqlalchemy.future import select
|
24 |
+
from pydantic import BaseModel, EmailStr
|
25 |
+
|
26 |
+
from src.models.user import User
|
27 |
+
from src.api.database import get_db
|
28 |
+
|
29 |
+
# Configure logging
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
# Security configuration
|
33 |
+
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_hex(32))
|
34 |
+
ALGORITHM = "HS256"
|
35 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
36 |
+
API_KEY_NAME = "X-API-Key"
|
37 |
+
|
38 |
+
# Set up password hashing
|
39 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
40 |
+
|
41 |
+
# Set up security schemes
|
42 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
43 |
+
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
44 |
+
|
45 |
+
# User models
|
46 |
+
class Token(BaseModel):
|
47 |
+
access_token: str
|
48 |
+
token_type: str
|
49 |
+
expires_at: datetime
|
50 |
+
|
51 |
+
class TokenData(BaseModel):
|
52 |
+
username: Optional[str] = None
|
53 |
+
scopes: List[str] = []
|
54 |
+
|
55 |
+
class UserInDB(BaseModel):
|
56 |
+
id: int
|
57 |
+
username: str
|
58 |
+
email: EmailStr
|
59 |
+
full_name: Optional[str] = None
|
60 |
+
is_active: bool = True
|
61 |
+
is_superuser: bool = False
|
62 |
+
scopes: List[str] = []
|
63 |
+
|
64 |
+
class Config:
|
65 |
+
from_attributes = True
|
66 |
+
|
67 |
+
# Rate limiting
|
68 |
+
class RateLimiter:
|
69 |
+
"""Simple in-memory rate limiter"""
|
70 |
+
|
71 |
+
def __init__(self, rate_limit: int = 100, time_window: int = 60):
|
72 |
+
"""
|
73 |
+
Initialize rate limiter.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
rate_limit: Maximum number of requests per time window
|
77 |
+
time_window: Time window in seconds
|
78 |
+
"""
|
79 |
+
self.rate_limit = rate_limit
|
80 |
+
self.time_window = time_window
|
81 |
+
self.requests = {}
|
82 |
+
|
83 |
+
def is_rate_limited(self, key: str) -> bool:
|
84 |
+
"""
|
85 |
+
Check if a key is rate limited.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
key: Identifier for the client (IP address, API key, etc.)
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
True if rate limited, False otherwise
|
92 |
+
"""
|
93 |
+
current_time = time.time()
|
94 |
+
|
95 |
+
# Initialize or clean up old requests
|
96 |
+
if key not in self.requests:
|
97 |
+
self.requests[key] = []
|
98 |
+
else:
|
99 |
+
# Remove requests outside the time window
|
100 |
+
self.requests[key] = [t for t in self.requests[key] if t > current_time - self.time_window]
|
101 |
+
|
102 |
+
# Check if rate limit is exceeded
|
103 |
+
if len(self.requests[key]) >= self.rate_limit:
|
104 |
+
return True
|
105 |
+
|
106 |
+
# Add the current request
|
107 |
+
self.requests[key].append(current_time)
|
108 |
+
return False
|
109 |
+
|
110 |
+
# Create global rate limiter instance
|
111 |
+
rate_limiter = RateLimiter()
|
112 |
+
|
113 |
+
# Role-based access control
|
114 |
+
# Define roles and permissions
|
115 |
+
ROLES = {
|
116 |
+
"admin": ["read:all", "write:all", "delete:all"],
|
117 |
+
"analyst": ["read:all", "write:threats", "write:indicators", "write:reports"],
|
118 |
+
"user": ["read:threats", "read:reports", "read:dashboard"],
|
119 |
+
"api": ["read:all", "write:threats", "write:indicators"]
|
120 |
+
}
|
121 |
+
|
122 |
+
# Security utility functions
|
123 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
124 |
+
"""Verify a password against a hash"""
|
125 |
+
return pwd_context.verify(plain_password, hashed_password)
|
126 |
+
|
127 |
+
def get_password_hash(password: str) -> str:
|
128 |
+
"""Hash a password for storage"""
|
129 |
+
return pwd_context.hash(password)
|
130 |
+
|
131 |
+
async def get_user(db: AsyncSession, username: str) -> Optional[UserInDB]:
|
132 |
+
"""Get a user from the database by username"""
|
133 |
+
result = await db.execute(select(User).filter(User.username == username))
|
134 |
+
user_db = result.scalars().first()
|
135 |
+
|
136 |
+
if not user_db:
|
137 |
+
return None
|
138 |
+
|
139 |
+
# Get user roles and scopes
|
140 |
+
scopes = []
|
141 |
+
if user_db.is_superuser:
|
142 |
+
scopes = ROLES["admin"]
|
143 |
+
else:
|
144 |
+
# In a real application, you would look up user roles in a database
|
145 |
+
# For simplicity, we'll assume non-superusers have the "user" role
|
146 |
+
scopes = ROLES["user"]
|
147 |
+
|
148 |
+
return UserInDB(
|
149 |
+
id=user_db.id,
|
150 |
+
username=user_db.username,
|
151 |
+
email=user_db.email,
|
152 |
+
full_name=user_db.full_name,
|
153 |
+
is_active=user_db.is_active,
|
154 |
+
is_superuser=user_db.is_superuser,
|
155 |
+
scopes=scopes
|
156 |
+
)
|
157 |
+
|
158 |
+
async def authenticate_user(db: AsyncSession, username: str, password: str) -> Optional[UserInDB]:
|
159 |
+
"""Authenticate a user with username and password"""
|
160 |
+
user = await get_user(db, username)
|
161 |
+
if not user:
|
162 |
+
return None
|
163 |
+
|
164 |
+
# Get the user from the database again to get the hashed password
|
165 |
+
result = await db.execute(select(User).filter(User.username == username))
|
166 |
+
user_db = result.scalars().first()
|
167 |
+
|
168 |
+
if not verify_password(password, user_db.hashed_password):
|
169 |
+
return None
|
170 |
+
|
171 |
+
return user
|
172 |
+
|
173 |
+
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
174 |
+
"""Create a JWT access token"""
|
175 |
+
to_encode = data.copy()
|
176 |
+
|
177 |
+
if expires_delta:
|
178 |
+
expire = datetime.utcnow() + expires_delta
|
179 |
+
else:
|
180 |
+
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
181 |
+
|
182 |
+
to_encode.update({"exp": expire})
|
183 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
184 |
+
return encoded_jwt
|
185 |
+
|
186 |
+
async def get_api_key_user(
|
187 |
+
api_key: str,
|
188 |
+
db: AsyncSession
|
189 |
+
) -> Optional[UserInDB]:
|
190 |
+
"""Get user associated with an API key"""
|
191 |
+
# In a real application, you would look up API keys in a database
|
192 |
+
# For simplicity, we'll use a simple hardcoded mapping
|
193 |
+
# TODO: Replace with database-backed API key storage
|
194 |
+
API_KEYS = {
|
195 |
+
"test-api-key": "api_user",
|
196 |
+
# Add more API keys here
|
197 |
+
}
|
198 |
+
|
199 |
+
if api_key not in API_KEYS:
|
200 |
+
return None
|
201 |
+
|
202 |
+
username = API_KEYS[api_key]
|
203 |
+
user = await get_user(db, username)
|
204 |
+
|
205 |
+
if not user:
|
206 |
+
return None
|
207 |
+
|
208 |
+
# Override scopes with API role scopes
|
209 |
+
user.scopes = ROLES["api"]
|
210 |
+
|
211 |
+
return user
|
212 |
+
|
213 |
+
# Dependencies for FastAPI
|
214 |
+
async def rate_limit(request: Request):
|
215 |
+
"""Rate limiting dependency"""
|
216 |
+
# Use API key or IP address as the rate limit key
|
217 |
+
client_key = request.headers.get(API_KEY_NAME) or request.client.host
|
218 |
+
|
219 |
+
if rate_limiter.is_rate_limited(client_key):
|
220 |
+
logger.warning(f"Rate limit exceeded for {client_key}")
|
221 |
+
raise HTTPException(
|
222 |
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
223 |
+
detail="Rate limit exceeded. Please try again later."
|
224 |
+
)
|
225 |
+
|
226 |
+
return True
|
227 |
+
|
228 |
+
async def get_current_user(
|
229 |
+
token: str = Depends(oauth2_scheme),
|
230 |
+
api_key: str = Security(api_key_header),
|
231 |
+
db: AsyncSession = Depends(get_db)
|
232 |
+
) -> UserInDB:
|
233 |
+
"""
|
234 |
+
Get the current user from either JWT token or API key.
|
235 |
+
|
236 |
+
This dependency can be used to require authentication for endpoints.
|
237 |
+
"""
|
238 |
+
credentials_exception = HTTPException(
|
239 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
240 |
+
detail="Could not validate credentials",
|
241 |
+
headers={"WWW-Authenticate": "Bearer"},
|
242 |
+
)
|
243 |
+
|
244 |
+
# Check API key first
|
245 |
+
if api_key:
|
246 |
+
user = await get_api_key_user(api_key, db)
|
247 |
+
if user:
|
248 |
+
return user
|
249 |
+
|
250 |
+
# Then check JWT token
|
251 |
+
try:
|
252 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
253 |
+
username = payload.get("sub")
|
254 |
+
if username is None:
|
255 |
+
raise credentials_exception
|
256 |
+
|
257 |
+
token_data = TokenData(
|
258 |
+
username=username,
|
259 |
+
scopes=payload.get("scopes", [])
|
260 |
+
)
|
261 |
+
except JWTError:
|
262 |
+
raise credentials_exception
|
263 |
+
|
264 |
+
user = await get_user(db, username=token_data.username)
|
265 |
+
if user is None:
|
266 |
+
raise credentials_exception
|
267 |
+
|
268 |
+
return user
|
269 |
+
|
270 |
+
async def get_current_active_user(
|
271 |
+
current_user: UserInDB = Depends(get_current_user)
|
272 |
+
) -> UserInDB:
|
273 |
+
"""
|
274 |
+
Get the current active user.
|
275 |
+
|
276 |
+
This dependency can be used to require an active user for endpoints.
|
277 |
+
"""
|
278 |
+
if not current_user.is_active:
|
279 |
+
raise HTTPException(status_code=400, detail="Inactive user")
|
280 |
+
|
281 |
+
return current_user
|
282 |
+
|
283 |
+
def has_scope(required_scopes: List[str]):
|
284 |
+
"""
|
285 |
+
Create a dependency that checks if the user has the required scopes.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
required_scopes: List of required scopes
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
A dependency function that checks if the user has the required scopes
|
292 |
+
"""
|
293 |
+
async def _has_scope(
|
294 |
+
current_user: UserInDB = Depends(get_current_active_user)
|
295 |
+
) -> UserInDB:
|
296 |
+
for scope in required_scopes:
|
297 |
+
if scope not in current_user.scopes:
|
298 |
+
raise HTTPException(
|
299 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
300 |
+
detail=f"Permission denied. Required scope: {scope}"
|
301 |
+
)
|
302 |
+
|
303 |
+
return current_user
|
304 |
+
|
305 |
+
return _has_scope
|
306 |
+
|
307 |
+
def admin_only(
|
308 |
+
current_user: UserInDB = Depends(get_current_active_user)
|
309 |
+
) -> UserInDB:
|
310 |
+
"""
|
311 |
+
Dependency that requires an admin user.
|
312 |
+
"""
|
313 |
+
if not current_user.is_superuser:
|
314 |
+
raise HTTPException(
|
315 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
316 |
+
detail="Permission denied. Admin access required."
|
317 |
+
)
|
318 |
+
|
319 |
+
return current_user
|
320 |
+
|
321 |
+
# Audit logging middleware
|
322 |
+
async def audit_log_middleware(request: Request, call_next):
|
323 |
+
"""
|
324 |
+
Middleware for audit logging.
|
325 |
+
|
326 |
+
Records details about API requests.
|
327 |
+
"""
|
328 |
+
# Get request details
|
329 |
+
method = request.method
|
330 |
+
path = request.url.path
|
331 |
+
client_host = request.client.host
|
332 |
+
user_agent = request.headers.get("User-Agent", "Unknown")
|
333 |
+
|
334 |
+
# Get user details if available
|
335 |
+
user = getattr(request.state, "user", None)
|
336 |
+
username = user.username if user else "Anonymous"
|
337 |
+
|
338 |
+
# Log request
|
339 |
+
logger.info(
|
340 |
+
f"API Request: {method} {path} | User: {username} | "
|
341 |
+
f"Client: {client_host} | User-Agent: {user_agent}"
|
342 |
+
)
|
343 |
+
|
344 |
+
# Process the request
|
345 |
+
start_time = time.time()
|
346 |
+
response = await call_next(request)
|
347 |
+
process_time = time.time() - start_time
|
348 |
+
|
349 |
+
# Log response
|
350 |
+
logger.info(
|
351 |
+
f"API Response: {method} {path} | Status: {response.status_code} | "
|
352 |
+
f"Time: {process_time:.4f}s | User: {username}"
|
353 |
+
)
|
354 |
+
|
355 |
+
return response
|
356 |
+
|
357 |
+
# API key validation middleware
|
358 |
+
def validate_api_key(request: Request):
|
359 |
+
"""
|
360 |
+
Middleware function to validate API keys.
|
361 |
+
|
362 |
+
This can be used as a dependency for FastAPI routes.
|
363 |
+
"""
|
364 |
+
api_key = request.headers.get(API_KEY_NAME)
|
365 |
+
if not api_key:
|
366 |
+
raise HTTPException(
|
367 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
368 |
+
detail="API key required",
|
369 |
+
headers={"WWW-Authenticate": f"{API_KEY_NAME}"},
|
370 |
+
)
|
371 |
+
|
372 |
+
# In a real application, you would validate the API key against a database
|
373 |
+
# For simplicity, we'll use a hardcoded list
|
374 |
+
valid_keys = ["test-api-key"] # Replace with database lookup
|
375 |
+
if api_key not in valid_keys:
|
376 |
+
raise HTTPException(
|
377 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
378 |
+
detail="Invalid API key",
|
379 |
+
headers={"WWW-Authenticate": f"{API_KEY_NAME}"},
|
380 |
+
)
|
381 |
+
|
382 |
+
return True
|
src/api/services/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Package initialization for API services.
|
3 |
+
"""
|
src/api/services/alert_service.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Service for alert operations.
|
3 |
+
"""
|
4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
5 |
+
from sqlalchemy.future import select
|
6 |
+
from sqlalchemy import func, or_, and_
|
7 |
+
from datetime import datetime
|
8 |
+
from typing import List, Optional, Dict, Any, Union
|
9 |
+
|
10 |
+
from src.models.alert import Alert, AlertStatus, AlertCategory
|
11 |
+
from src.models.threat import ThreatSeverity
|
12 |
+
from src.api.schemas import PaginationParams
|
13 |
+
|
14 |
+
async def create_alert(
|
15 |
+
db: AsyncSession,
|
16 |
+
title: str,
|
17 |
+
description: str,
|
18 |
+
severity: ThreatSeverity,
|
19 |
+
category: AlertCategory,
|
20 |
+
source_url: Optional[str] = None,
|
21 |
+
threat_id: Optional[int] = None,
|
22 |
+
mention_id: Optional[int] = None,
|
23 |
+
) -> Alert:
|
24 |
+
"""
|
25 |
+
Create a new alert.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
db: Database session
|
29 |
+
title: Alert title
|
30 |
+
description: Alert description
|
31 |
+
severity: Alert severity
|
32 |
+
category: Alert category
|
33 |
+
source_url: Source URL for the alert
|
34 |
+
threat_id: ID of related threat
|
35 |
+
mention_id: ID of related dark web mention
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
Alert: Created alert
|
39 |
+
"""
|
40 |
+
db_alert = Alert(
|
41 |
+
title=title,
|
42 |
+
description=description,
|
43 |
+
severity=severity,
|
44 |
+
status=AlertStatus.NEW,
|
45 |
+
category=category,
|
46 |
+
generated_at=datetime.utcnow(),
|
47 |
+
source_url=source_url,
|
48 |
+
is_read=False,
|
49 |
+
threat_id=threat_id,
|
50 |
+
mention_id=mention_id,
|
51 |
+
)
|
52 |
+
|
53 |
+
db.add(db_alert)
|
54 |
+
await db.commit()
|
55 |
+
await db.refresh(db_alert)
|
56 |
+
|
57 |
+
return db_alert
|
58 |
+
|
59 |
+
async def get_alert_by_id(db: AsyncSession, alert_id: int) -> Optional[Alert]:
|
60 |
+
"""
|
61 |
+
Get alert by ID.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
db: Database session
|
65 |
+
alert_id: Alert ID
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Optional[Alert]: Alert or None if not found
|
69 |
+
"""
|
70 |
+
result = await db.execute(select(Alert).filter(Alert.id == alert_id))
|
71 |
+
return result.scalars().first()
|
72 |
+
|
73 |
+
async def get_alerts(
|
74 |
+
db: AsyncSession,
|
75 |
+
pagination: PaginationParams,
|
76 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
77 |
+
status: Optional[List[AlertStatus]] = None,
|
78 |
+
category: Optional[List[AlertCategory]] = None,
|
79 |
+
is_read: Optional[bool] = None,
|
80 |
+
search_query: Optional[str] = None,
|
81 |
+
from_date: Optional[datetime] = None,
|
82 |
+
to_date: Optional[datetime] = None,
|
83 |
+
) -> List[Alert]:
|
84 |
+
"""
|
85 |
+
Get alerts with filtering and pagination.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
db: Database session
|
89 |
+
pagination: Pagination parameters
|
90 |
+
severity: Filter by severity
|
91 |
+
status: Filter by status
|
92 |
+
category: Filter by category
|
93 |
+
is_read: Filter by read status
|
94 |
+
search_query: Search in title and description
|
95 |
+
from_date: Filter by generated_at >= from_date
|
96 |
+
to_date: Filter by generated_at <= to_date
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
List[Alert]: List of alerts
|
100 |
+
"""
|
101 |
+
query = select(Alert)
|
102 |
+
|
103 |
+
# Apply filters
|
104 |
+
if severity:
|
105 |
+
query = query.filter(Alert.severity.in_(severity))
|
106 |
+
|
107 |
+
if status:
|
108 |
+
query = query.filter(Alert.status.in_(status))
|
109 |
+
|
110 |
+
if category:
|
111 |
+
query = query.filter(Alert.category.in_(category))
|
112 |
+
|
113 |
+
if is_read is not None:
|
114 |
+
query = query.filter(Alert.is_read == is_read)
|
115 |
+
|
116 |
+
if search_query:
|
117 |
+
search_filter = or_(
|
118 |
+
Alert.title.ilike(f"%{search_query}%"),
|
119 |
+
Alert.description.ilike(f"%{search_query}%")
|
120 |
+
)
|
121 |
+
query = query.filter(search_filter)
|
122 |
+
|
123 |
+
if from_date:
|
124 |
+
query = query.filter(Alert.generated_at >= from_date)
|
125 |
+
|
126 |
+
if to_date:
|
127 |
+
query = query.filter(Alert.generated_at <= to_date)
|
128 |
+
|
129 |
+
# Apply pagination
|
130 |
+
query = query.order_by(Alert.generated_at.desc())
|
131 |
+
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
|
132 |
+
|
133 |
+
result = await db.execute(query)
|
134 |
+
return result.scalars().all()
|
135 |
+
|
136 |
+
async def count_alerts(
|
137 |
+
db: AsyncSession,
|
138 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
139 |
+
status: Optional[List[AlertStatus]] = None,
|
140 |
+
category: Optional[List[AlertCategory]] = None,
|
141 |
+
is_read: Optional[bool] = None,
|
142 |
+
search_query: Optional[str] = None,
|
143 |
+
from_date: Optional[datetime] = None,
|
144 |
+
to_date: Optional[datetime] = None,
|
145 |
+
) -> int:
|
146 |
+
"""
|
147 |
+
Count alerts with filtering.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
db: Database session
|
151 |
+
severity: Filter by severity
|
152 |
+
status: Filter by status
|
153 |
+
category: Filter by category
|
154 |
+
is_read: Filter by read status
|
155 |
+
search_query: Search in title and description
|
156 |
+
from_date: Filter by generated_at >= from_date
|
157 |
+
to_date: Filter by generated_at <= to_date
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
int: Count of alerts
|
161 |
+
"""
|
162 |
+
query = select(func.count(Alert.id))
|
163 |
+
|
164 |
+
# Apply filters (same as in get_alerts)
|
165 |
+
if severity:
|
166 |
+
query = query.filter(Alert.severity.in_(severity))
|
167 |
+
|
168 |
+
if status:
|
169 |
+
query = query.filter(Alert.status.in_(status))
|
170 |
+
|
171 |
+
if category:
|
172 |
+
query = query.filter(Alert.category.in_(category))
|
173 |
+
|
174 |
+
if is_read is not None:
|
175 |
+
query = query.filter(Alert.is_read == is_read)
|
176 |
+
|
177 |
+
if search_query:
|
178 |
+
search_filter = or_(
|
179 |
+
Alert.title.ilike(f"%{search_query}%"),
|
180 |
+
Alert.description.ilike(f"%{search_query}%")
|
181 |
+
)
|
182 |
+
query = query.filter(search_filter)
|
183 |
+
|
184 |
+
if from_date:
|
185 |
+
query = query.filter(Alert.generated_at >= from_date)
|
186 |
+
|
187 |
+
if to_date:
|
188 |
+
query = query.filter(Alert.generated_at <= to_date)
|
189 |
+
|
190 |
+
result = await db.execute(query)
|
191 |
+
return result.scalar()
|
192 |
+
|
193 |
+
async def update_alert_status(
|
194 |
+
db: AsyncSession,
|
195 |
+
alert_id: int,
|
196 |
+
status: AlertStatus,
|
197 |
+
action_taken: Optional[str] = None,
|
198 |
+
) -> Optional[Alert]:
|
199 |
+
"""
|
200 |
+
Update alert status.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
db: Database session
|
204 |
+
alert_id: Alert ID
|
205 |
+
status: New status
|
206 |
+
action_taken: Description of action taken
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
Optional[Alert]: Updated alert or None if not found
|
210 |
+
"""
|
211 |
+
alert = await get_alert_by_id(db, alert_id)
|
212 |
+
if not alert:
|
213 |
+
return None
|
214 |
+
|
215 |
+
alert.status = status
|
216 |
+
|
217 |
+
if action_taken:
|
218 |
+
alert.action_taken = action_taken
|
219 |
+
|
220 |
+
if status == AlertStatus.RESOLVED:
|
221 |
+
alert.resolved_at = datetime.utcnow()
|
222 |
+
|
223 |
+
alert.updated_at = datetime.utcnow()
|
224 |
+
|
225 |
+
await db.commit()
|
226 |
+
await db.refresh(alert)
|
227 |
+
|
228 |
+
return alert
|
229 |
+
|
230 |
+
async def mark_alert_as_read(
|
231 |
+
db: AsyncSession,
|
232 |
+
alert_id: int,
|
233 |
+
) -> Optional[Alert]:
|
234 |
+
"""
|
235 |
+
Mark alert as read.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
db: Database session
|
239 |
+
alert_id: Alert ID
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
Optional[Alert]: Updated alert or None if not found
|
243 |
+
"""
|
244 |
+
alert = await get_alert_by_id(db, alert_id)
|
245 |
+
if not alert:
|
246 |
+
return None
|
247 |
+
|
248 |
+
alert.is_read = True
|
249 |
+
alert.updated_at = datetime.utcnow()
|
250 |
+
|
251 |
+
await db.commit()
|
252 |
+
await db.refresh(alert)
|
253 |
+
|
254 |
+
return alert
|
255 |
+
|
256 |
+
async def assign_alert(
|
257 |
+
db: AsyncSession,
|
258 |
+
alert_id: int,
|
259 |
+
user_id: int,
|
260 |
+
) -> Optional[Alert]:
|
261 |
+
"""
|
262 |
+
Assign alert to a user.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
db: Database session
|
266 |
+
alert_id: Alert ID
|
267 |
+
user_id: User ID to assign to
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
Optional[Alert]: Updated alert or None if not found
|
271 |
+
"""
|
272 |
+
alert = await get_alert_by_id(db, alert_id)
|
273 |
+
if not alert:
|
274 |
+
return None
|
275 |
+
|
276 |
+
alert.assigned_to_id = user_id
|
277 |
+
alert.status = AlertStatus.ASSIGNED
|
278 |
+
alert.updated_at = datetime.utcnow()
|
279 |
+
|
280 |
+
await db.commit()
|
281 |
+
await db.refresh(alert)
|
282 |
+
|
283 |
+
return alert
|
284 |
+
|
285 |
+
async def get_alert_counts_by_severity(
|
286 |
+
db: AsyncSession,
|
287 |
+
from_date: Optional[datetime] = None,
|
288 |
+
to_date: Optional[datetime] = None,
|
289 |
+
) -> Dict[str, int]:
|
290 |
+
"""
|
291 |
+
Get count of alerts by severity.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
db: Database session
|
295 |
+
from_date: Filter by generated_at >= from_date
|
296 |
+
to_date: Filter by generated_at <= to_date
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
Dict[str, int]: Mapping of severity to count
|
300 |
+
"""
|
301 |
+
result = {}
|
302 |
+
|
303 |
+
for severity in ThreatSeverity:
|
304 |
+
query = select(func.count(Alert.id)).filter(Alert.severity == severity)
|
305 |
+
|
306 |
+
if from_date:
|
307 |
+
query = query.filter(Alert.generated_at >= from_date)
|
308 |
+
|
309 |
+
if to_date:
|
310 |
+
query = query.filter(Alert.generated_at <= to_date)
|
311 |
+
|
312 |
+
count_result = await db.execute(query)
|
313 |
+
count = count_result.scalar() or 0
|
314 |
+
result[severity.value] = count
|
315 |
+
|
316 |
+
return result
|
src/api/services/dark_web_content_service.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Service for dark web content operations.
|
3 |
+
"""
|
4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
5 |
+
from sqlalchemy.future import select
|
6 |
+
from sqlalchemy import func, or_, text
|
7 |
+
from datetime import datetime
|
8 |
+
from typing import List, Optional, Dict, Any, Union
|
9 |
+
|
10 |
+
from src.models.dark_web_content import DarkWebContent, DarkWebMention, ContentType, ContentStatus
|
11 |
+
from src.models.threat import Threat, ThreatCategory, ThreatSeverity, ThreatStatus
|
12 |
+
from src.api.schemas import PaginationParams
|
13 |
+
|
14 |
+
async def create_content(
|
15 |
+
db: AsyncSession,
|
16 |
+
url: str,
|
17 |
+
content: str,
|
18 |
+
title: Optional[str] = None,
|
19 |
+
content_type: ContentType = ContentType.OTHER,
|
20 |
+
content_status: ContentStatus = ContentStatus.NEW,
|
21 |
+
source_name: Optional[str] = None,
|
22 |
+
source_type: Optional[str] = None,
|
23 |
+
language: Optional[str] = None,
|
24 |
+
relevance_score: float = 0.0,
|
25 |
+
sentiment_score: float = 0.0,
|
26 |
+
entity_data: Optional[str] = None,
|
27 |
+
) -> DarkWebContent:
|
28 |
+
"""
|
29 |
+
Create a new dark web content entry.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
db: Database session
|
33 |
+
url: URL of the content
|
34 |
+
content: Text content
|
35 |
+
title: Title of the content
|
36 |
+
content_type: Type of content
|
37 |
+
content_status: Status of content
|
38 |
+
source_name: Name of the source
|
39 |
+
source_type: Type of source
|
40 |
+
language: Language of the content
|
41 |
+
relevance_score: Relevance score (0-1)
|
42 |
+
sentiment_score: Sentiment score (-1 to 1)
|
43 |
+
entity_data: JSON string of extracted entities
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
DarkWebContent: Created content
|
47 |
+
"""
|
48 |
+
# Extract domain from URL if possible
|
49 |
+
domain = None
|
50 |
+
if url:
|
51 |
+
try:
|
52 |
+
from urllib.parse import urlparse
|
53 |
+
parsed_url = urlparse(url)
|
54 |
+
domain = parsed_url.netloc
|
55 |
+
except:
|
56 |
+
pass
|
57 |
+
|
58 |
+
db_content = DarkWebContent(
|
59 |
+
url=url,
|
60 |
+
domain=domain,
|
61 |
+
title=title,
|
62 |
+
content=content,
|
63 |
+
content_type=content_type,
|
64 |
+
content_status=content_status,
|
65 |
+
source_name=source_name,
|
66 |
+
source_type=source_type,
|
67 |
+
language=language,
|
68 |
+
scraped_at=datetime.utcnow(),
|
69 |
+
relevance_score=relevance_score,
|
70 |
+
sentiment_score=sentiment_score,
|
71 |
+
entity_data=entity_data,
|
72 |
+
)
|
73 |
+
|
74 |
+
db.add(db_content)
|
75 |
+
await db.commit()
|
76 |
+
await db.refresh(db_content)
|
77 |
+
|
78 |
+
return db_content
|
79 |
+
|
80 |
+
async def get_content_by_id(db: AsyncSession, content_id: int) -> Optional[DarkWebContent]:
|
81 |
+
"""
|
82 |
+
Get dark web content by ID.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
db: Database session
|
86 |
+
content_id: Content ID
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Optional[DarkWebContent]: Content or None if not found
|
90 |
+
"""
|
91 |
+
result = await db.execute(select(DarkWebContent).filter(DarkWebContent.id == content_id))
|
92 |
+
return result.scalars().first()
|
93 |
+
|
94 |
+
async def get_contents(
|
95 |
+
db: AsyncSession,
|
96 |
+
pagination: PaginationParams,
|
97 |
+
content_type: Optional[List[ContentType]] = None,
|
98 |
+
content_status: Optional[List[ContentStatus]] = None,
|
99 |
+
source_name: Optional[str] = None,
|
100 |
+
search_query: Optional[str] = None,
|
101 |
+
from_date: Optional[datetime] = None,
|
102 |
+
to_date: Optional[datetime] = None,
|
103 |
+
) -> List[DarkWebContent]:
|
104 |
+
"""
|
105 |
+
Get dark web contents with filtering and pagination.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
db: Database session
|
109 |
+
pagination: Pagination parameters
|
110 |
+
content_type: Filter by content type
|
111 |
+
content_status: Filter by content status
|
112 |
+
source_name: Filter by source name
|
113 |
+
search_query: Search in title and content
|
114 |
+
from_date: Filter by scraped_at >= from_date
|
115 |
+
to_date: Filter by scraped_at <= to_date
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
List[DarkWebContent]: List of dark web contents
|
119 |
+
"""
|
120 |
+
query = select(DarkWebContent)
|
121 |
+
|
122 |
+
# Apply filters
|
123 |
+
if content_type:
|
124 |
+
query = query.filter(DarkWebContent.content_type.in_(content_type))
|
125 |
+
|
126 |
+
if content_status:
|
127 |
+
query = query.filter(DarkWebContent.content_status.in_(content_status))
|
128 |
+
|
129 |
+
if source_name:
|
130 |
+
query = query.filter(DarkWebContent.source_name == source_name)
|
131 |
+
|
132 |
+
if search_query:
|
133 |
+
search_filter = or_(
|
134 |
+
DarkWebContent.title.ilike(f"%{search_query}%"),
|
135 |
+
DarkWebContent.content.ilike(f"%{search_query}%")
|
136 |
+
)
|
137 |
+
query = query.filter(search_filter)
|
138 |
+
|
139 |
+
if from_date:
|
140 |
+
query = query.filter(DarkWebContent.scraped_at >= from_date)
|
141 |
+
|
142 |
+
if to_date:
|
143 |
+
query = query.filter(DarkWebContent.scraped_at <= to_date)
|
144 |
+
|
145 |
+
# Apply pagination
|
146 |
+
query = query.order_by(DarkWebContent.scraped_at.desc())
|
147 |
+
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
|
148 |
+
|
149 |
+
result = await db.execute(query)
|
150 |
+
return result.scalars().all()
|
151 |
+
|
152 |
+
async def count_contents(
|
153 |
+
db: AsyncSession,
|
154 |
+
content_type: Optional[List[ContentType]] = None,
|
155 |
+
content_status: Optional[List[ContentStatus]] = None,
|
156 |
+
source_name: Optional[str] = None,
|
157 |
+
search_query: Optional[str] = None,
|
158 |
+
from_date: Optional[datetime] = None,
|
159 |
+
to_date: Optional[datetime] = None,
|
160 |
+
) -> int:
|
161 |
+
"""
|
162 |
+
Count dark web contents with filtering.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
db: Database session
|
166 |
+
content_type: Filter by content type
|
167 |
+
content_status: Filter by content status
|
168 |
+
source_name: Filter by source name
|
169 |
+
search_query: Search in title and content
|
170 |
+
from_date: Filter by scraped_at >= from_date
|
171 |
+
to_date: Filter by scraped_at <= to_date
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
int: Count of dark web contents
|
175 |
+
"""
|
176 |
+
query = select(func.count(DarkWebContent.id))
|
177 |
+
|
178 |
+
# Apply filters (same as in get_contents)
|
179 |
+
if content_type:
|
180 |
+
query = query.filter(DarkWebContent.content_type.in_(content_type))
|
181 |
+
|
182 |
+
if content_status:
|
183 |
+
query = query.filter(DarkWebContent.content_status.in_(content_status))
|
184 |
+
|
185 |
+
if source_name:
|
186 |
+
query = query.filter(DarkWebContent.source_name == source_name)
|
187 |
+
|
188 |
+
if search_query:
|
189 |
+
search_filter = or_(
|
190 |
+
DarkWebContent.title.ilike(f"%{search_query}%"),
|
191 |
+
DarkWebContent.content.ilike(f"%{search_query}%")
|
192 |
+
)
|
193 |
+
query = query.filter(search_filter)
|
194 |
+
|
195 |
+
if from_date:
|
196 |
+
query = query.filter(DarkWebContent.scraped_at >= from_date)
|
197 |
+
|
198 |
+
if to_date:
|
199 |
+
query = query.filter(DarkWebContent.scraped_at <= to_date)
|
200 |
+
|
201 |
+
result = await db.execute(query)
|
202 |
+
return result.scalar()
|
203 |
+
|
204 |
+
async def create_mention(
|
205 |
+
db: AsyncSession,
|
206 |
+
content_id: int,
|
207 |
+
keyword: str,
|
208 |
+
keyword_category: Optional[str] = None,
|
209 |
+
context: Optional[str] = None,
|
210 |
+
snippet: Optional[str] = None,
|
211 |
+
mention_type: Optional[str] = None,
|
212 |
+
confidence: float = 0.0,
|
213 |
+
is_verified: bool = False,
|
214 |
+
) -> DarkWebMention:
|
215 |
+
"""
|
216 |
+
Create a new dark web mention.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
db: Database session
|
220 |
+
content_id: ID of the content where the mention was found
|
221 |
+
keyword: Keyword that was mentioned
|
222 |
+
keyword_category: Category of the keyword
|
223 |
+
context: Text surrounding the mention
|
224 |
+
snippet: Extract of text containing the mention
|
225 |
+
mention_type: Type of mention
|
226 |
+
confidence: Confidence score (0-1)
|
227 |
+
is_verified: Whether the mention is verified
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
DarkWebMention: Created mention
|
231 |
+
"""
|
232 |
+
db_mention = DarkWebMention(
|
233 |
+
content_id=content_id,
|
234 |
+
keyword=keyword,
|
235 |
+
keyword_category=keyword_category,
|
236 |
+
context=context,
|
237 |
+
snippet=snippet,
|
238 |
+
mention_type=mention_type,
|
239 |
+
confidence=confidence,
|
240 |
+
is_verified=is_verified,
|
241 |
+
)
|
242 |
+
|
243 |
+
db.add(db_mention)
|
244 |
+
await db.commit()
|
245 |
+
await db.refresh(db_mention)
|
246 |
+
|
247 |
+
return db_mention
|
248 |
+
|
249 |
+
async def get_mention_by_id(db: AsyncSession, mention_id: int) -> Optional[DarkWebMention]:
|
250 |
+
"""
|
251 |
+
Get dark web mention by ID.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
db: Database session
|
255 |
+
mention_id: Mention ID
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
Optional[DarkWebMention]: Mention or None if not found
|
259 |
+
"""
|
260 |
+
result = await db.execute(select(DarkWebMention).filter(DarkWebMention.id == mention_id))
|
261 |
+
return result.scalars().first()
|
262 |
+
|
263 |
+
async def get_mentions(
|
264 |
+
db: AsyncSession,
|
265 |
+
pagination: PaginationParams,
|
266 |
+
keyword: Optional[str] = None,
|
267 |
+
content_id: Optional[int] = None,
|
268 |
+
is_verified: Optional[bool] = None,
|
269 |
+
from_date: Optional[datetime] = None,
|
270 |
+
to_date: Optional[datetime] = None,
|
271 |
+
) -> List[DarkWebMention]:
|
272 |
+
"""
|
273 |
+
Get dark web mentions with filtering and pagination.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
db: Database session
|
277 |
+
pagination: Pagination parameters
|
278 |
+
keyword: Filter by keyword
|
279 |
+
content_id: Filter by content ID
|
280 |
+
is_verified: Filter by verification status
|
281 |
+
from_date: Filter by created_at >= from_date
|
282 |
+
to_date: Filter by created_at <= to_date
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
List[DarkWebMention]: List of dark web mentions
|
286 |
+
"""
|
287 |
+
query = select(DarkWebMention)
|
288 |
+
|
289 |
+
# Apply filters
|
290 |
+
if keyword:
|
291 |
+
query = query.filter(DarkWebMention.keyword.ilike(f"%{keyword}%"))
|
292 |
+
|
293 |
+
if content_id:
|
294 |
+
query = query.filter(DarkWebMention.content_id == content_id)
|
295 |
+
|
296 |
+
if is_verified is not None:
|
297 |
+
query = query.filter(DarkWebMention.is_verified == is_verified)
|
298 |
+
|
299 |
+
if from_date:
|
300 |
+
query = query.filter(DarkWebMention.created_at >= from_date)
|
301 |
+
|
302 |
+
if to_date:
|
303 |
+
query = query.filter(DarkWebMention.created_at <= to_date)
|
304 |
+
|
305 |
+
# Apply pagination
|
306 |
+
query = query.order_by(DarkWebMention.created_at.desc())
|
307 |
+
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
|
308 |
+
|
309 |
+
result = await db.execute(query)
|
310 |
+
return result.scalars().all()
|
311 |
+
|
312 |
+
async def create_threat_from_content(
|
313 |
+
db: AsyncSession,
|
314 |
+
content_id: int,
|
315 |
+
title: str,
|
316 |
+
description: str,
|
317 |
+
severity: ThreatSeverity,
|
318 |
+
category: ThreatCategory,
|
319 |
+
confidence_score: float = 0.0,
|
320 |
+
) -> Threat:
|
321 |
+
"""
|
322 |
+
Create a threat from dark web content.
|
323 |
+
|
324 |
+
Args:
|
325 |
+
db: Database session
|
326 |
+
content_id: ID of the content
|
327 |
+
title: Threat title
|
328 |
+
description: Threat description
|
329 |
+
severity: Threat severity
|
330 |
+
category: Threat category
|
331 |
+
confidence_score: Confidence score (0-1)
|
332 |
+
|
333 |
+
Returns:
|
334 |
+
Threat: Created threat
|
335 |
+
"""
|
336 |
+
# Get the content
|
337 |
+
content = await get_content_by_id(db, content_id)
|
338 |
+
if not content:
|
339 |
+
raise ValueError(f"Content with ID {content_id} not found")
|
340 |
+
|
341 |
+
# Create the threat
|
342 |
+
from src.api.services.threat_service import create_threat
|
343 |
+
|
344 |
+
threat = await create_threat(
|
345 |
+
db=db,
|
346 |
+
title=title,
|
347 |
+
description=description,
|
348 |
+
severity=severity,
|
349 |
+
category=category,
|
350 |
+
status=ThreatStatus.NEW,
|
351 |
+
source_url=content.url,
|
352 |
+
source_name=content.source_name,
|
353 |
+
source_type=content.source_type,
|
354 |
+
confidence_score=confidence_score,
|
355 |
+
)
|
356 |
+
|
357 |
+
return threat
|
src/api/services/report_service.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Service for working with intelligence reports.
|
3 |
+
"""
|
4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
5 |
+
from sqlalchemy.future import select
|
6 |
+
from sqlalchemy import update, delete, func, desc, and_, or_
|
7 |
+
from typing import List, Optional, Dict, Any, Union
|
8 |
+
import logging
|
9 |
+
from datetime import datetime
|
10 |
+
|
11 |
+
from src.models.report import Report, ReportType, ReportStatus
|
12 |
+
from src.models.threat import ThreatSeverity
|
13 |
+
from src.api.schemas import PaginationParams
|
14 |
+
|
15 |
+
# Configure logger
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
async def create_report(
|
20 |
+
db: AsyncSession,
|
21 |
+
title: str,
|
22 |
+
summary: str,
|
23 |
+
content: str,
|
24 |
+
report_type: ReportType,
|
25 |
+
report_id: str,
|
26 |
+
status: ReportStatus = ReportStatus.DRAFT,
|
27 |
+
severity: Optional[ThreatSeverity] = None,
|
28 |
+
publish_date: Optional[datetime] = None,
|
29 |
+
time_period_start: Optional[datetime] = None,
|
30 |
+
time_period_end: Optional[datetime] = None,
|
31 |
+
keywords: Optional[List[str]] = None,
|
32 |
+
source_data: Optional[Dict[str, Any]] = None,
|
33 |
+
author_id: Optional[int] = None,
|
34 |
+
) -> Report:
|
35 |
+
"""
|
36 |
+
Create a new intelligence report.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
db: Database session
|
40 |
+
title: Report title
|
41 |
+
summary: Report summary
|
42 |
+
content: Report content
|
43 |
+
report_type: Type of report
|
44 |
+
report_id: Custom ID for the report
|
45 |
+
status: Report status
|
46 |
+
severity: Report severity
|
47 |
+
publish_date: Publication date
|
48 |
+
time_period_start: Start of time period covered
|
49 |
+
time_period_end: End of time period covered
|
50 |
+
keywords: List of keywords related to the report
|
51 |
+
source_data: Sources and references
|
52 |
+
author_id: ID of the report author
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
Report: Created report
|
56 |
+
"""
|
57 |
+
report = Report(
|
58 |
+
title=title,
|
59 |
+
summary=summary,
|
60 |
+
content=content,
|
61 |
+
report_type=report_type,
|
62 |
+
report_id=report_id,
|
63 |
+
status=status,
|
64 |
+
severity=severity,
|
65 |
+
publish_date=publish_date,
|
66 |
+
time_period_start=time_period_start,
|
67 |
+
time_period_end=time_period_end,
|
68 |
+
keywords=keywords or [],
|
69 |
+
source_data=source_data or {},
|
70 |
+
author_id=author_id,
|
71 |
+
)
|
72 |
+
|
73 |
+
db.add(report)
|
74 |
+
await db.commit()
|
75 |
+
await db.refresh(report)
|
76 |
+
|
77 |
+
return report
|
78 |
+
|
79 |
+
|
80 |
+
async def get_report_by_id(db: AsyncSession, report_id: int) -> Optional[Report]:
|
81 |
+
"""
|
82 |
+
Get report by ID.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
db: Database session
|
86 |
+
report_id: Report ID
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Optional[Report]: Found report or None
|
90 |
+
"""
|
91 |
+
result = await db.execute(
|
92 |
+
select(Report).where(Report.id == report_id)
|
93 |
+
)
|
94 |
+
return result.scalars().first()
|
95 |
+
|
96 |
+
|
97 |
+
async def get_report_by_custom_id(db: AsyncSession, custom_id: str) -> Optional[Report]:
|
98 |
+
"""
|
99 |
+
Get report by custom ID.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
db: Database session
|
103 |
+
custom_id: Custom report ID (e.g., RPT-2023-0001)
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
Optional[Report]: Found report or None
|
107 |
+
"""
|
108 |
+
result = await db.execute(
|
109 |
+
select(Report).where(Report.report_id == custom_id)
|
110 |
+
)
|
111 |
+
return result.scalars().first()
|
112 |
+
|
113 |
+
|
114 |
+
async def get_reports(
|
115 |
+
db: AsyncSession,
|
116 |
+
pagination: PaginationParams,
|
117 |
+
report_type: Optional[List[ReportType]] = None,
|
118 |
+
status: Optional[List[ReportStatus]] = None,
|
119 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
120 |
+
search_query: Optional[str] = None,
|
121 |
+
keywords: Optional[List[str]] = None,
|
122 |
+
author_id: Optional[int] = None,
|
123 |
+
from_date: Optional[datetime] = None,
|
124 |
+
to_date: Optional[datetime] = None,
|
125 |
+
) -> List[Report]:
|
126 |
+
"""
|
127 |
+
Get reports with filtering.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
db: Database session
|
131 |
+
pagination: Pagination parameters
|
132 |
+
report_type: Filter by report type
|
133 |
+
status: Filter by status
|
134 |
+
severity: Filter by severity
|
135 |
+
search_query: Search in title and summary
|
136 |
+
keywords: Filter by keywords
|
137 |
+
author_id: Filter by author ID
|
138 |
+
from_date: Filter by created_at >= from_date
|
139 |
+
to_date: Filter by created_at <= to_date
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
List[Report]: List of reports
|
143 |
+
"""
|
144 |
+
query = select(Report)
|
145 |
+
|
146 |
+
# Apply filters
|
147 |
+
if report_type:
|
148 |
+
query = query.where(Report.report_type.in_(report_type))
|
149 |
+
|
150 |
+
if status:
|
151 |
+
query = query.where(Report.status.in_(status))
|
152 |
+
|
153 |
+
if severity:
|
154 |
+
query = query.where(Report.severity.in_(severity))
|
155 |
+
|
156 |
+
if search_query:
|
157 |
+
search_filter = or_(
|
158 |
+
Report.title.ilike(f"%{search_query}%"),
|
159 |
+
Report.summary.ilike(f"%{search_query}%"),
|
160 |
+
Report.content.ilike(f"%{search_query}%"),
|
161 |
+
)
|
162 |
+
query = query.where(search_filter)
|
163 |
+
|
164 |
+
if keywords:
|
165 |
+
# For JSON arrays, need to use a more complex query
|
166 |
+
for keyword in keywords:
|
167 |
+
query = query.where(Report.keywords.contains([keyword]))
|
168 |
+
|
169 |
+
if author_id:
|
170 |
+
query = query.where(Report.author_id == author_id)
|
171 |
+
|
172 |
+
if from_date:
|
173 |
+
query = query.where(Report.created_at >= from_date)
|
174 |
+
|
175 |
+
if to_date:
|
176 |
+
query = query.where(Report.created_at <= to_date)
|
177 |
+
|
178 |
+
# Apply pagination
|
179 |
+
query = query.order_by(desc(Report.created_at))
|
180 |
+
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
|
181 |
+
|
182 |
+
result = await db.execute(query)
|
183 |
+
return result.scalars().all()
|
184 |
+
|
185 |
+
|
186 |
+
async def count_reports(
|
187 |
+
db: AsyncSession,
|
188 |
+
report_type: Optional[List[ReportType]] = None,
|
189 |
+
status: Optional[List[ReportStatus]] = None,
|
190 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
191 |
+
search_query: Optional[str] = None,
|
192 |
+
keywords: Optional[List[str]] = None,
|
193 |
+
author_id: Optional[int] = None,
|
194 |
+
from_date: Optional[datetime] = None,
|
195 |
+
to_date: Optional[datetime] = None,
|
196 |
+
) -> int:
|
197 |
+
"""
|
198 |
+
Count reports with filtering.
|
199 |
+
|
200 |
+
Args are the same as get_reports, except pagination.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
int: Count of matching reports
|
204 |
+
"""
|
205 |
+
query = select(func.count(Report.id))
|
206 |
+
|
207 |
+
# Apply filters
|
208 |
+
if report_type:
|
209 |
+
query = query.where(Report.report_type.in_(report_type))
|
210 |
+
|
211 |
+
if status:
|
212 |
+
query = query.where(Report.status.in_(status))
|
213 |
+
|
214 |
+
if severity:
|
215 |
+
query = query.where(Report.severity.in_(severity))
|
216 |
+
|
217 |
+
if search_query:
|
218 |
+
search_filter = or_(
|
219 |
+
Report.title.ilike(f"%{search_query}%"),
|
220 |
+
Report.summary.ilike(f"%{search_query}%"),
|
221 |
+
Report.content.ilike(f"%{search_query}%"),
|
222 |
+
)
|
223 |
+
query = query.where(search_filter)
|
224 |
+
|
225 |
+
if keywords:
|
226 |
+
# For JSON arrays, need to use a more complex query
|
227 |
+
for keyword in keywords:
|
228 |
+
query = query.where(Report.keywords.contains([keyword]))
|
229 |
+
|
230 |
+
if author_id:
|
231 |
+
query = query.where(Report.author_id == author_id)
|
232 |
+
|
233 |
+
if from_date:
|
234 |
+
query = query.where(Report.created_at >= from_date)
|
235 |
+
|
236 |
+
if to_date:
|
237 |
+
query = query.where(Report.created_at <= to_date)
|
238 |
+
|
239 |
+
result = await db.execute(query)
|
240 |
+
return result.scalar()
|
241 |
+
|
242 |
+
|
243 |
+
async def update_report(
|
244 |
+
db: AsyncSession,
|
245 |
+
report_id: int,
|
246 |
+
title: Optional[str] = None,
|
247 |
+
summary: Optional[str] = None,
|
248 |
+
content: Optional[str] = None,
|
249 |
+
report_type: Optional[ReportType] = None,
|
250 |
+
status: Optional[ReportStatus] = None,
|
251 |
+
severity: Optional[ThreatSeverity] = None,
|
252 |
+
publish_date: Optional[datetime] = None,
|
253 |
+
time_period_start: Optional[datetime] = None,
|
254 |
+
time_period_end: Optional[datetime] = None,
|
255 |
+
keywords: Optional[List[str]] = None,
|
256 |
+
source_data: Optional[Dict[str, Any]] = None,
|
257 |
+
) -> Optional[Report]:
|
258 |
+
"""
|
259 |
+
Update a report.
|
260 |
+
|
261 |
+
Args:
|
262 |
+
db: Database session
|
263 |
+
report_id: Report ID
|
264 |
+
Other args: Fields to update
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
Optional[Report]: Updated report or None
|
268 |
+
"""
|
269 |
+
report = await get_report_by_id(db, report_id)
|
270 |
+
|
271 |
+
if not report:
|
272 |
+
return None
|
273 |
+
|
274 |
+
# Update fields if provided
|
275 |
+
if title is not None:
|
276 |
+
report.title = title
|
277 |
+
|
278 |
+
if summary is not None:
|
279 |
+
report.summary = summary
|
280 |
+
|
281 |
+
if content is not None:
|
282 |
+
report.content = content
|
283 |
+
|
284 |
+
if report_type is not None:
|
285 |
+
report.report_type = report_type
|
286 |
+
|
287 |
+
if status is not None:
|
288 |
+
report.status = status
|
289 |
+
|
290 |
+
if severity is not None:
|
291 |
+
report.severity = severity
|
292 |
+
|
293 |
+
if publish_date is not None:
|
294 |
+
report.publish_date = publish_date
|
295 |
+
|
296 |
+
if time_period_start is not None:
|
297 |
+
report.time_period_start = time_period_start
|
298 |
+
|
299 |
+
if time_period_end is not None:
|
300 |
+
report.time_period_end = time_period_end
|
301 |
+
|
302 |
+
if keywords is not None:
|
303 |
+
report.keywords = keywords
|
304 |
+
|
305 |
+
if source_data is not None:
|
306 |
+
report.source_data = source_data
|
307 |
+
|
308 |
+
await db.commit()
|
309 |
+
await db.refresh(report)
|
310 |
+
|
311 |
+
return report
|
312 |
+
|
313 |
+
|
314 |
+
async def add_threat_to_report(
|
315 |
+
db: AsyncSession,
|
316 |
+
report_id: int,
|
317 |
+
threat_id: int,
|
318 |
+
) -> Optional[Report]:
|
319 |
+
"""
|
320 |
+
Add a threat to a report.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
db: Database session
|
324 |
+
report_id: Report ID
|
325 |
+
threat_id: Threat ID
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
Optional[Report]: Updated report or None
|
329 |
+
"""
|
330 |
+
from src.api.services.threat_service import get_threat_by_id
|
331 |
+
|
332 |
+
# Get report and threat
|
333 |
+
report = await get_report_by_id(db, report_id)
|
334 |
+
threat = await get_threat_by_id(db, threat_id)
|
335 |
+
|
336 |
+
if not report or not threat:
|
337 |
+
return None
|
338 |
+
|
339 |
+
# Add threat to report
|
340 |
+
report.threats.append(threat)
|
341 |
+
await db.commit()
|
342 |
+
await db.refresh(report)
|
343 |
+
|
344 |
+
return report
|
345 |
+
|
346 |
+
|
347 |
+
async def remove_threat_from_report(
|
348 |
+
db: AsyncSession,
|
349 |
+
report_id: int,
|
350 |
+
threat_id: int,
|
351 |
+
) -> Optional[Report]:
|
352 |
+
"""
|
353 |
+
Remove a threat from a report.
|
354 |
+
|
355 |
+
Args:
|
356 |
+
db: Database session
|
357 |
+
report_id: Report ID
|
358 |
+
threat_id: Threat ID
|
359 |
+
|
360 |
+
Returns:
|
361 |
+
Optional[Report]: Updated report or None
|
362 |
+
"""
|
363 |
+
from src.api.services.threat_service import get_threat_by_id
|
364 |
+
|
365 |
+
# Get report and threat
|
366 |
+
report = await get_report_by_id(db, report_id)
|
367 |
+
threat = await get_threat_by_id(db, threat_id)
|
368 |
+
|
369 |
+
if not report or not threat:
|
370 |
+
return None
|
371 |
+
|
372 |
+
# Remove threat from report
|
373 |
+
if threat in report.threats:
|
374 |
+
report.threats.remove(threat)
|
375 |
+
await db.commit()
|
376 |
+
await db.refresh(report)
|
377 |
+
|
378 |
+
return report
|
379 |
+
|
380 |
+
|
381 |
+
async def publish_report(
|
382 |
+
db: AsyncSession,
|
383 |
+
report_id: int,
|
384 |
+
) -> Optional[Report]:
|
385 |
+
"""
|
386 |
+
Publish a report.
|
387 |
+
|
388 |
+
Args:
|
389 |
+
db: Database session
|
390 |
+
report_id: Report ID
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
Optional[Report]: Updated report or None
|
394 |
+
"""
|
395 |
+
report = await get_report_by_id(db, report_id)
|
396 |
+
|
397 |
+
if not report:
|
398 |
+
return None
|
399 |
+
|
400 |
+
# Update status and publish date
|
401 |
+
report.status = ReportStatus.PUBLISHED
|
402 |
+
|
403 |
+
if not report.publish_date:
|
404 |
+
report.publish_date = datetime.utcnow()
|
405 |
+
|
406 |
+
await db.commit()
|
407 |
+
await db.refresh(report)
|
408 |
+
|
409 |
+
return report
|
410 |
+
|
411 |
+
|
412 |
+
async def archive_report(
|
413 |
+
db: AsyncSession,
|
414 |
+
report_id: int,
|
415 |
+
) -> Optional[Report]:
|
416 |
+
"""
|
417 |
+
Archive a report.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
db: Database session
|
421 |
+
report_id: Report ID
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
Optional[Report]: Updated report or None
|
425 |
+
"""
|
426 |
+
report = await get_report_by_id(db, report_id)
|
427 |
+
|
428 |
+
if not report:
|
429 |
+
return None
|
430 |
+
|
431 |
+
# Update status
|
432 |
+
report.status = ReportStatus.ARCHIVED
|
433 |
+
await db.commit()
|
434 |
+
await db.refresh(report)
|
435 |
+
|
436 |
+
return report
|
src/api/services/search_history_service.py
ADDED
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Search History and Trends Service
|
3 |
+
|
4 |
+
This service manages search history, saved searches, and trend analysis.
|
5 |
+
"""
|
6 |
+
import logging
|
7 |
+
import json
|
8 |
+
from datetime import datetime, timedelta
|
9 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
10 |
+
|
11 |
+
from sqlalchemy import func, desc, and_, or_, text
|
12 |
+
from sqlalchemy.future import select
|
13 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
14 |
+
from sqlalchemy.orm import selectinload
|
15 |
+
|
16 |
+
from src.models.search_history import SearchHistory, SearchResult, SavedSearch, TrendTopic
|
17 |
+
from src.models.dark_web_content import DarkWebContent
|
18 |
+
from src.models.user import User
|
19 |
+
|
20 |
+
# Configure logging
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
async def add_search_history(
|
24 |
+
db: AsyncSession,
|
25 |
+
query: str,
|
26 |
+
user_id: Optional[int] = None,
|
27 |
+
result_count: int = 0,
|
28 |
+
category: Optional[str] = None,
|
29 |
+
is_saved: bool = False,
|
30 |
+
notes: Optional[str] = None,
|
31 |
+
tags: Optional[str] = None
|
32 |
+
) -> SearchHistory:
|
33 |
+
"""
|
34 |
+
Add a new search history entry.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
db: Database session
|
38 |
+
query: Search query
|
39 |
+
user_id: ID of the user who performed the search (optional)
|
40 |
+
result_count: Number of results returned
|
41 |
+
category: Category of the search
|
42 |
+
is_saved: Whether this is a saved search
|
43 |
+
notes: Optional notes
|
44 |
+
tags: Optional tags (comma-separated)
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
The created SearchHistory object
|
48 |
+
"""
|
49 |
+
search_history = SearchHistory(
|
50 |
+
query=query,
|
51 |
+
user_id=user_id,
|
52 |
+
result_count=result_count,
|
53 |
+
category=category,
|
54 |
+
is_saved=is_saved,
|
55 |
+
notes=notes,
|
56 |
+
tags=tags
|
57 |
+
)
|
58 |
+
|
59 |
+
db.add(search_history)
|
60 |
+
await db.commit()
|
61 |
+
await db.refresh(search_history)
|
62 |
+
|
63 |
+
# Update trend data
|
64 |
+
await update_trend_data(db, query, category)
|
65 |
+
|
66 |
+
return search_history
|
67 |
+
|
68 |
+
async def add_search_result(
|
69 |
+
db: AsyncSession,
|
70 |
+
search_id: int,
|
71 |
+
url: str,
|
72 |
+
title: Optional[str] = None,
|
73 |
+
snippet: Optional[str] = None,
|
74 |
+
source: Optional[str] = None,
|
75 |
+
relevance_score: float = 0.0,
|
76 |
+
content_id: Optional[int] = None
|
77 |
+
) -> SearchResult:
|
78 |
+
"""
|
79 |
+
Add a new search result.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
db: Database session
|
83 |
+
search_id: ID of the parent search
|
84 |
+
url: URL of the result
|
85 |
+
title: Title of the result
|
86 |
+
snippet: Text snippet from the result
|
87 |
+
source: Source of the result
|
88 |
+
relevance_score: Score indicating relevance to the search query
|
89 |
+
content_id: ID of the content in our database (if applicable)
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
The created SearchResult object
|
93 |
+
"""
|
94 |
+
search_result = SearchResult(
|
95 |
+
search_id=search_id,
|
96 |
+
url=url,
|
97 |
+
title=title,
|
98 |
+
snippet=snippet,
|
99 |
+
source=source,
|
100 |
+
relevance_score=relevance_score,
|
101 |
+
content_id=content_id
|
102 |
+
)
|
103 |
+
|
104 |
+
db.add(search_result)
|
105 |
+
await db.commit()
|
106 |
+
await db.refresh(search_result)
|
107 |
+
|
108 |
+
return search_result
|
109 |
+
|
110 |
+
async def get_search_history(
|
111 |
+
db: AsyncSession,
|
112 |
+
skip: int = 0,
|
113 |
+
limit: int = 100,
|
114 |
+
user_id: Optional[int] = None,
|
115 |
+
query_filter: Optional[str] = None,
|
116 |
+
date_from: Optional[datetime] = None,
|
117 |
+
date_to: Optional[datetime] = None,
|
118 |
+
category: Optional[str] = None,
|
119 |
+
is_saved: Optional[bool] = None,
|
120 |
+
include_results: bool = False
|
121 |
+
) -> List[SearchHistory]:
|
122 |
+
"""
|
123 |
+
Get search history with filtering options.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
db: Database session
|
127 |
+
skip: Number of items to skip
|
128 |
+
limit: Maximum number of items to return
|
129 |
+
user_id: Filter by user ID
|
130 |
+
query_filter: Filter by search query (partial match)
|
131 |
+
date_from: Filter by timestamp (from)
|
132 |
+
date_to: Filter by timestamp (to)
|
133 |
+
category: Filter by category
|
134 |
+
is_saved: Filter by saved status
|
135 |
+
include_results: Whether to include search results
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
List of SearchHistory objects
|
139 |
+
"""
|
140 |
+
statement = select(SearchHistory)
|
141 |
+
|
142 |
+
# Apply filters
|
143 |
+
if user_id is not None:
|
144 |
+
statement = statement.where(SearchHistory.user_id == user_id)
|
145 |
+
|
146 |
+
if query_filter:
|
147 |
+
statement = statement.where(SearchHistory.query.ilike(f"%{query_filter}%"))
|
148 |
+
|
149 |
+
if date_from:
|
150 |
+
statement = statement.where(SearchHistory.timestamp >= date_from)
|
151 |
+
|
152 |
+
if date_to:
|
153 |
+
statement = statement.where(SearchHistory.timestamp <= date_to)
|
154 |
+
|
155 |
+
if category:
|
156 |
+
statement = statement.where(SearchHistory.category == category)
|
157 |
+
|
158 |
+
if is_saved is not None:
|
159 |
+
statement = statement.where(SearchHistory.is_saved == is_saved)
|
160 |
+
|
161 |
+
# Load related data if requested
|
162 |
+
if include_results:
|
163 |
+
statement = statement.options(selectinload(SearchHistory.search_results))
|
164 |
+
|
165 |
+
# Apply pagination
|
166 |
+
statement = statement.order_by(desc(SearchHistory.timestamp)).offset(skip).limit(limit)
|
167 |
+
|
168 |
+
result = await db.execute(statement)
|
169 |
+
return result.scalars().all()
|
170 |
+
|
171 |
+
async def get_search_by_id(
|
172 |
+
db: AsyncSession,
|
173 |
+
search_id: int,
|
174 |
+
include_results: bool = False
|
175 |
+
) -> Optional[SearchHistory]:
|
176 |
+
"""
|
177 |
+
Get a search history entry by ID.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
db: Database session
|
181 |
+
search_id: Search history ID
|
182 |
+
include_results: Whether to include search results
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
SearchHistory object or None if not found
|
186 |
+
"""
|
187 |
+
statement = select(SearchHistory).where(SearchHistory.id == search_id)
|
188 |
+
|
189 |
+
if include_results:
|
190 |
+
statement = statement.options(selectinload(SearchHistory.search_results))
|
191 |
+
|
192 |
+
result = await db.execute(statement)
|
193 |
+
return result.scalars().first()
|
194 |
+
|
195 |
+
async def delete_search_history(db: AsyncSession, search_id: int) -> bool:
|
196 |
+
"""
|
197 |
+
Delete a search history entry.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
db: Database session
|
201 |
+
search_id: ID of the search to delete
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
True if successful, False otherwise
|
205 |
+
"""
|
206 |
+
search = await get_search_by_id(db, search_id)
|
207 |
+
if not search:
|
208 |
+
return False
|
209 |
+
|
210 |
+
await db.delete(search)
|
211 |
+
await db.commit()
|
212 |
+
return True
|
213 |
+
|
214 |
+
async def save_search(
|
215 |
+
db: AsyncSession,
|
216 |
+
search_id: int,
|
217 |
+
is_saved: bool = True,
|
218 |
+
notes: Optional[str] = None,
|
219 |
+
tags: Optional[str] = None
|
220 |
+
) -> Optional[SearchHistory]:
|
221 |
+
"""
|
222 |
+
Save or unsave a search history entry.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
db: Database session
|
226 |
+
search_id: ID of the search
|
227 |
+
is_saved: Whether to save or unsave
|
228 |
+
notes: Optional notes to add
|
229 |
+
tags: Optional tags to add (comma-separated)
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
Updated SearchHistory object or None if not found
|
233 |
+
"""
|
234 |
+
search = await get_search_by_id(db, search_id)
|
235 |
+
if not search:
|
236 |
+
return None
|
237 |
+
|
238 |
+
search.is_saved = is_saved
|
239 |
+
|
240 |
+
if notes:
|
241 |
+
search.notes = notes
|
242 |
+
|
243 |
+
if tags:
|
244 |
+
search.tags = tags
|
245 |
+
|
246 |
+
await db.commit()
|
247 |
+
await db.refresh(search)
|
248 |
+
return search
|
249 |
+
|
250 |
+
async def create_saved_search(
|
251 |
+
db: AsyncSession,
|
252 |
+
name: str,
|
253 |
+
query: str,
|
254 |
+
user_id: int,
|
255 |
+
frequency: int = 24,
|
256 |
+
notification_enabled: bool = True,
|
257 |
+
threshold: int = 1,
|
258 |
+
category: Optional[str] = None
|
259 |
+
) -> SavedSearch:
|
260 |
+
"""
|
261 |
+
Create a new saved search with periodic monitoring.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
db: Database session
|
265 |
+
name: Name of the saved search
|
266 |
+
query: Search query
|
267 |
+
user_id: ID of the user
|
268 |
+
frequency: How often to run this search (in hours, 0 for manual only)
|
269 |
+
notification_enabled: Whether to send notifications for new results
|
270 |
+
threshold: Minimum number of new results for notification
|
271 |
+
category: Category of the search
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
The created SavedSearch object
|
275 |
+
"""
|
276 |
+
saved_search = SavedSearch(
|
277 |
+
name=name,
|
278 |
+
query=query,
|
279 |
+
user_id=user_id,
|
280 |
+
frequency=frequency,
|
281 |
+
notification_enabled=notification_enabled,
|
282 |
+
threshold=threshold,
|
283 |
+
category=category
|
284 |
+
)
|
285 |
+
|
286 |
+
db.add(saved_search)
|
287 |
+
await db.commit()
|
288 |
+
await db.refresh(saved_search)
|
289 |
+
|
290 |
+
return saved_search
|
291 |
+
|
292 |
+
async def get_saved_searches(
|
293 |
+
db: AsyncSession,
|
294 |
+
user_id: Optional[int] = None,
|
295 |
+
is_active: Optional[bool] = None,
|
296 |
+
skip: int = 0,
|
297 |
+
limit: int = 100
|
298 |
+
) -> List[SavedSearch]:
|
299 |
+
"""
|
300 |
+
Get saved searches with filtering options.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
db: Database session
|
304 |
+
user_id: Filter by user ID
|
305 |
+
is_active: Filter by active status
|
306 |
+
skip: Number of items to skip
|
307 |
+
limit: Maximum number of items to return
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
List of SavedSearch objects
|
311 |
+
"""
|
312 |
+
statement = select(SavedSearch)
|
313 |
+
|
314 |
+
# Apply filters
|
315 |
+
if user_id is not None:
|
316 |
+
statement = statement.where(SavedSearch.user_id == user_id)
|
317 |
+
|
318 |
+
if is_active is not None:
|
319 |
+
statement = statement.where(SavedSearch.is_active == is_active)
|
320 |
+
|
321 |
+
# Apply pagination
|
322 |
+
statement = statement.order_by(SavedSearch.name).offset(skip).limit(limit)
|
323 |
+
|
324 |
+
result = await db.execute(statement)
|
325 |
+
return result.scalars().all()
|
326 |
+
|
327 |
+
async def update_trend_data(
|
328 |
+
db: AsyncSession,
|
329 |
+
query: str,
|
330 |
+
category: Optional[str] = None
|
331 |
+
) -> None:
|
332 |
+
"""
|
333 |
+
Update trend data based on search queries.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
db: Database session
|
337 |
+
query: Search query
|
338 |
+
category: Category of the search
|
339 |
+
"""
|
340 |
+
# Split query into individual terms/topics
|
341 |
+
topics = [t.strip() for t in query.split() if len(t.strip()) > 3]
|
342 |
+
|
343 |
+
# Process each topic
|
344 |
+
for topic in topics:
|
345 |
+
# Check if topic already exists
|
346 |
+
statement = select(TrendTopic).where(TrendTopic.topic == topic)
|
347 |
+
result = await db.execute(statement)
|
348 |
+
trend_topic = result.scalars().first()
|
349 |
+
|
350 |
+
if trend_topic:
|
351 |
+
# Update existing topic
|
352 |
+
trend_topic.last_seen = datetime.utcnow()
|
353 |
+
trend_topic.mention_count += 1
|
354 |
+
|
355 |
+
# Calculate growth rate (percentage change over the last 24 hours)
|
356 |
+
time_diff = (trend_topic.last_seen - trend_topic.first_seen).total_seconds() / 3600 # hours
|
357 |
+
if time_diff > 0:
|
358 |
+
hourly_rate = trend_topic.mention_count / time_diff
|
359 |
+
trend_topic.growth_rate = hourly_rate * 24 # daily growth rate
|
360 |
+
|
361 |
+
# Update category if provided and not already set
|
362 |
+
if category and not trend_topic.category:
|
363 |
+
trend_topic.category = category
|
364 |
+
else:
|
365 |
+
# Create a new trend topic
|
366 |
+
trend_topic = TrendTopic(
|
367 |
+
topic=topic,
|
368 |
+
category=category,
|
369 |
+
mention_count=1,
|
370 |
+
growth_rate=1.0 # Initial growth rate
|
371 |
+
)
|
372 |
+
db.add(trend_topic)
|
373 |
+
|
374 |
+
await db.commit()
|
375 |
+
|
376 |
+
async def get_trending_topics(
|
377 |
+
db: AsyncSession,
|
378 |
+
days: int = 7,
|
379 |
+
limit: int = 20,
|
380 |
+
category: Optional[str] = None,
|
381 |
+
min_mentions: int = 3
|
382 |
+
) -> List[TrendTopic]:
|
383 |
+
"""
|
384 |
+
Get trending topics over a specific time period.
|
385 |
+
|
386 |
+
Args:
|
387 |
+
db: Database session
|
388 |
+
days: Number of days to consider
|
389 |
+
limit: Maximum number of topics to return
|
390 |
+
category: Filter by category
|
391 |
+
min_mentions: Minimum number of mentions
|
392 |
+
|
393 |
+
Returns:
|
394 |
+
List of TrendTopic objects sorted by growth rate
|
395 |
+
"""
|
396 |
+
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
397 |
+
|
398 |
+
statement = select(TrendTopic).where(
|
399 |
+
and_(
|
400 |
+
TrendTopic.last_seen >= cutoff_date,
|
401 |
+
TrendTopic.mention_count >= min_mentions,
|
402 |
+
TrendTopic.is_active == True
|
403 |
+
)
|
404 |
+
)
|
405 |
+
|
406 |
+
if category:
|
407 |
+
statement = statement.where(TrendTopic.category == category)
|
408 |
+
|
409 |
+
statement = statement.order_by(desc(TrendTopic.growth_rate)).limit(limit)
|
410 |
+
|
411 |
+
result = await db.execute(statement)
|
412 |
+
return result.scalars().all()
|
413 |
+
|
414 |
+
async def get_search_frequency(
|
415 |
+
db: AsyncSession,
|
416 |
+
days: int = 30,
|
417 |
+
interval: str = 'day'
|
418 |
+
) -> List[Dict[str, Any]]:
|
419 |
+
"""
|
420 |
+
Get search frequency over time for visualization.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
db: Database session
|
424 |
+
days: Number of days to analyze
|
425 |
+
interval: Time interval ('hour', 'day', 'week', 'month')
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
List of dictionaries with time intervals and search counts
|
429 |
+
"""
|
430 |
+
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
431 |
+
|
432 |
+
# SQL query depends on the interval
|
433 |
+
if interval == 'hour':
|
434 |
+
date_format = "YYYY-MM-DD HH24:00"
|
435 |
+
trunc_expr = func.date_trunc('hour', SearchHistory.timestamp)
|
436 |
+
elif interval == 'day':
|
437 |
+
date_format = "YYYY-MM-DD"
|
438 |
+
trunc_expr = func.date_trunc('day', SearchHistory.timestamp)
|
439 |
+
elif interval == 'week':
|
440 |
+
date_format = "YYYY-WW"
|
441 |
+
trunc_expr = func.date_trunc('week', SearchHistory.timestamp)
|
442 |
+
else: # month
|
443 |
+
date_format = "YYYY-MM"
|
444 |
+
trunc_expr = func.date_trunc('month', SearchHistory.timestamp)
|
445 |
+
|
446 |
+
# Query for search count by interval
|
447 |
+
statement = select(
|
448 |
+
trunc_expr.label('interval'),
|
449 |
+
func.count(SearchHistory.id).label('count')
|
450 |
+
).where(
|
451 |
+
SearchHistory.timestamp >= cutoff_date
|
452 |
+
).group_by(
|
453 |
+
'interval'
|
454 |
+
).order_by(
|
455 |
+
'interval'
|
456 |
+
)
|
457 |
+
|
458 |
+
result = await db.execute(statement)
|
459 |
+
rows = result.all()
|
460 |
+
|
461 |
+
# Convert to list of dictionaries
|
462 |
+
return [{"interval": row.interval, "count": row.count} for row in rows]
|
463 |
+
|
464 |
+
async def get_popular_searches(
|
465 |
+
db: AsyncSession,
|
466 |
+
days: int = 30,
|
467 |
+
limit: int = 10
|
468 |
+
) -> List[Dict[str, Any]]:
|
469 |
+
"""
|
470 |
+
Get the most popular search terms.
|
471 |
+
|
472 |
+
Args:
|
473 |
+
db: Database session
|
474 |
+
days: Number of days to analyze
|
475 |
+
limit: Maximum number of terms to return
|
476 |
+
|
477 |
+
Returns:
|
478 |
+
List of dictionaries with search queries and counts
|
479 |
+
"""
|
480 |
+
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
481 |
+
|
482 |
+
statement = select(
|
483 |
+
SearchHistory.query,
|
484 |
+
func.count(SearchHistory.id).label('count')
|
485 |
+
).where(
|
486 |
+
SearchHistory.timestamp >= cutoff_date
|
487 |
+
).group_by(
|
488 |
+
SearchHistory.query
|
489 |
+
).order_by(
|
490 |
+
desc('count')
|
491 |
+
).limit(limit)
|
492 |
+
|
493 |
+
result = await db.execute(statement)
|
494 |
+
rows = result.all()
|
495 |
+
|
496 |
+
return [{"query": row.query, "count": row.count} for row in rows]
|
497 |
+
|
498 |
+
async def get_search_categories(
|
499 |
+
db: AsyncSession,
|
500 |
+
days: int = 30
|
501 |
+
) -> List[Dict[str, Any]]:
|
502 |
+
"""
|
503 |
+
Get distribution of search categories.
|
504 |
+
|
505 |
+
Args:
|
506 |
+
db: Database session
|
507 |
+
days: Number of days to analyze
|
508 |
+
|
509 |
+
Returns:
|
510 |
+
List of dictionaries with categories and counts
|
511 |
+
"""
|
512 |
+
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
513 |
+
|
514 |
+
statement = select(
|
515 |
+
SearchHistory.category,
|
516 |
+
func.count(SearchHistory.id).label('count')
|
517 |
+
).where(
|
518 |
+
and_(
|
519 |
+
SearchHistory.timestamp >= cutoff_date,
|
520 |
+
SearchHistory.category.is_not(None)
|
521 |
+
)
|
522 |
+
).group_by(
|
523 |
+
SearchHistory.category
|
524 |
+
).order_by(
|
525 |
+
desc('count')
|
526 |
+
)
|
527 |
+
|
528 |
+
result = await db.execute(statement)
|
529 |
+
rows = result.all()
|
530 |
+
|
531 |
+
return [{"category": row.category or "Uncategorized", "count": row.count} for row in rows]
|
532 |
+
|
533 |
+
async def get_search_trend_analysis(
|
534 |
+
db: AsyncSession,
|
535 |
+
days: int = 90,
|
536 |
+
trend_days: int = 7,
|
537 |
+
limit: int = 10
|
538 |
+
) -> Dict[str, Any]:
|
539 |
+
"""
|
540 |
+
Get comprehensive analysis of search trends.
|
541 |
+
|
542 |
+
Args:
|
543 |
+
db: Database session
|
544 |
+
days: Total days to analyze
|
545 |
+
trend_days: Days to calculate short-term trends
|
546 |
+
limit: Maximum number of items in each category
|
547 |
+
|
548 |
+
Returns:
|
549 |
+
Dictionary with various trend analyses
|
550 |
+
"""
|
551 |
+
# Get overall search frequency
|
552 |
+
frequency = await get_search_frequency(db, days, 'day')
|
553 |
+
|
554 |
+
# Get popular searches
|
555 |
+
popular = await get_popular_searches(db, days, limit)
|
556 |
+
|
557 |
+
# Get recent trending topics
|
558 |
+
trending = await get_trending_topics(db, trend_days, limit)
|
559 |
+
|
560 |
+
# Get category distribution
|
561 |
+
categories = await get_search_categories(db, days)
|
562 |
+
|
563 |
+
# Get recent (last 24 hours) vs. overall popular terms
|
564 |
+
recent_popular = await get_popular_searches(db, 1, limit)
|
565 |
+
|
566 |
+
# Calculate velocity (rate of change)
|
567 |
+
# This compares the last 7 days to the previous 7 days
|
568 |
+
cutoff_recent = datetime.utcnow() - timedelta(days=trend_days)
|
569 |
+
cutoff_previous = cutoff_recent - timedelta(days=trend_days)
|
570 |
+
|
571 |
+
# Query for velocity calculation
|
572 |
+
statement_recent = select(func.count(SearchHistory.id)).where(
|
573 |
+
SearchHistory.timestamp >= cutoff_recent
|
574 |
+
)
|
575 |
+
statement_previous = select(func.count(SearchHistory.id)).where(
|
576 |
+
and_(
|
577 |
+
SearchHistory.timestamp >= cutoff_previous,
|
578 |
+
SearchHistory.timestamp < cutoff_recent
|
579 |
+
)
|
580 |
+
)
|
581 |
+
|
582 |
+
result_recent = await db.execute(statement_recent)
|
583 |
+
result_previous = await db.execute(statement_previous)
|
584 |
+
|
585 |
+
count_recent = result_recent.scalar() or 0
|
586 |
+
count_previous = result_previous.scalar() or 0
|
587 |
+
|
588 |
+
if count_previous > 0:
|
589 |
+
velocity = (count_recent - count_previous) / count_previous * 100 # percentage change
|
590 |
+
else:
|
591 |
+
velocity = 100.0 if count_recent > 0 else 0.0
|
592 |
+
|
593 |
+
# Compile the results
|
594 |
+
return {
|
595 |
+
"frequency": frequency,
|
596 |
+
"popular_searches": popular,
|
597 |
+
"trending_topics": [
|
598 |
+
{"topic": t.topic, "mentions": t.mention_count, "growth_rate": t.growth_rate}
|
599 |
+
for t in trending
|
600 |
+
],
|
601 |
+
"categories": categories,
|
602 |
+
"recent_popular": recent_popular,
|
603 |
+
"velocity": velocity,
|
604 |
+
"total_searches": {
|
605 |
+
"total": count_recent + count_previous,
|
606 |
+
"recent": count_recent,
|
607 |
+
"previous": count_previous
|
608 |
+
}
|
609 |
+
}
|
src/api/services/subscription_service.py
ADDED
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Subscription service.
|
3 |
+
|
4 |
+
This module provides functions for managing subscriptions.
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
import logging
|
8 |
+
from datetime import datetime, timedelta
|
9 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import stripe
|
12 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
13 |
+
from sqlalchemy import select, update, delete
|
14 |
+
from sqlalchemy.orm import joinedload
|
15 |
+
|
16 |
+
from src.models.subscription import (
|
17 |
+
SubscriptionPlan, UserSubscription, PaymentHistory,
|
18 |
+
SubscriptionTier, BillingPeriod, SubscriptionStatus, PaymentStatus
|
19 |
+
)
|
20 |
+
from src.models.user import User
|
21 |
+
|
22 |
+
# Set up Stripe API key
|
23 |
+
stripe.api_key = os.environ.get("STRIPE_SECRET_KEY")
|
24 |
+
STRIPE_PUBLISHABLE_KEY = os.environ.get("STRIPE_PUBLISHABLE_KEY")
|
25 |
+
|
26 |
+
# Set up logging
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
async def get_subscription_plans(
|
31 |
+
db: AsyncSession,
|
32 |
+
active_only: bool = True
|
33 |
+
) -> List[SubscriptionPlan]:
|
34 |
+
"""
|
35 |
+
Get all subscription plans.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
db: Database session
|
39 |
+
active_only: If True, only return active plans
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
List of subscription plans
|
43 |
+
"""
|
44 |
+
query = select(SubscriptionPlan)
|
45 |
+
|
46 |
+
if active_only:
|
47 |
+
query = query.where(SubscriptionPlan.is_active == True)
|
48 |
+
|
49 |
+
result = await db.execute(query)
|
50 |
+
plans = result.scalars().all()
|
51 |
+
|
52 |
+
return plans
|
53 |
+
|
54 |
+
|
55 |
+
async def get_subscription_plan_by_id(
|
56 |
+
db: AsyncSession,
|
57 |
+
plan_id: int
|
58 |
+
) -> Optional[SubscriptionPlan]:
|
59 |
+
"""
|
60 |
+
Get a subscription plan by ID.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
db: Database session
|
64 |
+
plan_id: ID of the plan to get
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Subscription plan or None if not found
|
68 |
+
"""
|
69 |
+
query = select(SubscriptionPlan).where(SubscriptionPlan.id == plan_id)
|
70 |
+
result = await db.execute(query)
|
71 |
+
plan = result.scalars().first()
|
72 |
+
|
73 |
+
return plan
|
74 |
+
|
75 |
+
|
76 |
+
async def get_subscription_plan_by_tier(
|
77 |
+
db: AsyncSession,
|
78 |
+
tier: SubscriptionTier
|
79 |
+
) -> Optional[SubscriptionPlan]:
|
80 |
+
"""
|
81 |
+
Get a subscription plan by tier.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
db: Database session
|
85 |
+
tier: Tier of the plan to get
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
Subscription plan or None if not found
|
89 |
+
"""
|
90 |
+
query = select(SubscriptionPlan).where(SubscriptionPlan.tier == tier)
|
91 |
+
result = await db.execute(query)
|
92 |
+
plan = result.scalars().first()
|
93 |
+
|
94 |
+
return plan
|
95 |
+
|
96 |
+
|
97 |
+
async def create_subscription_plan(
|
98 |
+
db: AsyncSession,
|
99 |
+
name: str,
|
100 |
+
tier: SubscriptionTier,
|
101 |
+
description: str,
|
102 |
+
price_monthly: float,
|
103 |
+
price_annually: float,
|
104 |
+
max_alerts: int = 10,
|
105 |
+
max_reports: int = 5,
|
106 |
+
max_searches_per_day: int = 20,
|
107 |
+
max_monitoring_keywords: int = 10,
|
108 |
+
max_data_retention_days: int = 30,
|
109 |
+
supports_api_access: bool = False,
|
110 |
+
supports_live_feed: bool = False,
|
111 |
+
supports_dark_web_monitoring: bool = False,
|
112 |
+
supports_export: bool = False,
|
113 |
+
supports_advanced_analytics: bool = False,
|
114 |
+
create_stripe_product: bool = True
|
115 |
+
) -> Optional[SubscriptionPlan]:
|
116 |
+
"""
|
117 |
+
Create a new subscription plan.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
db: Database session
|
121 |
+
name: Name of the plan
|
122 |
+
tier: Tier of the plan
|
123 |
+
description: Description of the plan
|
124 |
+
price_monthly: Monthly price of the plan
|
125 |
+
price_annually: Annual price of the plan
|
126 |
+
max_alerts: Maximum number of alerts allowed
|
127 |
+
max_reports: Maximum number of reports allowed
|
128 |
+
max_searches_per_day: Maximum number of searches per day
|
129 |
+
max_monitoring_keywords: Maximum number of monitoring keywords
|
130 |
+
max_data_retention_days: Maximum number of days to retain data
|
131 |
+
supports_api_access: Whether the plan supports API access
|
132 |
+
supports_live_feed: Whether the plan supports live feed
|
133 |
+
supports_dark_web_monitoring: Whether the plan supports dark web monitoring
|
134 |
+
supports_export: Whether the plan supports data export
|
135 |
+
supports_advanced_analytics: Whether the plan supports advanced analytics
|
136 |
+
create_stripe_product: Whether to create a Stripe product for this plan
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
Created subscription plan or None if creation failed
|
140 |
+
"""
|
141 |
+
# Check if plan with the same tier already exists
|
142 |
+
existing_plan = await get_subscription_plan_by_tier(db, tier)
|
143 |
+
|
144 |
+
if existing_plan:
|
145 |
+
logger.warning(f"Subscription plan with tier {tier} already exists")
|
146 |
+
return None
|
147 |
+
|
148 |
+
# Create Stripe product if requested
|
149 |
+
stripe_product_id = None
|
150 |
+
stripe_monthly_price_id = None
|
151 |
+
stripe_annual_price_id = None
|
152 |
+
|
153 |
+
if create_stripe_product and stripe.api_key:
|
154 |
+
try:
|
155 |
+
# Create Stripe product
|
156 |
+
product = stripe.Product.create(
|
157 |
+
name=name,
|
158 |
+
description=description,
|
159 |
+
metadata={
|
160 |
+
"tier": tier.value,
|
161 |
+
"max_alerts": max_alerts,
|
162 |
+
"max_reports": max_reports,
|
163 |
+
"max_searches_per_day": max_searches_per_day,
|
164 |
+
"max_monitoring_keywords": max_monitoring_keywords,
|
165 |
+
"max_data_retention_days": max_data_retention_days,
|
166 |
+
"supports_api_access": "yes" if supports_api_access else "no",
|
167 |
+
"supports_live_feed": "yes" if supports_live_feed else "no",
|
168 |
+
"supports_dark_web_monitoring": "yes" if supports_dark_web_monitoring else "no",
|
169 |
+
"supports_export": "yes" if supports_export else "no",
|
170 |
+
"supports_advanced_analytics": "yes" if supports_advanced_analytics else "no"
|
171 |
+
}
|
172 |
+
)
|
173 |
+
|
174 |
+
stripe_product_id = product.id
|
175 |
+
|
176 |
+
# Create monthly price
|
177 |
+
monthly_price = stripe.Price.create(
|
178 |
+
product=product.id,
|
179 |
+
unit_amount=int(price_monthly * 100), # Stripe uses cents
|
180 |
+
currency="usd",
|
181 |
+
recurring={"interval": "month"},
|
182 |
+
metadata={"billing_period": "monthly"}
|
183 |
+
)
|
184 |
+
|
185 |
+
stripe_monthly_price_id = monthly_price.id
|
186 |
+
|
187 |
+
# Create annual price
|
188 |
+
annual_price = stripe.Price.create(
|
189 |
+
product=product.id,
|
190 |
+
unit_amount=int(price_annually * 100), # Stripe uses cents
|
191 |
+
currency="usd",
|
192 |
+
recurring={"interval": "year"},
|
193 |
+
metadata={"billing_period": "annually"}
|
194 |
+
)
|
195 |
+
|
196 |
+
stripe_annual_price_id = annual_price.id
|
197 |
+
|
198 |
+
logger.info(f"Created Stripe product {product.id} for plan {name}")
|
199 |
+
except Exception as e:
|
200 |
+
logger.error(f"Failed to create Stripe product for plan {name}: {e}")
|
201 |
+
|
202 |
+
# Create plan in database
|
203 |
+
plan = SubscriptionPlan(
|
204 |
+
name=name,
|
205 |
+
tier=tier,
|
206 |
+
description=description,
|
207 |
+
price_monthly=price_monthly,
|
208 |
+
price_annually=price_annually,
|
209 |
+
max_alerts=max_alerts,
|
210 |
+
max_reports=max_reports,
|
211 |
+
max_searches_per_day=max_searches_per_day,
|
212 |
+
max_monitoring_keywords=max_monitoring_keywords,
|
213 |
+
max_data_retention_days=max_data_retention_days,
|
214 |
+
supports_api_access=supports_api_access,
|
215 |
+
supports_live_feed=supports_live_feed,
|
216 |
+
supports_dark_web_monitoring=supports_dark_web_monitoring,
|
217 |
+
supports_export=supports_export,
|
218 |
+
supports_advanced_analytics=supports_advanced_analytics,
|
219 |
+
stripe_product_id=stripe_product_id,
|
220 |
+
stripe_monthly_price_id=stripe_monthly_price_id,
|
221 |
+
stripe_annual_price_id=stripe_annual_price_id
|
222 |
+
)
|
223 |
+
|
224 |
+
db.add(plan)
|
225 |
+
await db.commit()
|
226 |
+
await db.refresh(plan)
|
227 |
+
|
228 |
+
return plan
|
229 |
+
|
230 |
+
|
231 |
+
async def update_subscription_plan(
|
232 |
+
db: AsyncSession,
|
233 |
+
plan_id: int,
|
234 |
+
name: Optional[str] = None,
|
235 |
+
description: Optional[str] = None,
|
236 |
+
price_monthly: Optional[float] = None,
|
237 |
+
price_annually: Optional[float] = None,
|
238 |
+
is_active: Optional[bool] = None,
|
239 |
+
max_alerts: Optional[int] = None,
|
240 |
+
max_reports: Optional[int] = None,
|
241 |
+
max_searches_per_day: Optional[int] = None,
|
242 |
+
max_monitoring_keywords: Optional[int] = None,
|
243 |
+
max_data_retention_days: Optional[int] = None,
|
244 |
+
supports_api_access: Optional[bool] = None,
|
245 |
+
supports_live_feed: Optional[bool] = None,
|
246 |
+
supports_dark_web_monitoring: Optional[bool] = None,
|
247 |
+
supports_export: Optional[bool] = None,
|
248 |
+
supports_advanced_analytics: Optional[bool] = None,
|
249 |
+
update_stripe_product: bool = True
|
250 |
+
) -> Optional[SubscriptionPlan]:
|
251 |
+
"""
|
252 |
+
Update a subscription plan.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
db: Database session
|
256 |
+
plan_id: ID of the plan to update
|
257 |
+
name: New name of the plan
|
258 |
+
description: New description of the plan
|
259 |
+
price_monthly: New monthly price of the plan
|
260 |
+
price_annually: New annual price of the plan
|
261 |
+
is_active: New active status of the plan
|
262 |
+
max_alerts: New maximum number of alerts allowed
|
263 |
+
max_reports: New maximum number of reports allowed
|
264 |
+
max_searches_per_day: New maximum number of searches per day
|
265 |
+
max_monitoring_keywords: New maximum number of monitoring keywords
|
266 |
+
max_data_retention_days: New maximum number of days to retain data
|
267 |
+
supports_api_access: New API access support status
|
268 |
+
supports_live_feed: New live feed support status
|
269 |
+
supports_dark_web_monitoring: New dark web monitoring support status
|
270 |
+
supports_export: New data export support status
|
271 |
+
supports_advanced_analytics: New advanced analytics support status
|
272 |
+
update_stripe_product: Whether to update the Stripe product for this plan
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
Updated subscription plan or None if update failed
|
276 |
+
"""
|
277 |
+
# Get existing plan
|
278 |
+
plan = await get_subscription_plan_by_id(db, plan_id)
|
279 |
+
|
280 |
+
if not plan:
|
281 |
+
logger.warning(f"Subscription plan with ID {plan_id} not found")
|
282 |
+
return None
|
283 |
+
|
284 |
+
# Prepare update data
|
285 |
+
update_data = {}
|
286 |
+
|
287 |
+
if name is not None:
|
288 |
+
update_data["name"] = name
|
289 |
+
|
290 |
+
if description is not None:
|
291 |
+
update_data["description"] = description
|
292 |
+
|
293 |
+
if price_monthly is not None:
|
294 |
+
update_data["price_monthly"] = price_monthly
|
295 |
+
|
296 |
+
if price_annually is not None:
|
297 |
+
update_data["price_annually"] = price_annually
|
298 |
+
|
299 |
+
if is_active is not None:
|
300 |
+
update_data["is_active"] = is_active
|
301 |
+
|
302 |
+
if max_alerts is not None:
|
303 |
+
update_data["max_alerts"] = max_alerts
|
304 |
+
|
305 |
+
if max_reports is not None:
|
306 |
+
update_data["max_reports"] = max_reports
|
307 |
+
|
308 |
+
if max_searches_per_day is not None:
|
309 |
+
update_data["max_searches_per_day"] = max_searches_per_day
|
310 |
+
|
311 |
+
if max_monitoring_keywords is not None:
|
312 |
+
update_data["max_monitoring_keywords"] = max_monitoring_keywords
|
313 |
+
|
314 |
+
if max_data_retention_days is not None:
|
315 |
+
update_data["max_data_retention_days"] = max_data_retention_days
|
316 |
+
|
317 |
+
if supports_api_access is not None:
|
318 |
+
update_data["supports_api_access"] = supports_api_access
|
319 |
+
|
320 |
+
if supports_live_feed is not None:
|
321 |
+
update_data["supports_live_feed"] = supports_live_feed
|
322 |
+
|
323 |
+
if supports_dark_web_monitoring is not None:
|
324 |
+
update_data["supports_dark_web_monitoring"] = supports_dark_web_monitoring
|
325 |
+
|
326 |
+
if supports_export is not None:
|
327 |
+
update_data["supports_export"] = supports_export
|
328 |
+
|
329 |
+
if supports_advanced_analytics is not None:
|
330 |
+
update_data["supports_advanced_analytics"] = supports_advanced_analytics
|
331 |
+
|
332 |
+
# Update Stripe product if requested
|
333 |
+
if update_stripe_product and plan.stripe_product_id and stripe.api_key:
|
334 |
+
try:
|
335 |
+
# Update Stripe product
|
336 |
+
product_update_data = {}
|
337 |
+
|
338 |
+
if name is not None:
|
339 |
+
product_update_data["name"] = name
|
340 |
+
|
341 |
+
if description is not None:
|
342 |
+
product_update_data["description"] = description
|
343 |
+
|
344 |
+
metadata_update = {}
|
345 |
+
|
346 |
+
if max_alerts is not None:
|
347 |
+
metadata_update["max_alerts"] = max_alerts
|
348 |
+
|
349 |
+
if max_reports is not None:
|
350 |
+
metadata_update["max_reports"] = max_reports
|
351 |
+
|
352 |
+
if max_searches_per_day is not None:
|
353 |
+
metadata_update["max_searches_per_day"] = max_searches_per_day
|
354 |
+
|
355 |
+
if max_monitoring_keywords is not None:
|
356 |
+
metadata_update["max_monitoring_keywords"] = max_monitoring_keywords
|
357 |
+
|
358 |
+
if max_data_retention_days is not None:
|
359 |
+
metadata_update["max_data_retention_days"] = max_data_retention_days
|
360 |
+
|
361 |
+
if supports_api_access is not None:
|
362 |
+
metadata_update["supports_api_access"] = "yes" if supports_api_access else "no"
|
363 |
+
|
364 |
+
if supports_live_feed is not None:
|
365 |
+
metadata_update["supports_live_feed"] = "yes" if supports_live_feed else "no"
|
366 |
+
|
367 |
+
if supports_dark_web_monitoring is not None:
|
368 |
+
metadata_update["supports_dark_web_monitoring"] = "yes" if supports_dark_web_monitoring else "no"
|
369 |
+
|
370 |
+
if supports_export is not None:
|
371 |
+
metadata_update["supports_export"] = "yes" if supports_export else "no"
|
372 |
+
|
373 |
+
if supports_advanced_analytics is not None:
|
374 |
+
metadata_update["supports_advanced_analytics"] = "yes" if supports_advanced_analytics else "no"
|
375 |
+
|
376 |
+
if metadata_update:
|
377 |
+
product_update_data["metadata"] = metadata_update
|
378 |
+
|
379 |
+
if product_update_data:
|
380 |
+
stripe.Product.modify(plan.stripe_product_id, **product_update_data)
|
381 |
+
|
382 |
+
# Update prices if needed
|
383 |
+
if price_monthly is not None and plan.stripe_monthly_price_id:
|
384 |
+
# Can't update existing price in Stripe, create a new one
|
385 |
+
new_monthly_price = stripe.Price.create(
|
386 |
+
product=plan.stripe_product_id,
|
387 |
+
unit_amount=int(price_monthly * 100), # Stripe uses cents
|
388 |
+
currency="usd",
|
389 |
+
recurring={"interval": "month"},
|
390 |
+
metadata={"billing_period": "monthly"}
|
391 |
+
)
|
392 |
+
|
393 |
+
update_data["stripe_monthly_price_id"] = new_monthly_price.id
|
394 |
+
|
395 |
+
if price_annually is not None and plan.stripe_annual_price_id:
|
396 |
+
# Can't update existing price in Stripe, create a new one
|
397 |
+
new_annual_price = stripe.Price.create(
|
398 |
+
product=plan.stripe_product_id,
|
399 |
+
unit_amount=int(price_annually * 100), # Stripe uses cents
|
400 |
+
currency="usd",
|
401 |
+
recurring={"interval": "year"},
|
402 |
+
metadata={"billing_period": "annually"}
|
403 |
+
)
|
404 |
+
|
405 |
+
update_data["stripe_annual_price_id"] = new_annual_price.id
|
406 |
+
|
407 |
+
logger.info(f"Updated Stripe product {plan.stripe_product_id} for plan {plan.name}")
|
408 |
+
except Exception as e:
|
409 |
+
logger.error(f"Failed to update Stripe product for plan {plan.name}: {e}")
|
410 |
+
|
411 |
+
# Update plan in database
|
412 |
+
if update_data:
|
413 |
+
await db.execute(
|
414 |
+
update(SubscriptionPlan)
|
415 |
+
.where(SubscriptionPlan.id == plan_id)
|
416 |
+
.values(**update_data)
|
417 |
+
)
|
418 |
+
|
419 |
+
await db.commit()
|
420 |
+
|
421 |
+
# Refresh plan
|
422 |
+
plan = await get_subscription_plan_by_id(db, plan_id)
|
423 |
+
|
424 |
+
return plan
|
425 |
+
|
426 |
+
|
427 |
+
async def get_user_subscription(
|
428 |
+
db: AsyncSession,
|
429 |
+
user_id: int
|
430 |
+
) -> Optional[UserSubscription]:
|
431 |
+
"""
|
432 |
+
Get a user's active subscription.
|
433 |
+
|
434 |
+
Args:
|
435 |
+
db: Database session
|
436 |
+
user_id: ID of the user
|
437 |
+
|
438 |
+
Returns:
|
439 |
+
User subscription or None if not found
|
440 |
+
"""
|
441 |
+
query = (
|
442 |
+
select(UserSubscription)
|
443 |
+
.where(UserSubscription.user_id == user_id)
|
444 |
+
.where(UserSubscription.status != SubscriptionStatus.CANCELED)
|
445 |
+
.options(joinedload(UserSubscription.plan))
|
446 |
+
)
|
447 |
+
|
448 |
+
result = await db.execute(query)
|
449 |
+
subscription = result.scalars().first()
|
450 |
+
|
451 |
+
return subscription
|
452 |
+
|
453 |
+
|
454 |
+
async def get_user_subscription_by_id(
|
455 |
+
db: AsyncSession,
|
456 |
+
subscription_id: int
|
457 |
+
) -> Optional[UserSubscription]:
|
458 |
+
"""
|
459 |
+
Get a user subscription by ID.
|
460 |
+
|
461 |
+
Args:
|
462 |
+
db: Database session
|
463 |
+
subscription_id: ID of the subscription
|
464 |
+
|
465 |
+
Returns:
|
466 |
+
User subscription or None if not found
|
467 |
+
"""
|
468 |
+
query = (
|
469 |
+
select(UserSubscription)
|
470 |
+
.where(UserSubscription.id == subscription_id)
|
471 |
+
.options(joinedload(UserSubscription.plan))
|
472 |
+
)
|
473 |
+
|
474 |
+
result = await db.execute(query)
|
475 |
+
subscription = result.scalars().first()
|
476 |
+
|
477 |
+
return subscription
|
478 |
+
|
479 |
+
|
480 |
+
async def create_user_subscription(
|
481 |
+
db: AsyncSession,
|
482 |
+
user_id: int,
|
483 |
+
plan_id: int,
|
484 |
+
billing_period: BillingPeriod = BillingPeriod.MONTHLY,
|
485 |
+
create_stripe_subscription: bool = True,
|
486 |
+
payment_method_id: Optional[str] = None
|
487 |
+
) -> Optional[UserSubscription]:
|
488 |
+
"""
|
489 |
+
Create a new user subscription.
|
490 |
+
|
491 |
+
Args:
|
492 |
+
db: Database session
|
493 |
+
user_id: ID of the user
|
494 |
+
plan_id: ID of the subscription plan
|
495 |
+
billing_period: Billing period (monthly or annually)
|
496 |
+
create_stripe_subscription: Whether to create a Stripe subscription
|
497 |
+
payment_method_id: ID of the payment method to use (required if create_stripe_subscription is True)
|
498 |
+
|
499 |
+
Returns:
|
500 |
+
Created user subscription or None if creation failed
|
501 |
+
"""
|
502 |
+
# Check if user exists
|
503 |
+
query = select(User).where(User.id == user_id)
|
504 |
+
result = await db.execute(query)
|
505 |
+
user = result.scalars().first()
|
506 |
+
|
507 |
+
if not user:
|
508 |
+
logger.warning(f"User with ID {user_id} not found")
|
509 |
+
return None
|
510 |
+
|
511 |
+
# Check if plan exists
|
512 |
+
plan = await get_subscription_plan_by_id(db, plan_id)
|
513 |
+
|
514 |
+
if not plan:
|
515 |
+
logger.warning(f"Subscription plan with ID {plan_id} not found")
|
516 |
+
return None
|
517 |
+
|
518 |
+
# Check if user already has an active subscription
|
519 |
+
existing_subscription = await get_user_subscription(db, user_id)
|
520 |
+
|
521 |
+
if existing_subscription:
|
522 |
+
logger.warning(f"User with ID {user_id} already has an active subscription")
|
523 |
+
return None
|
524 |
+
|
525 |
+
# Calculate subscription period
|
526 |
+
now = datetime.utcnow()
|
527 |
+
|
528 |
+
if billing_period == BillingPeriod.MONTHLY:
|
529 |
+
current_period_end = now + timedelta(days=30)
|
530 |
+
price = plan.price_monthly
|
531 |
+
stripe_price_id = plan.stripe_monthly_price_id
|
532 |
+
elif billing_period == BillingPeriod.ANNUALLY:
|
533 |
+
current_period_end = now + timedelta(days=365)
|
534 |
+
price = plan.price_annually
|
535 |
+
stripe_price_id = plan.stripe_annual_price_id
|
536 |
+
else:
|
537 |
+
logger.warning(f"Invalid billing period: {billing_period}")
|
538 |
+
return None
|
539 |
+
|
540 |
+
# Create Stripe subscription if requested
|
541 |
+
stripe_subscription_id = None
|
542 |
+
stripe_customer_id = None
|
543 |
+
|
544 |
+
if create_stripe_subscription and stripe.api_key and plan.stripe_product_id:
|
545 |
+
if not payment_method_id:
|
546 |
+
logger.warning("Payment method ID is required to create a Stripe subscription")
|
547 |
+
return None
|
548 |
+
|
549 |
+
try:
|
550 |
+
# Create or retrieve Stripe customer
|
551 |
+
customers = stripe.Customer.list(email=user.email)
|
552 |
+
|
553 |
+
if customers.data:
|
554 |
+
customer = customers.data[0]
|
555 |
+
stripe_customer_id = customer.id
|
556 |
+
else:
|
557 |
+
customer = stripe.Customer.create(
|
558 |
+
email=user.email,
|
559 |
+
name=user.full_name,
|
560 |
+
metadata={"user_id": user_id}
|
561 |
+
)
|
562 |
+
|
563 |
+
stripe_customer_id = customer.id
|
564 |
+
|
565 |
+
# Attach payment method to customer
|
566 |
+
stripe.PaymentMethod.attach(
|
567 |
+
payment_method_id,
|
568 |
+
customer=stripe_customer_id
|
569 |
+
)
|
570 |
+
|
571 |
+
# Set as default payment method
|
572 |
+
stripe.Customer.modify(
|
573 |
+
stripe_customer_id,
|
574 |
+
invoice_settings={
|
575 |
+
"default_payment_method": payment_method_id
|
576 |
+
}
|
577 |
+
)
|
578 |
+
|
579 |
+
# Create subscription
|
580 |
+
subscription = stripe.Subscription.create(
|
581 |
+
customer=stripe_customer_id,
|
582 |
+
items=[
|
583 |
+
{"price": stripe_price_id}
|
584 |
+
],
|
585 |
+
expand=["latest_invoice.payment_intent"]
|
586 |
+
)
|
587 |
+
|
588 |
+
stripe_subscription_id = subscription.id
|
589 |
+
|
590 |
+
logger.info(f"Created Stripe subscription {subscription.id} for user {user_id}")
|
591 |
+
except Exception as e:
|
592 |
+
logger.error(f"Failed to create Stripe subscription for user {user_id}: {e}")
|
593 |
+
return None
|
594 |
+
|
595 |
+
# Create subscription in database
|
596 |
+
subscription = UserSubscription(
|
597 |
+
user_id=user_id,
|
598 |
+
plan_id=plan_id,
|
599 |
+
status=SubscriptionStatus.ACTIVE,
|
600 |
+
billing_period=billing_period,
|
601 |
+
current_period_start=now,
|
602 |
+
current_period_end=current_period_end,
|
603 |
+
stripe_subscription_id=stripe_subscription_id,
|
604 |
+
stripe_customer_id=stripe_customer_id
|
605 |
+
)
|
606 |
+
|
607 |
+
db.add(subscription)
|
608 |
+
await db.commit()
|
609 |
+
await db.refresh(subscription)
|
610 |
+
|
611 |
+
# Record payment
|
612 |
+
if subscription.id:
|
613 |
+
payment_status = PaymentStatus.SUCCEEDED if stripe_subscription_id else PaymentStatus.PENDING
|
614 |
+
|
615 |
+
payment = PaymentHistory(
|
616 |
+
user_id=user_id,
|
617 |
+
subscription_id=subscription.id,
|
618 |
+
amount=price,
|
619 |
+
currency="USD",
|
620 |
+
status=payment_status
|
621 |
+
)
|
622 |
+
|
623 |
+
db.add(payment)
|
624 |
+
await db.commit()
|
625 |
+
|
626 |
+
return subscription
|
627 |
+
|
628 |
+
|
629 |
+
async def cancel_user_subscription(
|
630 |
+
db: AsyncSession,
|
631 |
+
subscription_id: int,
|
632 |
+
cancel_stripe_subscription: bool = True
|
633 |
+
) -> Optional[UserSubscription]:
|
634 |
+
"""
|
635 |
+
Cancel a user subscription.
|
636 |
+
|
637 |
+
Args:
|
638 |
+
db: Database session
|
639 |
+
subscription_id: ID of the subscription to cancel
|
640 |
+
cancel_stripe_subscription: Whether to cancel the Stripe subscription
|
641 |
+
|
642 |
+
Returns:
|
643 |
+
Canceled user subscription or None if cancellation failed
|
644 |
+
"""
|
645 |
+
# Get subscription
|
646 |
+
subscription = await get_user_subscription_by_id(db, subscription_id)
|
647 |
+
|
648 |
+
if not subscription:
|
649 |
+
logger.warning(f"Subscription with ID {subscription_id} not found")
|
650 |
+
return None
|
651 |
+
|
652 |
+
# Cancel Stripe subscription if requested
|
653 |
+
if cancel_stripe_subscription and subscription.stripe_subscription_id and stripe.api_key:
|
654 |
+
try:
|
655 |
+
stripe.Subscription.modify(
|
656 |
+
subscription.stripe_subscription_id,
|
657 |
+
cancel_at_period_end=True
|
658 |
+
)
|
659 |
+
|
660 |
+
logger.info(f"Canceled Stripe subscription {subscription.stripe_subscription_id} at period end")
|
661 |
+
except Exception as e:
|
662 |
+
logger.error(f"Failed to cancel Stripe subscription {subscription.stripe_subscription_id}: {e}")
|
663 |
+
|
664 |
+
# Update subscription in database
|
665 |
+
now = datetime.utcnow()
|
666 |
+
|
667 |
+
await db.execute(
|
668 |
+
update(UserSubscription)
|
669 |
+
.where(UserSubscription.id == subscription_id)
|
670 |
+
.values(
|
671 |
+
status=SubscriptionStatus.CANCELED,
|
672 |
+
canceled_at=now
|
673 |
+
)
|
674 |
+
)
|
675 |
+
|
676 |
+
await db.commit()
|
677 |
+
|
678 |
+
# Refresh subscription
|
679 |
+
subscription = await get_user_subscription_by_id(db, subscription_id)
|
680 |
+
|
681 |
+
return subscription
|
src/api/services/threat_service.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Service for threat operations.
|
3 |
+
"""
|
4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
5 |
+
from sqlalchemy.future import select
|
6 |
+
from sqlalchemy import func, or_, and_
|
7 |
+
from datetime import datetime, timedelta
|
8 |
+
from typing import List, Optional, Dict, Any, Union
|
9 |
+
|
10 |
+
from src.models.threat import Threat, ThreatSeverity, ThreatStatus, ThreatCategory
|
11 |
+
from src.models.indicator import Indicator, IndicatorType
|
12 |
+
from src.api.schemas import PaginationParams
|
13 |
+
|
14 |
+
async def create_threat(
|
15 |
+
db: AsyncSession,
|
16 |
+
title: str,
|
17 |
+
description: str,
|
18 |
+
severity: ThreatSeverity,
|
19 |
+
category: ThreatCategory,
|
20 |
+
status: ThreatStatus = ThreatStatus.NEW,
|
21 |
+
source_url: Optional[str] = None,
|
22 |
+
source_name: Optional[str] = None,
|
23 |
+
source_type: Optional[str] = None,
|
24 |
+
affected_entity: Optional[str] = None,
|
25 |
+
affected_entity_type: Optional[str] = None,
|
26 |
+
confidence_score: float = 0.0,
|
27 |
+
risk_score: float = 0.0,
|
28 |
+
) -> Threat:
|
29 |
+
"""
|
30 |
+
Create a new threat.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
db: Database session
|
34 |
+
title: Threat title
|
35 |
+
description: Threat description
|
36 |
+
severity: Threat severity
|
37 |
+
category: Threat category
|
38 |
+
status: Threat status
|
39 |
+
source_url: URL of the source
|
40 |
+
source_name: Name of the source
|
41 |
+
source_type: Type of source
|
42 |
+
affected_entity: Name of affected entity
|
43 |
+
affected_entity_type: Type of affected entity
|
44 |
+
confidence_score: Confidence score (0-1)
|
45 |
+
risk_score: Risk score (0-1)
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Threat: Created threat
|
49 |
+
"""
|
50 |
+
db_threat = Threat(
|
51 |
+
title=title,
|
52 |
+
description=description,
|
53 |
+
severity=severity,
|
54 |
+
category=category,
|
55 |
+
status=status,
|
56 |
+
source_url=source_url,
|
57 |
+
source_name=source_name,
|
58 |
+
source_type=source_type,
|
59 |
+
discovered_at=datetime.utcnow(),
|
60 |
+
affected_entity=affected_entity,
|
61 |
+
affected_entity_type=affected_entity_type,
|
62 |
+
confidence_score=confidence_score,
|
63 |
+
risk_score=risk_score,
|
64 |
+
)
|
65 |
+
|
66 |
+
db.add(db_threat)
|
67 |
+
await db.commit()
|
68 |
+
await db.refresh(db_threat)
|
69 |
+
|
70 |
+
return db_threat
|
71 |
+
|
72 |
+
async def get_threat_by_id(db: AsyncSession, threat_id: int) -> Optional[Threat]:
|
73 |
+
"""
|
74 |
+
Get threat by ID.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
db: Database session
|
78 |
+
threat_id: Threat ID
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Optional[Threat]: Threat or None if not found
|
82 |
+
"""
|
83 |
+
result = await db.execute(select(Threat).filter(Threat.id == threat_id))
|
84 |
+
return result.scalars().first()
|
85 |
+
|
86 |
+
async def get_threats(
|
87 |
+
db: AsyncSession,
|
88 |
+
pagination: PaginationParams,
|
89 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
90 |
+
status: Optional[List[ThreatStatus]] = None,
|
91 |
+
category: Optional[List[ThreatCategory]] = None,
|
92 |
+
search_query: Optional[str] = None,
|
93 |
+
from_date: Optional[datetime] = None,
|
94 |
+
to_date: Optional[datetime] = None,
|
95 |
+
) -> List[Threat]:
|
96 |
+
"""
|
97 |
+
Get threats with filtering and pagination.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
db: Database session
|
101 |
+
pagination: Pagination parameters
|
102 |
+
severity: Filter by severity
|
103 |
+
status: Filter by status
|
104 |
+
category: Filter by category
|
105 |
+
search_query: Search in title and description
|
106 |
+
from_date: Filter by discovered_at >= from_date
|
107 |
+
to_date: Filter by discovered_at <= to_date
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
List[Threat]: List of threats
|
111 |
+
"""
|
112 |
+
query = select(Threat)
|
113 |
+
|
114 |
+
# Apply filters
|
115 |
+
if severity:
|
116 |
+
query = query.filter(Threat.severity.in_(severity))
|
117 |
+
|
118 |
+
if status:
|
119 |
+
query = query.filter(Threat.status.in_(status))
|
120 |
+
|
121 |
+
if category:
|
122 |
+
query = query.filter(Threat.category.in_(category))
|
123 |
+
|
124 |
+
if search_query:
|
125 |
+
search_filter = or_(
|
126 |
+
Threat.title.ilike(f"%{search_query}%"),
|
127 |
+
Threat.description.ilike(f"%{search_query}%")
|
128 |
+
)
|
129 |
+
query = query.filter(search_filter)
|
130 |
+
|
131 |
+
if from_date:
|
132 |
+
query = query.filter(Threat.discovered_at >= from_date)
|
133 |
+
|
134 |
+
if to_date:
|
135 |
+
query = query.filter(Threat.discovered_at <= to_date)
|
136 |
+
|
137 |
+
# Apply pagination
|
138 |
+
query = query.order_by(Threat.discovered_at.desc())
|
139 |
+
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
|
140 |
+
|
141 |
+
result = await db.execute(query)
|
142 |
+
return result.scalars().all()
|
143 |
+
|
144 |
+
async def count_threats(
|
145 |
+
db: AsyncSession,
|
146 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
147 |
+
status: Optional[List[ThreatStatus]] = None,
|
148 |
+
category: Optional[List[ThreatCategory]] = None,
|
149 |
+
search_query: Optional[str] = None,
|
150 |
+
from_date: Optional[datetime] = None,
|
151 |
+
to_date: Optional[datetime] = None,
|
152 |
+
) -> int:
|
153 |
+
"""
|
154 |
+
Count threats with filtering.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
db: Database session
|
158 |
+
severity: Filter by severity
|
159 |
+
status: Filter by status
|
160 |
+
category: Filter by category
|
161 |
+
search_query: Search in title and description
|
162 |
+
from_date: Filter by discovered_at >= from_date
|
163 |
+
to_date: Filter by discovered_at <= to_date
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
int: Count of threats
|
167 |
+
"""
|
168 |
+
query = select(func.count(Threat.id))
|
169 |
+
|
170 |
+
# Apply filters (same as in get_threats)
|
171 |
+
if severity:
|
172 |
+
query = query.filter(Threat.severity.in_(severity))
|
173 |
+
|
174 |
+
if status:
|
175 |
+
query = query.filter(Threat.status.in_(status))
|
176 |
+
|
177 |
+
if category:
|
178 |
+
query = query.filter(Threat.category.in_(category))
|
179 |
+
|
180 |
+
if search_query:
|
181 |
+
search_filter = or_(
|
182 |
+
Threat.title.ilike(f"%{search_query}%"),
|
183 |
+
Threat.description.ilike(f"%{search_query}%")
|
184 |
+
)
|
185 |
+
query = query.filter(search_filter)
|
186 |
+
|
187 |
+
if from_date:
|
188 |
+
query = query.filter(Threat.discovered_at >= from_date)
|
189 |
+
|
190 |
+
if to_date:
|
191 |
+
query = query.filter(Threat.discovered_at <= to_date)
|
192 |
+
|
193 |
+
result = await db.execute(query)
|
194 |
+
return result.scalar()
|
195 |
+
|
196 |
+
async def update_threat(
|
197 |
+
db: AsyncSession,
|
198 |
+
threat_id: int,
|
199 |
+
title: Optional[str] = None,
|
200 |
+
description: Optional[str] = None,
|
201 |
+
severity: Optional[ThreatSeverity] = None,
|
202 |
+
status: Optional[ThreatStatus] = None,
|
203 |
+
category: Optional[ThreatCategory] = None,
|
204 |
+
affected_entity: Optional[str] = None,
|
205 |
+
affected_entity_type: Optional[str] = None,
|
206 |
+
confidence_score: Optional[float] = None,
|
207 |
+
risk_score: Optional[float] = None,
|
208 |
+
) -> Optional[Threat]:
|
209 |
+
"""
|
210 |
+
Update threat.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
db: Database session
|
214 |
+
threat_id: Threat ID
|
215 |
+
title: New title
|
216 |
+
description: New description
|
217 |
+
severity: New severity
|
218 |
+
status: New status
|
219 |
+
category: New category
|
220 |
+
affected_entity: New affected entity
|
221 |
+
affected_entity_type: New affected entity type
|
222 |
+
confidence_score: New confidence score
|
223 |
+
risk_score: New risk score
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
Optional[Threat]: Updated threat or None if not found
|
227 |
+
"""
|
228 |
+
threat = await get_threat_by_id(db, threat_id)
|
229 |
+
if not threat:
|
230 |
+
return None
|
231 |
+
|
232 |
+
if title is not None:
|
233 |
+
threat.title = title
|
234 |
+
|
235 |
+
if description is not None:
|
236 |
+
threat.description = description
|
237 |
+
|
238 |
+
if severity is not None:
|
239 |
+
threat.severity = severity
|
240 |
+
|
241 |
+
if status is not None:
|
242 |
+
threat.status = status
|
243 |
+
|
244 |
+
if category is not None:
|
245 |
+
threat.category = category
|
246 |
+
|
247 |
+
if affected_entity is not None:
|
248 |
+
threat.affected_entity = affected_entity
|
249 |
+
|
250 |
+
if affected_entity_type is not None:
|
251 |
+
threat.affected_entity_type = affected_entity_type
|
252 |
+
|
253 |
+
if confidence_score is not None:
|
254 |
+
threat.confidence_score = confidence_score
|
255 |
+
|
256 |
+
if risk_score is not None:
|
257 |
+
threat.risk_score = risk_score
|
258 |
+
|
259 |
+
threat.updated_at = datetime.utcnow()
|
260 |
+
|
261 |
+
await db.commit()
|
262 |
+
await db.refresh(threat)
|
263 |
+
|
264 |
+
return threat
|
265 |
+
|
266 |
+
async def add_indicator_to_threat(
|
267 |
+
db: AsyncSession,
|
268 |
+
threat_id: int,
|
269 |
+
value: str,
|
270 |
+
indicator_type: IndicatorType,
|
271 |
+
description: Optional[str] = None,
|
272 |
+
is_verified: bool = False,
|
273 |
+
context: Optional[str] = None,
|
274 |
+
source: Optional[str] = None,
|
275 |
+
confidence_score: float = 0.0,
|
276 |
+
) -> Indicator:
|
277 |
+
"""
|
278 |
+
Add an indicator to a threat.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
db: Database session
|
282 |
+
threat_id: Threat ID
|
283 |
+
value: Indicator value
|
284 |
+
indicator_type: Indicator type
|
285 |
+
description: Description of the indicator
|
286 |
+
is_verified: Whether the indicator is verified
|
287 |
+
context: Context of the indicator
|
288 |
+
source: Source of the indicator
|
289 |
+
confidence_score: Confidence score (0-1)
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
Indicator: Created indicator
|
293 |
+
"""
|
294 |
+
# Check if threat exists
|
295 |
+
threat = await get_threat_by_id(db, threat_id)
|
296 |
+
if not threat:
|
297 |
+
raise ValueError(f"Threat with ID {threat_id} not found")
|
298 |
+
|
299 |
+
# Create indicator
|
300 |
+
db_indicator = Indicator(
|
301 |
+
threat_id=threat_id,
|
302 |
+
value=value,
|
303 |
+
indicator_type=indicator_type,
|
304 |
+
description=description,
|
305 |
+
is_verified=is_verified,
|
306 |
+
context=context,
|
307 |
+
source=source,
|
308 |
+
confidence_score=confidence_score,
|
309 |
+
first_seen=datetime.utcnow(),
|
310 |
+
last_seen=datetime.utcnow(),
|
311 |
+
)
|
312 |
+
|
313 |
+
db.add(db_indicator)
|
314 |
+
await db.commit()
|
315 |
+
await db.refresh(db_indicator)
|
316 |
+
|
317 |
+
return db_indicator
|
318 |
+
|
319 |
+
async def get_threat_statistics(
|
320 |
+
db: AsyncSession,
|
321 |
+
from_date: Optional[datetime] = None,
|
322 |
+
to_date: Optional[datetime] = None,
|
323 |
+
) -> Dict[str, Any]:
|
324 |
+
"""
|
325 |
+
Get threat statistics.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
db: Database session
|
329 |
+
from_date: Filter by discovered_at >= from_date
|
330 |
+
to_date: Filter by discovered_at <= to_date
|
331 |
+
|
332 |
+
Returns:
|
333 |
+
Dict[str, Any]: Threat statistics
|
334 |
+
"""
|
335 |
+
# Set default time range if not provided
|
336 |
+
if not to_date:
|
337 |
+
to_date = datetime.utcnow()
|
338 |
+
|
339 |
+
if not from_date:
|
340 |
+
from_date = to_date - timedelta(days=30)
|
341 |
+
|
342 |
+
# Get count by severity
|
343 |
+
severity_counts = {}
|
344 |
+
for severity in ThreatSeverity:
|
345 |
+
query = select(func.count(Threat.id)).filter(and_(
|
346 |
+
Threat.severity == severity,
|
347 |
+
Threat.discovered_at >= from_date,
|
348 |
+
Threat.discovered_at <= to_date,
|
349 |
+
))
|
350 |
+
result = await db.execute(query)
|
351 |
+
severity_counts[severity.value] = result.scalar() or 0
|
352 |
+
|
353 |
+
# Get count by status
|
354 |
+
status_counts = {}
|
355 |
+
for status in ThreatStatus:
|
356 |
+
query = select(func.count(Threat.id)).filter(and_(
|
357 |
+
Threat.status == status,
|
358 |
+
Threat.discovered_at >= from_date,
|
359 |
+
Threat.discovered_at <= to_date,
|
360 |
+
))
|
361 |
+
result = await db.execute(query)
|
362 |
+
status_counts[status.value] = result.scalar() or 0
|
363 |
+
|
364 |
+
# Get count by category
|
365 |
+
category_counts = {}
|
366 |
+
for category in ThreatCategory:
|
367 |
+
query = select(func.count(Threat.id)).filter(and_(
|
368 |
+
Threat.category == category,
|
369 |
+
Threat.discovered_at >= from_date,
|
370 |
+
Threat.discovered_at <= to_date,
|
371 |
+
))
|
372 |
+
result = await db.execute(query)
|
373 |
+
category_counts[category.value] = result.scalar() or 0
|
374 |
+
|
375 |
+
# Get total count
|
376 |
+
query = select(func.count(Threat.id)).filter(and_(
|
377 |
+
Threat.discovered_at >= from_date,
|
378 |
+
Threat.discovered_at <= to_date,
|
379 |
+
))
|
380 |
+
result = await db.execute(query)
|
381 |
+
total_count = result.scalar() or 0
|
382 |
+
|
383 |
+
# Get count by day
|
384 |
+
time_series = []
|
385 |
+
current_date = from_date.date()
|
386 |
+
end_date = to_date.date()
|
387 |
+
|
388 |
+
while current_date <= end_date:
|
389 |
+
next_date = current_date + timedelta(days=1)
|
390 |
+
query = select(func.count(Threat.id)).filter(and_(
|
391 |
+
Threat.discovered_at >= datetime.combine(current_date, datetime.min.time()),
|
392 |
+
Threat.discovered_at < datetime.combine(next_date, datetime.min.time()),
|
393 |
+
))
|
394 |
+
result = await db.execute(query)
|
395 |
+
count = result.scalar() or 0
|
396 |
+
time_series.append({
|
397 |
+
"date": current_date.isoformat(),
|
398 |
+
"count": count
|
399 |
+
})
|
400 |
+
current_date = next_date
|
401 |
+
|
402 |
+
# Return statistics
|
403 |
+
return {
|
404 |
+
"total_count": total_count,
|
405 |
+
"severity_counts": severity_counts,
|
406 |
+
"status_counts": status_counts,
|
407 |
+
"category_counts": category_counts,
|
408 |
+
"time_series": time_series,
|
409 |
+
"from_date": from_date.isoformat(),
|
410 |
+
"to_date": to_date.isoformat(),
|
411 |
+
}
|
src/api/services/user_service.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
2 |
+
from sqlalchemy.future import select
|
3 |
+
from sqlalchemy import update
|
4 |
+
from passlib.context import CryptContext
|
5 |
+
from typing import Optional, List, Dict, Any
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from src.models.user import User
|
9 |
+
from src.api.schemas import UserCreate, UserUpdate, UserInDB
|
10 |
+
|
11 |
+
# Configure logger
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
# Password context for hashing and verification
|
15 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
16 |
+
|
17 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
18 |
+
"""
|
19 |
+
Verify password against hash.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
plain_password: Plain password
|
23 |
+
hashed_password: Hashed password
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
bool: True if password is correct
|
27 |
+
"""
|
28 |
+
return pwd_context.verify(plain_password, hashed_password)
|
29 |
+
|
30 |
+
def get_password_hash(password: str) -> str:
|
31 |
+
"""
|
32 |
+
Hash password.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
password: Plain password
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
str: Hashed password
|
39 |
+
"""
|
40 |
+
return pwd_context.hash(password)
|
41 |
+
|
42 |
+
async def get_user_by_username(db: AsyncSession, username: str) -> Optional[UserInDB]:
|
43 |
+
"""
|
44 |
+
Get user by username.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
db: Database session
|
48 |
+
username: Username
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Optional[UserInDB]: User if found, None otherwise
|
52 |
+
"""
|
53 |
+
try:
|
54 |
+
result = await db.execute(select(User).where(User.username == username))
|
55 |
+
user = result.scalars().first()
|
56 |
+
|
57 |
+
if not user:
|
58 |
+
return None
|
59 |
+
|
60 |
+
# Convert SQLAlchemy model to Pydantic model
|
61 |
+
user_dict = {c.name: getattr(user, c.name) for c in user.__table__.columns}
|
62 |
+
return UserInDB(**user_dict)
|
63 |
+
except Exception as e:
|
64 |
+
logger.error(f"Error getting user by username: {e}")
|
65 |
+
return None
|
66 |
+
|
67 |
+
async def authenticate_user(db: AsyncSession, username: str, password: str) -> Optional[UserInDB]:
|
68 |
+
"""
|
69 |
+
Authenticate user.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
db: Database session
|
73 |
+
username: Username
|
74 |
+
password: Plain password
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Optional[UserInDB]: User if authenticated, None otherwise
|
78 |
+
"""
|
79 |
+
user = await get_user_by_username(db, username)
|
80 |
+
|
81 |
+
if not user:
|
82 |
+
return None
|
83 |
+
|
84 |
+
if not verify_password(password, user.hashed_password):
|
85 |
+
return None
|
86 |
+
|
87 |
+
return user
|
88 |
+
|
89 |
+
async def create_user(db: AsyncSession, user_data: UserCreate) -> Optional[UserInDB]:
|
90 |
+
"""
|
91 |
+
Create a new user.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
db: Database session
|
95 |
+
user_data: User data
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Optional[UserInDB]: Created user
|
99 |
+
"""
|
100 |
+
try:
|
101 |
+
# Check if user already exists
|
102 |
+
existing_user = await get_user_by_username(db, user_data.username)
|
103 |
+
if existing_user:
|
104 |
+
return None
|
105 |
+
|
106 |
+
# Create new user
|
107 |
+
hashed_password = get_password_hash(user_data.password)
|
108 |
+
user = User(
|
109 |
+
username=user_data.username,
|
110 |
+
email=user_data.email,
|
111 |
+
full_name=user_data.full_name,
|
112 |
+
hashed_password=hashed_password,
|
113 |
+
is_active=user_data.is_active
|
114 |
+
)
|
115 |
+
|
116 |
+
db.add(user)
|
117 |
+
await db.commit()
|
118 |
+
await db.refresh(user)
|
119 |
+
|
120 |
+
# Convert SQLAlchemy model to Pydantic model
|
121 |
+
user_dict = {c.name: getattr(user, c.name) for c in user.__table__.columns}
|
122 |
+
return UserInDB(**user_dict)
|
123 |
+
except Exception as e:
|
124 |
+
logger.error(f"Error creating user: {e}")
|
125 |
+
await db.rollback()
|
126 |
+
return None
|
127 |
+
|
128 |
+
async def update_user(db: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserInDB]:
|
129 |
+
"""
|
130 |
+
Update user.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
db: Database session
|
134 |
+
user_id: User ID
|
135 |
+
user_data: User data
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
Optional[UserInDB]: Updated user
|
139 |
+
"""
|
140 |
+
try:
|
141 |
+
# Create update dictionary
|
142 |
+
update_data = user_data.dict(exclude_unset=True)
|
143 |
+
|
144 |
+
# Hash password if provided
|
145 |
+
if "password" in update_data:
|
146 |
+
update_data["hashed_password"] = get_password_hash(update_data.pop("password"))
|
147 |
+
|
148 |
+
# Update user
|
149 |
+
stmt = update(User).where(User.id == user_id).values(**update_data)
|
150 |
+
await db.execute(stmt)
|
151 |
+
await db.commit()
|
152 |
+
|
153 |
+
# Get updated user
|
154 |
+
result = await db.execute(select(User).where(User.id == user_id))
|
155 |
+
user = result.scalars().first()
|
156 |
+
|
157 |
+
if not user:
|
158 |
+
return None
|
159 |
+
|
160 |
+
# Convert SQLAlchemy model to Pydantic model
|
161 |
+
user_dict = {c.name: getattr(user, c.name) for c in user.__table__.columns}
|
162 |
+
return UserInDB(**user_dict)
|
163 |
+
except Exception as e:
|
164 |
+
logger.error(f"Error updating user: {e}")
|
165 |
+
await db.rollback()
|
166 |
+
return None
|