Replit Deployment commited on
Commit
89ae94f
·
1 Parent(s): 94b87a8

Deployment from Replit

Browse files
hf_database.py CHANGED
@@ -10,7 +10,12 @@ from sqlalchemy.orm import sessionmaker
10
  from sqlalchemy.pool import StaticPool
11
  from src.models.base import Base
12
  from src.models.user import User
13
- from src.api.security import get_password_hash
 
 
 
 
 
14
 
15
  # Configure logging
16
  logging.basicConfig(level=logging.INFO)
 
10
  from sqlalchemy.pool import StaticPool
11
  from src.models.base import Base
12
  from src.models.user import User
13
+ try:
14
+ # Try to import from src.api.security first (for local development)
15
+ from src.api.security import get_password_hash
16
+ except ImportError:
17
+ # Fall back to simplified security module for HF (copied during deployment)
18
+ from security_hf import get_password_hash
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.INFO)
security_hf.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified security module for Hugging Face deployment.
3
+ Contains only the essential functions needed for HF deployment.
4
+ """
5
+ from passlib.context import CryptContext
6
+
7
+ # Set up password hashing
8
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
9
+
10
+ def get_password_hash(password: str) -> str:
11
+ """Hash a password for storage"""
12
+ return pwd_context.hash(password)
13
+
14
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
15
+ """Verify a password against a hash"""
16
+ return pwd_context.verify(plain_password, hashed_password)
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Package initialization for src.
3
+ """
src/api/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Package initialization for API.
3
+ """
src/api/auth.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Depends, HTTPException, status
2
+ from fastapi.security import OAuth2PasswordBearer
3
+ from sqlalchemy.ext.asyncio import AsyncSession
4
+ from jose import JWTError, jwt
5
+ from datetime import datetime, timedelta
6
+ from typing import Optional, Dict, Any
7
+ import os
8
+ import logging
9
+
10
+ from src.api.database import get_db
11
+ from src.api.schemas import TokenData, UserInDB
12
+ from src.api.services.user_service import get_user_by_username
13
+
14
+ # Configure logger
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Constants for JWT
18
+ SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-for-jwt-please-change-in-production")
19
+ ALGORITHM = "HS256"
20
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
21
+
22
+ # OAuth2PasswordBearer for token extraction
23
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/token")
24
+
25
+ def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
26
+ """
27
+ Create a JWT access token.
28
+
29
+ Args:
30
+ data: Dictionary of data to encode in the token
31
+ expires_delta: Optional expiration time delta
32
+
33
+ Returns:
34
+ str: JWT token
35
+ """
36
+ to_encode = data.copy()
37
+
38
+ if expires_delta:
39
+ expire = datetime.utcnow() + expires_delta
40
+ else:
41
+ expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
42
+
43
+ to_encode.update({"exp": expire})
44
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
45
+
46
+ return encoded_jwt
47
+
48
+ async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> UserInDB:
49
+ """
50
+ Get the current authenticated user based on the JWT token.
51
+
52
+ Args:
53
+ token: JWT token
54
+ db: Database session
55
+
56
+ Returns:
57
+ UserInDB: User data
58
+
59
+ Raises:
60
+ HTTPException: If authentication fails
61
+ """
62
+ credentials_exception = HTTPException(
63
+ status_code=status.HTTP_401_UNAUTHORIZED,
64
+ detail="Could not validate credentials",
65
+ headers={"WWW-Authenticate": "Bearer"},
66
+ )
67
+
68
+ try:
69
+ # Decode JWT
70
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
71
+ username: str = payload.get("sub")
72
+
73
+ if username is None:
74
+ raise credentials_exception
75
+
76
+ token_data = TokenData(username=username)
77
+ except JWTError as e:
78
+ logger.error(f"JWT error: {e}")
79
+ raise credentials_exception
80
+
81
+ # Get user from database
82
+ user = await get_user_by_username(db, username=token_data.username)
83
+
84
+ if user is None:
85
+ raise credentials_exception
86
+
87
+ return user
88
+
89
+ async def get_current_active_user(current_user: UserInDB = Depends(get_current_user)) -> UserInDB:
90
+ """
91
+ Get the current active user.
92
+
93
+ Args:
94
+ current_user: Current authenticated user
95
+
96
+ Returns:
97
+ UserInDB: User data
98
+
99
+ Raises:
100
+ HTTPException: If user is inactive
101
+ """
102
+ if not current_user.is_active:
103
+ raise HTTPException(status_code=400, detail="Inactive user")
104
+
105
+ return current_user
src/api/database.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Database configuration and setup for API.
3
+ """
4
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
5
+ from sqlalchemy.orm import sessionmaker
6
+ import os
7
+ from typing import AsyncGenerator
8
+ import logging
9
+
10
+ # Configure logger
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Get database URL from environment (convert synchronous to async URL)
14
+ db_url = os.getenv("DATABASE_URL", "")
15
+ if db_url.startswith("postgresql://"):
16
+ # Remove sslmode parameter if present which causes issues with asyncpg
17
+ if "?" in db_url:
18
+ base_url, params = db_url.split("?", 1)
19
+ param_list = params.split("&")
20
+ filtered_params = [p for p in param_list if not p.startswith("sslmode=")]
21
+ if filtered_params:
22
+ db_url = f"{base_url}?{'&'.join(filtered_params)}"
23
+ else:
24
+ db_url = base_url
25
+
26
+ ASYNC_DATABASE_URL = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
27
+ else:
28
+ ASYNC_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
29
+
30
+ # Create async engine
31
+ engine = create_async_engine(
32
+ ASYNC_DATABASE_URL,
33
+ echo=False, # Set to True for debugging
34
+ future=True,
35
+ pool_size=5,
36
+ max_overflow=10
37
+ )
38
+
39
+ # Create async session factory
40
+ async_session = sessionmaker(
41
+ engine,
42
+ class_=AsyncSession,
43
+ expire_on_commit=False
44
+ )
45
+
46
+ async def get_db() -> AsyncGenerator[AsyncSession, None]:
47
+ """
48
+ Get database session generator.
49
+
50
+ Yields:
51
+ AsyncSession: Database session
52
+ """
53
+ session = async_session()
54
+ try:
55
+ yield session
56
+ await session.commit()
57
+ except Exception as e:
58
+ await session.rollback()
59
+ logger.error(f"Database error: {e}")
60
+ raise
61
+ finally:
62
+ await session.close()
63
+
64
+ # Dependency for getting DB session
65
+ async def get_db_session() -> AsyncSession:
66
+ """
67
+ Get database session.
68
+
69
+ Returns:
70
+ AsyncSession: Database session
71
+ """
72
+ async with async_session() as session:
73
+ return session
src/api/main.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Depends, HTTPException, status
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import logging
4
+ from typing import List
5
+
6
+ from src.api.database import get_db
7
+ from src.api.auth import get_current_user
8
+ from src.api.routers import threats_router, indicators_router, auth_router, admin_router
9
+
10
+ # Configure logging
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
14
+ handlers=[
15
+ logging.StreamHandler(),
16
+ logging.FileHandler("app.log")
17
+ ]
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Create FastAPI app
22
+ app = FastAPI(
23
+ title="CyberForge OSINT API",
24
+ description="API for Dark Web OSINT platform",
25
+ version="1.0.0"
26
+ )
27
+
28
+ # Add CORS middleware
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"], # Update for production
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ # Include routers for different endpoints
38
+ app.include_router(auth_router.router, prefix="/api/v1")
39
+ app.include_router(
40
+ threats_router.router,
41
+ prefix="/api/v1/threats",
42
+ tags=["threats"],
43
+ dependencies=[Depends(get_current_user)]
44
+ )
45
+ app.include_router(
46
+ indicators_router.router,
47
+ prefix="/api/v1/indicators",
48
+ tags=["indicators"],
49
+ dependencies=[Depends(get_current_user)]
50
+ )
51
+ app.include_router(
52
+ admin_router.router,
53
+ prefix="/api/v1/admin",
54
+ tags=["admin"],
55
+ dependencies=[Depends(get_current_user)]
56
+ )
57
+
58
+ @app.get("/api/health")
59
+ async def health_check():
60
+ """Health check endpoint for monitoring."""
61
+ return {"status": "healthy", "version": "1.0.0"}
62
+
63
+ @app.on_event("startup")
64
+ async def startup_event():
65
+ """Event handler for application startup."""
66
+ logger.info("Starting the CyberForge OSINT API")
67
+ # Add any startup tasks here (database connection, cache warming, etc.)
68
+
69
+ @app.on_event("shutdown")
70
+ async def shutdown_event():
71
+ """Event handler for application shutdown."""
72
+ logger.info("Shutting down the CyberForge OSINT API")
73
+ # Add any cleanup tasks here (close connections, save state, etc.)
src/api/routers/auth_router.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authentication router.
3
+
4
+ This module provides authentication endpoints for the API.
5
+ """
6
+ from datetime import datetime, timedelta
7
+ from typing import Any, Dict, Optional
8
+
9
+ from fastapi import APIRouter, Depends, HTTPException, status
10
+ from fastapi.security import OAuth2PasswordRequestForm
11
+ from sqlalchemy.ext.asyncio import AsyncSession
12
+
13
+ from src.api.database import get_db
14
+ from src.api.security import (
15
+ ACCESS_TOKEN_EXPIRE_MINUTES,
16
+ Token,
17
+ UserInDB,
18
+ authenticate_user,
19
+ create_access_token,
20
+ get_current_active_user,
21
+ )
22
+
23
+ router = APIRouter(tags=["authentication"])
24
+
25
+ @router.post("/token", response_model=Token)
26
+ async def login_for_access_token(
27
+ form_data: OAuth2PasswordRequestForm = Depends(),
28
+ db: AsyncSession = Depends(get_db)
29
+ ) -> Dict[str, Any]:
30
+ """
31
+ OAuth2 compatible token login, get an access token for future requests.
32
+ """
33
+ user = await authenticate_user(db, form_data.username, form_data.password)
34
+ if not user:
35
+ raise HTTPException(
36
+ status_code=status.HTTP_401_UNAUTHORIZED,
37
+ detail="Incorrect username or password",
38
+ headers={"WWW-Authenticate": "Bearer"},
39
+ )
40
+
41
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
42
+ access_token = create_access_token(
43
+ data={"sub": user.username, "scopes": user.scopes},
44
+ expires_delta=access_token_expires
45
+ )
46
+
47
+ expires_at = datetime.utcnow() + access_token_expires
48
+
49
+ return {
50
+ "access_token": access_token,
51
+ "token_type": "bearer",
52
+ "expires_at": expires_at
53
+ }
54
+
55
+ @router.get("/users/me", response_model=UserInDB)
56
+ async def read_users_me(
57
+ current_user: UserInDB = Depends(get_current_active_user)
58
+ ) -> UserInDB:
59
+ """
60
+ Get current user.
61
+ """
62
+ return current_user
63
+
64
+ @router.get("/users/me/scopes")
65
+ async def read_own_scopes(
66
+ current_user: UserInDB = Depends(get_current_active_user)
67
+ ) -> Dict[str, Any]:
68
+ """
69
+ Get current user's scopes.
70
+ """
71
+ return {
72
+ "username": current_user.username,
73
+ "scopes": current_user.scopes
74
+ }
src/api/routers/scraping_router.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, status, Body, BackgroundTasks
2
+ from sqlalchemy.ext.asyncio import AsyncSession
3
+ from typing import List, Optional, Dict, Any
4
+ import logging
5
+ from datetime import datetime
6
+
7
+ from src.api.database import get_db
8
+ from src.api.auth import get_current_user
9
+ from src.api.schemas import User, CrawlRequest, CrawlResult
10
+ from src.services.scraper import WebScraper, ScraperError
11
+ from src.services.tor_proxy import TorProxyService, TorProxyError
12
+
13
+ # Configure logger
14
+ logger = logging.getLogger(__name__)
15
+
16
+ router = APIRouter(
17
+ prefix="/scraping",
18
+ tags=["scraping"],
19
+ responses={404: {"description": "Not found"}}
20
+ )
21
+
22
+ # Initialize services
23
+ scraper = WebScraper()
24
+
25
+ @router.post("/test-tor", response_model=Dict[str, Any])
26
+ async def test_tor_connection(
27
+ current_user: User = Depends(get_current_user)
28
+ ):
29
+ """
30
+ Test Tor connection.
31
+
32
+ Args:
33
+ current_user: Current authenticated user
34
+
35
+ Returns:
36
+ Dict[str, Any]: Connection status
37
+ """
38
+ try:
39
+ tor_proxy = TorProxyService()
40
+ is_connected = await tor_proxy.check_connection()
41
+
42
+ return {
43
+ "status": "success",
44
+ "is_connected": is_connected,
45
+ "timestamp": datetime.utcnow().isoformat()
46
+ }
47
+ except TorProxyError as e:
48
+ logger.error(f"Tor proxy error: {e}")
49
+ raise HTTPException(
50
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
51
+ detail=f"Tor proxy error: {str(e)}"
52
+ )
53
+ except Exception as e:
54
+ logger.error(f"Error testing Tor connection: {e}")
55
+ raise HTTPException(
56
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
57
+ detail=f"An error occurred: {str(e)}"
58
+ )
59
+
60
+ @router.post("/scrape", response_model=Dict[str, Any])
61
+ async def scrape_page(
62
+ url: str,
63
+ use_tor: bool = Body(False),
64
+ current_user: User = Depends(get_current_user)
65
+ ):
66
+ """
67
+ Scrape a single page.
68
+
69
+ Args:
70
+ url: URL to scrape
71
+ use_tor: Whether to use Tor proxy
72
+ current_user: Current authenticated user
73
+
74
+ Returns:
75
+ Dict[str, Any]: Scraped content
76
+ """
77
+ try:
78
+ result = await scraper.extract_content(url, use_tor=use_tor)
79
+
80
+ return {
81
+ "status": "success",
82
+ "data": result,
83
+ "timestamp": datetime.utcnow().isoformat()
84
+ }
85
+ except ScraperError as e:
86
+ logger.error(f"Scraper error: {e}")
87
+ raise HTTPException(
88
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
89
+ detail=f"Scraper error: {str(e)}"
90
+ )
91
+ except Exception as e:
92
+ logger.error(f"Error scraping page: {e}")
93
+ raise HTTPException(
94
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
95
+ detail=f"An error occurred: {str(e)}"
96
+ )
97
+
98
+ @router.post("/crawl", response_model=Dict[str, Any])
99
+ async def crawl_site(
100
+ crawl_request: CrawlRequest,
101
+ background_tasks: BackgroundTasks,
102
+ current_user: User = Depends(get_current_user)
103
+ ):
104
+ """
105
+ Crawl a site.
106
+
107
+ Args:
108
+ crawl_request: Crawl request data
109
+ background_tasks: Background tasks
110
+ current_user: Current authenticated user
111
+
112
+ Returns:
113
+ Dict[str, Any]: Crawl status
114
+ """
115
+ # For longer crawls, we add them as background tasks
116
+ # This prevents timeouts on the API request
117
+
118
+ # Start crawl in background
119
+ if crawl_request.max_depth > 1 or '.onion' in crawl_request.url:
120
+ # Add to background tasks
121
+ background_tasks.add_task(
122
+ scraper.crawl,
123
+ crawl_request.url,
124
+ max_depth=crawl_request.max_depth,
125
+ max_pages=50,
126
+ keyword_filter=crawl_request.keywords
127
+ )
128
+
129
+ return {
130
+ "status": "started",
131
+ "message": "Crawl started in background",
132
+ "timestamp": datetime.utcnow().isoformat()
133
+ }
134
+ else:
135
+ # For simple crawls, we perform them synchronously
136
+ try:
137
+ results = await scraper.crawl(
138
+ crawl_request.url,
139
+ max_depth=crawl_request.max_depth,
140
+ max_pages=10,
141
+ keyword_filter=crawl_request.keywords
142
+ )
143
+
144
+ return {
145
+ "status": "completed",
146
+ "results": results,
147
+ "count": len(results),
148
+ "timestamp": datetime.utcnow().isoformat()
149
+ }
150
+ except ScraperError as e:
151
+ logger.error(f"Scraper error: {e}")
152
+ raise HTTPException(
153
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
154
+ detail=f"Scraper error: {str(e)}"
155
+ )
156
+ except Exception as e:
157
+ logger.error(f"Error crawling site: {e}")
158
+ raise HTTPException(
159
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
160
+ detail=f"An error occurred: {str(e)}"
161
+ )
src/api/routers/threats_router.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, status, Body, Query
2
+ from sqlalchemy.ext.asyncio import AsyncSession
3
+ from typing import List, Optional
4
+ import logging
5
+
6
+ from src.api.database import get_db
7
+ from src.api.auth import get_current_user
8
+ from src.api.schemas import (
9
+ Threat, ThreatCreate, ThreatUpdate, ThreatFilter,
10
+ PaginationParams, User
11
+ )
12
+ from src.api.services.threat_service import (
13
+ create_threat, get_threat_by_id, update_threat,
14
+ delete_threat, get_threats, get_threat_statistics
15
+ )
16
+
17
+ # Configure logger
18
+ logger = logging.getLogger(__name__)
19
+
20
+ router = APIRouter(
21
+ tags=["threats"],
22
+ responses={404: {"description": "Not found"}}
23
+ )
24
+
25
+ @router.post("/", response_model=Threat, status_code=status.HTTP_201_CREATED)
26
+ async def create_threat_endpoint(
27
+ threat_data: ThreatCreate,
28
+ db: AsyncSession = Depends(get_db),
29
+ current_user: User = Depends(get_current_user)
30
+ ):
31
+ """
32
+ Create a new threat.
33
+
34
+ Args:
35
+ threat_data: Threat data
36
+ db: Database session
37
+ current_user: Current authenticated user
38
+
39
+ Returns:
40
+ Threat: Created threat
41
+ """
42
+ try:
43
+ threat = await create_threat(db, threat_data)
44
+
45
+ if not threat:
46
+ raise HTTPException(
47
+ status_code=status.HTTP_400_BAD_REQUEST,
48
+ detail="Failed to create threat"
49
+ )
50
+
51
+ return threat
52
+ except Exception as e:
53
+ logger.error(f"Error creating threat: {e}")
54
+ raise HTTPException(
55
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
56
+ detail=f"An error occurred: {str(e)}"
57
+ )
58
+
59
+ @router.get("/{threat_id}", response_model=Threat)
60
+ async def get_threat_endpoint(
61
+ threat_id: int,
62
+ db: AsyncSession = Depends(get_db),
63
+ current_user: User = Depends(get_current_user)
64
+ ):
65
+ """
66
+ Get threat by ID.
67
+
68
+ Args:
69
+ threat_id: Threat ID
70
+ db: Database session
71
+ current_user: Current authenticated user
72
+
73
+ Returns:
74
+ Threat: Threat data
75
+ """
76
+ threat = await get_threat_by_id(db, threat_id)
77
+
78
+ if not threat:
79
+ raise HTTPException(
80
+ status_code=status.HTTP_404_NOT_FOUND,
81
+ detail=f"Threat with ID {threat_id} not found"
82
+ )
83
+
84
+ return threat
85
+
86
+ @router.put("/{threat_id}", response_model=Threat)
87
+ async def update_threat_endpoint(
88
+ threat_id: int,
89
+ threat_data: ThreatUpdate,
90
+ db: AsyncSession = Depends(get_db),
91
+ current_user: User = Depends(get_current_user)
92
+ ):
93
+ """
94
+ Update threat.
95
+
96
+ Args:
97
+ threat_id: Threat ID
98
+ threat_data: Threat data
99
+ db: Database session
100
+ current_user: Current authenticated user
101
+
102
+ Returns:
103
+ Threat: Updated threat
104
+ """
105
+ # Check if threat exists
106
+ threat = await get_threat_by_id(db, threat_id)
107
+
108
+ if not threat:
109
+ raise HTTPException(
110
+ status_code=status.HTTP_404_NOT_FOUND,
111
+ detail=f"Threat with ID {threat_id} not found"
112
+ )
113
+
114
+ # Update threat
115
+ updated_threat = await update_threat(db, threat_id, threat_data)
116
+
117
+ if not updated_threat:
118
+ raise HTTPException(
119
+ status_code=status.HTTP_400_BAD_REQUEST,
120
+ detail="Failed to update threat"
121
+ )
122
+
123
+ return updated_threat
124
+
125
+ @router.delete("/{threat_id}", status_code=status.HTTP_204_NO_CONTENT)
126
+ async def delete_threat_endpoint(
127
+ threat_id: int,
128
+ db: AsyncSession = Depends(get_db),
129
+ current_user: User = Depends(get_current_user)
130
+ ):
131
+ """
132
+ Delete threat.
133
+
134
+ Args:
135
+ threat_id: Threat ID
136
+ db: Database session
137
+ current_user: Current authenticated user
138
+ """
139
+ # Check if threat exists
140
+ threat = await get_threat_by_id(db, threat_id)
141
+
142
+ if not threat:
143
+ raise HTTPException(
144
+ status_code=status.HTTP_404_NOT_FOUND,
145
+ detail=f"Threat with ID {threat_id} not found"
146
+ )
147
+
148
+ # Delete threat
149
+ deleted = await delete_threat(db, threat_id)
150
+
151
+ if not deleted:
152
+ raise HTTPException(
153
+ status_code=status.HTTP_400_BAD_REQUEST,
154
+ detail="Failed to delete threat"
155
+ )
156
+
157
+ @router.get("/", response_model=List[Threat])
158
+ async def get_threats_endpoint(
159
+ pagination: PaginationParams = Depends(),
160
+ severity: Optional[List[str]] = Query(None),
161
+ status: Optional[List[str]] = Query(None),
162
+ category: Optional[List[str]] = Query(None),
163
+ search: Optional[str] = Query(None),
164
+ from_date: Optional[str] = Query(None),
165
+ to_date: Optional[str] = Query(None),
166
+ db: AsyncSession = Depends(get_db),
167
+ current_user: User = Depends(get_current_user)
168
+ ):
169
+ """
170
+ Get threats with filtering and pagination.
171
+
172
+ Args:
173
+ pagination: Pagination parameters
174
+ severity: Filter by severity
175
+ status: Filter by status
176
+ category: Filter by category
177
+ search: Search in title and description
178
+ from_date: Filter from date
179
+ to_date: Filter to date
180
+ db: Database session
181
+ current_user: Current authenticated user
182
+
183
+ Returns:
184
+ List[Threat]: List of threats
185
+ """
186
+ # Create filter params
187
+ filter_params = ThreatFilter(
188
+ severity=severity,
189
+ status=status,
190
+ category=category,
191
+ search=search,
192
+ from_date=from_date,
193
+ to_date=to_date
194
+ )
195
+
196
+ # Get threats
197
+ threats, total = await get_threats(db, filter_params, pagination)
198
+
199
+ return threats
200
+
201
+ @router.get("/statistics", response_model=dict)
202
+ async def get_threat_statistics_endpoint(
203
+ db: AsyncSession = Depends(get_db),
204
+ current_user: User = Depends(get_current_user)
205
+ ):
206
+ """
207
+ Get threat statistics.
208
+
209
+ Args:
210
+ db: Database session
211
+ current_user: Current authenticated user
212
+
213
+ Returns:
214
+ dict: Threat statistics
215
+ """
216
+ statistics = await get_threat_statistics(db)
217
+ return statistics
src/api/schemas.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API schemas for data validation and serialization.
3
+ """
4
+ from pydantic import BaseModel, Field, validator, EmailStr
5
+ from typing import Optional, List, Dict, Any, Union
6
+ from datetime import datetime
7
+ from enum import Enum
8
+
9
+ # Pagination
10
+ class PaginationParams(BaseModel):
11
+ """Pagination parameters."""
12
+ page: int = Field(1, ge=1, description="Page number")
13
+ size: int = Field(10, ge=1, le=100, description="Items per page")
14
+
15
+ # User schemas
16
+ class UserBase(BaseModel):
17
+ """Base user schema."""
18
+ username: str
19
+ email: EmailStr
20
+ full_name: Optional[str] = None
21
+ is_active: bool = True
22
+
23
+ class UserCreate(UserBase):
24
+ """User creation schema."""
25
+ password: str
26
+
27
+ class UserUpdate(BaseModel):
28
+ """User update schema."""
29
+ username: Optional[str] = None
30
+ email: Optional[EmailStr] = None
31
+ full_name: Optional[str] = None
32
+ is_active: Optional[bool] = None
33
+ password: Optional[str] = None
34
+
35
+ class UserResponse(UserBase):
36
+ """User response schema."""
37
+ id: int
38
+ is_superuser: bool = False
39
+
40
+ class Config:
41
+ orm_mode = True
42
+
43
+ # Token schemas
44
+ class Token(BaseModel):
45
+ """Token schema."""
46
+ access_token: str
47
+ token_type: str = "bearer"
48
+
49
+ class TokenPayload(BaseModel):
50
+ """Token payload schema."""
51
+ sub: Optional[int] = None
52
+
53
+ # Dark Web Content schemas
54
+ class DarkWebContentBase(BaseModel):
55
+ """Base schema for dark web content."""
56
+ url: str
57
+ title: Optional[str] = None
58
+ content: str
59
+ content_type: str
60
+ source_name: Optional[str] = None
61
+ source_type: Optional[str] = None
62
+ language: Optional[str] = None
63
+
64
+ class DarkWebContentCreate(DarkWebContentBase):
65
+ """Schema for creating dark web content."""
66
+ relevance_score: float = 0.0
67
+ sentiment_score: float = 0.0
68
+ entity_data: Optional[str] = None
69
+
70
+ class DarkWebContentUpdate(BaseModel):
71
+ """Schema for updating dark web content."""
72
+ title: Optional[str] = None
73
+ content_status: Optional[str] = None
74
+ relevance_score: Optional[float] = None
75
+ sentiment_score: Optional[float] = None
76
+ entity_data: Optional[str] = None
77
+
78
+ class DarkWebContentResponse(DarkWebContentBase):
79
+ """Schema for dark web content response."""
80
+ id: int
81
+ domain: Optional[str] = None
82
+ content_status: str
83
+ scraped_at: datetime
84
+ relevance_score: float
85
+ sentiment_score: float
86
+ entity_data: Optional[str] = None
87
+
88
+ class Config:
89
+ orm_mode = True
90
+
91
+ # Dark Web Mention schemas
92
+ class DarkWebMentionBase(BaseModel):
93
+ """Base schema for dark web mention."""
94
+ content_id: int
95
+ keyword: str
96
+ keyword_category: Optional[str] = None
97
+ context: Optional[str] = None
98
+ snippet: Optional[str] = None
99
+ mention_type: Optional[str] = None
100
+
101
+ class DarkWebMentionCreate(DarkWebMentionBase):
102
+ """Schema for creating dark web mention."""
103
+ confidence: float = 0.0
104
+ is_verified: bool = False
105
+
106
+ class DarkWebMentionUpdate(BaseModel):
107
+ """Schema for updating dark web mention."""
108
+ keyword_category: Optional[str] = None
109
+ mention_type: Optional[str] = None
110
+ confidence: Optional[float] = None
111
+ is_verified: Optional[bool] = None
112
+
113
+ class DarkWebMentionResponse(DarkWebMentionBase):
114
+ """Schema for dark web mention response."""
115
+ id: int
116
+ confidence: float
117
+ is_verified: bool
118
+ created_at: datetime
119
+
120
+ class Config:
121
+ orm_mode = True
122
+
123
+ # Threat schemas
124
+ class ThreatBase(BaseModel):
125
+ """Base schema for threat."""
126
+ title: str
127
+ description: str
128
+ severity: str
129
+ category: str
130
+
131
+ class ThreatCreate(ThreatBase):
132
+ """Schema for creating threat."""
133
+ status: str = "New"
134
+ source_url: Optional[str] = None
135
+ source_name: Optional[str] = None
136
+ source_type: Optional[str] = None
137
+ affected_entity: Optional[str] = None
138
+ affected_entity_type: Optional[str] = None
139
+ confidence_score: float = 0.0
140
+ risk_score: float = 0.0
141
+
142
+ class ThreatUpdate(BaseModel):
143
+ """Schema for updating threat."""
144
+ title: Optional[str] = None
145
+ description: Optional[str] = None
146
+ severity: Optional[str] = None
147
+ status: Optional[str] = None
148
+ category: Optional[str] = None
149
+ affected_entity: Optional[str] = None
150
+ affected_entity_type: Optional[str] = None
151
+ confidence_score: Optional[float] = None
152
+ risk_score: Optional[float] = None
153
+
154
+ class ThreatResponse(ThreatBase):
155
+ """Schema for threat response."""
156
+ id: int
157
+ status: str
158
+ source_url: Optional[str] = None
159
+ source_name: Optional[str] = None
160
+ source_type: Optional[str] = None
161
+ discovered_at: datetime
162
+ affected_entity: Optional[str] = None
163
+ affected_entity_type: Optional[str] = None
164
+ confidence_score: float
165
+ risk_score: float
166
+
167
+ class Config:
168
+ orm_mode = True
169
+
170
+ # Indicator schemas
171
+ class IndicatorBase(BaseModel):
172
+ """Base schema for indicator."""
173
+ threat_id: int
174
+ value: str
175
+ indicator_type: str
176
+ description: Optional[str] = None
177
+
178
+ class IndicatorCreate(IndicatorBase):
179
+ """Schema for creating indicator."""
180
+ is_verified: bool = False
181
+ context: Optional[str] = None
182
+ source: Optional[str] = None
183
+ confidence_score: float = 0.0
184
+
185
+ class IndicatorUpdate(BaseModel):
186
+ """Schema for updating indicator."""
187
+ description: Optional[str] = None
188
+ is_verified: Optional[bool] = None
189
+ context: Optional[str] = None
190
+ source: Optional[str] = None
191
+ confidence_score: Optional[float] = None
192
+
193
+ class IndicatorResponse(IndicatorBase):
194
+ """Schema for indicator response."""
195
+ id: int
196
+ is_verified: bool
197
+ context: Optional[str] = None
198
+ source: Optional[str] = None
199
+ confidence_score: float
200
+ first_seen: datetime
201
+ last_seen: datetime
202
+
203
+ class Config:
204
+ orm_mode = True
205
+
206
+ # Alert schemas
207
+ class AlertBase(BaseModel):
208
+ """Base schema for alert."""
209
+ title: str
210
+ description: str
211
+ severity: str
212
+ category: str
213
+
214
+ class AlertCreate(AlertBase):
215
+ """Schema for creating alert."""
216
+ source_url: Optional[str] = None
217
+ threat_id: Optional[int] = None
218
+ mention_id: Optional[int] = None
219
+
220
+ class AlertUpdate(BaseModel):
221
+ """Schema for updating alert."""
222
+ status: str
223
+ action_taken: Optional[str] = None
224
+ assigned_to_id: Optional[int] = None
225
+ is_read: Optional[bool] = None
226
+
227
+ class AlertResponse(AlertBase):
228
+ """Schema for alert response."""
229
+ id: int
230
+ status: str
231
+ generated_at: datetime
232
+ source_url: Optional[str] = None
233
+ is_read: bool
234
+ threat_id: Optional[int] = None
235
+ mention_id: Optional[int] = None
236
+ assigned_to_id: Optional[int] = None
237
+ action_taken: Optional[str] = None
238
+ resolved_at: Optional[datetime] = None
239
+
240
+ class Config:
241
+ orm_mode = True
242
+
243
+ # Report schemas
244
+ class ReportBase(BaseModel):
245
+ """Base schema for report."""
246
+ report_id: str
247
+ title: str
248
+ summary: str
249
+ content: str
250
+ report_type: str
251
+
252
+ class ReportCreate(ReportBase):
253
+ """Schema for creating report."""
254
+ status: str = "Draft"
255
+ severity: Optional[str] = None
256
+ publish_date: Optional[datetime] = None
257
+ time_period_start: Optional[datetime] = None
258
+ time_period_end: Optional[datetime] = None
259
+ keywords: Optional[str] = None
260
+ author_id: int
261
+ threat_ids: List[int] = []
262
+
263
+ class ReportUpdate(BaseModel):
264
+ """Schema for updating report."""
265
+ title: Optional[str] = None
266
+ summary: Optional[str] = None
267
+ content: Optional[str] = None
268
+ report_type: Optional[str] = None
269
+ status: Optional[str] = None
270
+ severity: Optional[str] = None
271
+ publish_date: Optional[datetime] = None
272
+ time_period_start: Optional[datetime] = None
273
+ time_period_end: Optional[datetime] = None
274
+ keywords: Optional[str] = None
275
+ threat_ids: Optional[List[int]] = None
276
+
277
+ class ReportResponse(ReportBase):
278
+ """Schema for report response."""
279
+ id: int
280
+ status: str
281
+ severity: Optional[str] = None
282
+ publish_date: Optional[datetime] = None
283
+ time_period_start: Optional[datetime] = None
284
+ time_period_end: Optional[datetime] = None
285
+ keywords: Optional[str] = None
286
+ author_id: int
287
+
288
+ class Config:
289
+ orm_mode = True
290
+
291
+ # Statistics response schemas
292
+ class ThreatStatisticsResponse(BaseModel):
293
+ """Schema for threat statistics response."""
294
+ total_count: int
295
+ severity_counts: Dict[str, int]
296
+ status_counts: Dict[str, int]
297
+ category_counts: Dict[str, int]
298
+ time_series: List[Dict[str, Any]]
299
+ from_date: str
300
+ to_date: str
301
+
302
+ class ContentStatisticsResponse(BaseModel):
303
+ """Schema for content statistics response."""
304
+ total_count: int
305
+ content_type_counts: Dict[str, int]
306
+ content_status_counts: Dict[str, int]
307
+ source_counts: Dict[str, int]
308
+ time_series: List[Dict[str, Any]]
309
+ from_date: str
310
+ to_date: str
src/api/security.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API Security Module
3
+
4
+ This module provides security features for the API, including:
5
+ 1. Authentication using JWT tokens
6
+ 2. Rate limiting to prevent abuse
7
+ 3. Role-based access control
8
+ 4. Request validation
9
+ 5. Audit logging
10
+ """
11
+ import os
12
+ import time
13
+ import logging
14
+ import secrets
15
+ from datetime import datetime, timedelta
16
+ from typing import Dict, List, Optional, Union, Any, Callable
17
+
18
+ from fastapi import Depends, HTTPException, Security, status, Request
19
+ from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
20
+ from jose import JWTError, jwt
21
+ from passlib.context import CryptContext
22
+ from sqlalchemy.ext.asyncio import AsyncSession
23
+ from sqlalchemy.future import select
24
+ from pydantic import BaseModel, EmailStr
25
+
26
+ from src.models.user import User
27
+ from src.api.database import get_db
28
+
29
+ # Configure logging
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Security configuration
33
+ SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_hex(32))
34
+ ALGORITHM = "HS256"
35
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
36
+ API_KEY_NAME = "X-API-Key"
37
+
38
+ # Set up password hashing
39
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
40
+
41
+ # Set up security schemes
42
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
43
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
44
+
45
+ # User models
46
+ class Token(BaseModel):
47
+ access_token: str
48
+ token_type: str
49
+ expires_at: datetime
50
+
51
+ class TokenData(BaseModel):
52
+ username: Optional[str] = None
53
+ scopes: List[str] = []
54
+
55
+ class UserInDB(BaseModel):
56
+ id: int
57
+ username: str
58
+ email: EmailStr
59
+ full_name: Optional[str] = None
60
+ is_active: bool = True
61
+ is_superuser: bool = False
62
+ scopes: List[str] = []
63
+
64
+ class Config:
65
+ from_attributes = True
66
+
67
+ # Rate limiting
68
+ class RateLimiter:
69
+ """Simple in-memory rate limiter"""
70
+
71
+ def __init__(self, rate_limit: int = 100, time_window: int = 60):
72
+ """
73
+ Initialize rate limiter.
74
+
75
+ Args:
76
+ rate_limit: Maximum number of requests per time window
77
+ time_window: Time window in seconds
78
+ """
79
+ self.rate_limit = rate_limit
80
+ self.time_window = time_window
81
+ self.requests = {}
82
+
83
+ def is_rate_limited(self, key: str) -> bool:
84
+ """
85
+ Check if a key is rate limited.
86
+
87
+ Args:
88
+ key: Identifier for the client (IP address, API key, etc.)
89
+
90
+ Returns:
91
+ True if rate limited, False otherwise
92
+ """
93
+ current_time = time.time()
94
+
95
+ # Initialize or clean up old requests
96
+ if key not in self.requests:
97
+ self.requests[key] = []
98
+ else:
99
+ # Remove requests outside the time window
100
+ self.requests[key] = [t for t in self.requests[key] if t > current_time - self.time_window]
101
+
102
+ # Check if rate limit is exceeded
103
+ if len(self.requests[key]) >= self.rate_limit:
104
+ return True
105
+
106
+ # Add the current request
107
+ self.requests[key].append(current_time)
108
+ return False
109
+
110
+ # Create global rate limiter instance
111
+ rate_limiter = RateLimiter()
112
+
113
+ # Role-based access control
114
+ # Define roles and permissions
115
+ ROLES = {
116
+ "admin": ["read:all", "write:all", "delete:all"],
117
+ "analyst": ["read:all", "write:threats", "write:indicators", "write:reports"],
118
+ "user": ["read:threats", "read:reports", "read:dashboard"],
119
+ "api": ["read:all", "write:threats", "write:indicators"]
120
+ }
121
+
122
+ # Security utility functions
123
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
124
+ """Verify a password against a hash"""
125
+ return pwd_context.verify(plain_password, hashed_password)
126
+
127
+ def get_password_hash(password: str) -> str:
128
+ """Hash a password for storage"""
129
+ return pwd_context.hash(password)
130
+
131
+ async def get_user(db: AsyncSession, username: str) -> Optional[UserInDB]:
132
+ """Get a user from the database by username"""
133
+ result = await db.execute(select(User).filter(User.username == username))
134
+ user_db = result.scalars().first()
135
+
136
+ if not user_db:
137
+ return None
138
+
139
+ # Get user roles and scopes
140
+ scopes = []
141
+ if user_db.is_superuser:
142
+ scopes = ROLES["admin"]
143
+ else:
144
+ # In a real application, you would look up user roles in a database
145
+ # For simplicity, we'll assume non-superusers have the "user" role
146
+ scopes = ROLES["user"]
147
+
148
+ return UserInDB(
149
+ id=user_db.id,
150
+ username=user_db.username,
151
+ email=user_db.email,
152
+ full_name=user_db.full_name,
153
+ is_active=user_db.is_active,
154
+ is_superuser=user_db.is_superuser,
155
+ scopes=scopes
156
+ )
157
+
158
+ async def authenticate_user(db: AsyncSession, username: str, password: str) -> Optional[UserInDB]:
159
+ """Authenticate a user with username and password"""
160
+ user = await get_user(db, username)
161
+ if not user:
162
+ return None
163
+
164
+ # Get the user from the database again to get the hashed password
165
+ result = await db.execute(select(User).filter(User.username == username))
166
+ user_db = result.scalars().first()
167
+
168
+ if not verify_password(password, user_db.hashed_password):
169
+ return None
170
+
171
+ return user
172
+
173
+ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
174
+ """Create a JWT access token"""
175
+ to_encode = data.copy()
176
+
177
+ if expires_delta:
178
+ expire = datetime.utcnow() + expires_delta
179
+ else:
180
+ expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
181
+
182
+ to_encode.update({"exp": expire})
183
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
184
+ return encoded_jwt
185
+
186
+ async def get_api_key_user(
187
+ api_key: str,
188
+ db: AsyncSession
189
+ ) -> Optional[UserInDB]:
190
+ """Get user associated with an API key"""
191
+ # In a real application, you would look up API keys in a database
192
+ # For simplicity, we'll use a simple hardcoded mapping
193
+ # TODO: Replace with database-backed API key storage
194
+ API_KEYS = {
195
+ "test-api-key": "api_user",
196
+ # Add more API keys here
197
+ }
198
+
199
+ if api_key not in API_KEYS:
200
+ return None
201
+
202
+ username = API_KEYS[api_key]
203
+ user = await get_user(db, username)
204
+
205
+ if not user:
206
+ return None
207
+
208
+ # Override scopes with API role scopes
209
+ user.scopes = ROLES["api"]
210
+
211
+ return user
212
+
213
+ # Dependencies for FastAPI
214
+ async def rate_limit(request: Request):
215
+ """Rate limiting dependency"""
216
+ # Use API key or IP address as the rate limit key
217
+ client_key = request.headers.get(API_KEY_NAME) or request.client.host
218
+
219
+ if rate_limiter.is_rate_limited(client_key):
220
+ logger.warning(f"Rate limit exceeded for {client_key}")
221
+ raise HTTPException(
222
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
223
+ detail="Rate limit exceeded. Please try again later."
224
+ )
225
+
226
+ return True
227
+
228
+ async def get_current_user(
229
+ token: str = Depends(oauth2_scheme),
230
+ api_key: str = Security(api_key_header),
231
+ db: AsyncSession = Depends(get_db)
232
+ ) -> UserInDB:
233
+ """
234
+ Get the current user from either JWT token or API key.
235
+
236
+ This dependency can be used to require authentication for endpoints.
237
+ """
238
+ credentials_exception = HTTPException(
239
+ status_code=status.HTTP_401_UNAUTHORIZED,
240
+ detail="Could not validate credentials",
241
+ headers={"WWW-Authenticate": "Bearer"},
242
+ )
243
+
244
+ # Check API key first
245
+ if api_key:
246
+ user = await get_api_key_user(api_key, db)
247
+ if user:
248
+ return user
249
+
250
+ # Then check JWT token
251
+ try:
252
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
253
+ username = payload.get("sub")
254
+ if username is None:
255
+ raise credentials_exception
256
+
257
+ token_data = TokenData(
258
+ username=username,
259
+ scopes=payload.get("scopes", [])
260
+ )
261
+ except JWTError:
262
+ raise credentials_exception
263
+
264
+ user = await get_user(db, username=token_data.username)
265
+ if user is None:
266
+ raise credentials_exception
267
+
268
+ return user
269
+
270
+ async def get_current_active_user(
271
+ current_user: UserInDB = Depends(get_current_user)
272
+ ) -> UserInDB:
273
+ """
274
+ Get the current active user.
275
+
276
+ This dependency can be used to require an active user for endpoints.
277
+ """
278
+ if not current_user.is_active:
279
+ raise HTTPException(status_code=400, detail="Inactive user")
280
+
281
+ return current_user
282
+
283
+ def has_scope(required_scopes: List[str]):
284
+ """
285
+ Create a dependency that checks if the user has the required scopes.
286
+
287
+ Args:
288
+ required_scopes: List of required scopes
289
+
290
+ Returns:
291
+ A dependency function that checks if the user has the required scopes
292
+ """
293
+ async def _has_scope(
294
+ current_user: UserInDB = Depends(get_current_active_user)
295
+ ) -> UserInDB:
296
+ for scope in required_scopes:
297
+ if scope not in current_user.scopes:
298
+ raise HTTPException(
299
+ status_code=status.HTTP_403_FORBIDDEN,
300
+ detail=f"Permission denied. Required scope: {scope}"
301
+ )
302
+
303
+ return current_user
304
+
305
+ return _has_scope
306
+
307
+ def admin_only(
308
+ current_user: UserInDB = Depends(get_current_active_user)
309
+ ) -> UserInDB:
310
+ """
311
+ Dependency that requires an admin user.
312
+ """
313
+ if not current_user.is_superuser:
314
+ raise HTTPException(
315
+ status_code=status.HTTP_403_FORBIDDEN,
316
+ detail="Permission denied. Admin access required."
317
+ )
318
+
319
+ return current_user
320
+
321
+ # Audit logging middleware
322
+ async def audit_log_middleware(request: Request, call_next):
323
+ """
324
+ Middleware for audit logging.
325
+
326
+ Records details about API requests.
327
+ """
328
+ # Get request details
329
+ method = request.method
330
+ path = request.url.path
331
+ client_host = request.client.host
332
+ user_agent = request.headers.get("User-Agent", "Unknown")
333
+
334
+ # Get user details if available
335
+ user = getattr(request.state, "user", None)
336
+ username = user.username if user else "Anonymous"
337
+
338
+ # Log request
339
+ logger.info(
340
+ f"API Request: {method} {path} | User: {username} | "
341
+ f"Client: {client_host} | User-Agent: {user_agent}"
342
+ )
343
+
344
+ # Process the request
345
+ start_time = time.time()
346
+ response = await call_next(request)
347
+ process_time = time.time() - start_time
348
+
349
+ # Log response
350
+ logger.info(
351
+ f"API Response: {method} {path} | Status: {response.status_code} | "
352
+ f"Time: {process_time:.4f}s | User: {username}"
353
+ )
354
+
355
+ return response
356
+
357
+ # API key validation middleware
358
+ def validate_api_key(request: Request):
359
+ """
360
+ Middleware function to validate API keys.
361
+
362
+ This can be used as a dependency for FastAPI routes.
363
+ """
364
+ api_key = request.headers.get(API_KEY_NAME)
365
+ if not api_key:
366
+ raise HTTPException(
367
+ status_code=status.HTTP_401_UNAUTHORIZED,
368
+ detail="API key required",
369
+ headers={"WWW-Authenticate": f"{API_KEY_NAME}"},
370
+ )
371
+
372
+ # In a real application, you would validate the API key against a database
373
+ # For simplicity, we'll use a hardcoded list
374
+ valid_keys = ["test-api-key"] # Replace with database lookup
375
+ if api_key not in valid_keys:
376
+ raise HTTPException(
377
+ status_code=status.HTTP_401_UNAUTHORIZED,
378
+ detail="Invalid API key",
379
+ headers={"WWW-Authenticate": f"{API_KEY_NAME}"},
380
+ )
381
+
382
+ return True
src/api/services/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Package initialization for API services.
3
+ """
src/api/services/alert_service.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Service for alert operations.
3
+ """
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy.future import select
6
+ from sqlalchemy import func, or_, and_
7
+ from datetime import datetime
8
+ from typing import List, Optional, Dict, Any, Union
9
+
10
+ from src.models.alert import Alert, AlertStatus, AlertCategory
11
+ from src.models.threat import ThreatSeverity
12
+ from src.api.schemas import PaginationParams
13
+
14
+ async def create_alert(
15
+ db: AsyncSession,
16
+ title: str,
17
+ description: str,
18
+ severity: ThreatSeverity,
19
+ category: AlertCategory,
20
+ source_url: Optional[str] = None,
21
+ threat_id: Optional[int] = None,
22
+ mention_id: Optional[int] = None,
23
+ ) -> Alert:
24
+ """
25
+ Create a new alert.
26
+
27
+ Args:
28
+ db: Database session
29
+ title: Alert title
30
+ description: Alert description
31
+ severity: Alert severity
32
+ category: Alert category
33
+ source_url: Source URL for the alert
34
+ threat_id: ID of related threat
35
+ mention_id: ID of related dark web mention
36
+
37
+ Returns:
38
+ Alert: Created alert
39
+ """
40
+ db_alert = Alert(
41
+ title=title,
42
+ description=description,
43
+ severity=severity,
44
+ status=AlertStatus.NEW,
45
+ category=category,
46
+ generated_at=datetime.utcnow(),
47
+ source_url=source_url,
48
+ is_read=False,
49
+ threat_id=threat_id,
50
+ mention_id=mention_id,
51
+ )
52
+
53
+ db.add(db_alert)
54
+ await db.commit()
55
+ await db.refresh(db_alert)
56
+
57
+ return db_alert
58
+
59
+ async def get_alert_by_id(db: AsyncSession, alert_id: int) -> Optional[Alert]:
60
+ """
61
+ Get alert by ID.
62
+
63
+ Args:
64
+ db: Database session
65
+ alert_id: Alert ID
66
+
67
+ Returns:
68
+ Optional[Alert]: Alert or None if not found
69
+ """
70
+ result = await db.execute(select(Alert).filter(Alert.id == alert_id))
71
+ return result.scalars().first()
72
+
73
+ async def get_alerts(
74
+ db: AsyncSession,
75
+ pagination: PaginationParams,
76
+ severity: Optional[List[ThreatSeverity]] = None,
77
+ status: Optional[List[AlertStatus]] = None,
78
+ category: Optional[List[AlertCategory]] = None,
79
+ is_read: Optional[bool] = None,
80
+ search_query: Optional[str] = None,
81
+ from_date: Optional[datetime] = None,
82
+ to_date: Optional[datetime] = None,
83
+ ) -> List[Alert]:
84
+ """
85
+ Get alerts with filtering and pagination.
86
+
87
+ Args:
88
+ db: Database session
89
+ pagination: Pagination parameters
90
+ severity: Filter by severity
91
+ status: Filter by status
92
+ category: Filter by category
93
+ is_read: Filter by read status
94
+ search_query: Search in title and description
95
+ from_date: Filter by generated_at >= from_date
96
+ to_date: Filter by generated_at <= to_date
97
+
98
+ Returns:
99
+ List[Alert]: List of alerts
100
+ """
101
+ query = select(Alert)
102
+
103
+ # Apply filters
104
+ if severity:
105
+ query = query.filter(Alert.severity.in_(severity))
106
+
107
+ if status:
108
+ query = query.filter(Alert.status.in_(status))
109
+
110
+ if category:
111
+ query = query.filter(Alert.category.in_(category))
112
+
113
+ if is_read is not None:
114
+ query = query.filter(Alert.is_read == is_read)
115
+
116
+ if search_query:
117
+ search_filter = or_(
118
+ Alert.title.ilike(f"%{search_query}%"),
119
+ Alert.description.ilike(f"%{search_query}%")
120
+ )
121
+ query = query.filter(search_filter)
122
+
123
+ if from_date:
124
+ query = query.filter(Alert.generated_at >= from_date)
125
+
126
+ if to_date:
127
+ query = query.filter(Alert.generated_at <= to_date)
128
+
129
+ # Apply pagination
130
+ query = query.order_by(Alert.generated_at.desc())
131
+ query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
132
+
133
+ result = await db.execute(query)
134
+ return result.scalars().all()
135
+
136
+ async def count_alerts(
137
+ db: AsyncSession,
138
+ severity: Optional[List[ThreatSeverity]] = None,
139
+ status: Optional[List[AlertStatus]] = None,
140
+ category: Optional[List[AlertCategory]] = None,
141
+ is_read: Optional[bool] = None,
142
+ search_query: Optional[str] = None,
143
+ from_date: Optional[datetime] = None,
144
+ to_date: Optional[datetime] = None,
145
+ ) -> int:
146
+ """
147
+ Count alerts with filtering.
148
+
149
+ Args:
150
+ db: Database session
151
+ severity: Filter by severity
152
+ status: Filter by status
153
+ category: Filter by category
154
+ is_read: Filter by read status
155
+ search_query: Search in title and description
156
+ from_date: Filter by generated_at >= from_date
157
+ to_date: Filter by generated_at <= to_date
158
+
159
+ Returns:
160
+ int: Count of alerts
161
+ """
162
+ query = select(func.count(Alert.id))
163
+
164
+ # Apply filters (same as in get_alerts)
165
+ if severity:
166
+ query = query.filter(Alert.severity.in_(severity))
167
+
168
+ if status:
169
+ query = query.filter(Alert.status.in_(status))
170
+
171
+ if category:
172
+ query = query.filter(Alert.category.in_(category))
173
+
174
+ if is_read is not None:
175
+ query = query.filter(Alert.is_read == is_read)
176
+
177
+ if search_query:
178
+ search_filter = or_(
179
+ Alert.title.ilike(f"%{search_query}%"),
180
+ Alert.description.ilike(f"%{search_query}%")
181
+ )
182
+ query = query.filter(search_filter)
183
+
184
+ if from_date:
185
+ query = query.filter(Alert.generated_at >= from_date)
186
+
187
+ if to_date:
188
+ query = query.filter(Alert.generated_at <= to_date)
189
+
190
+ result = await db.execute(query)
191
+ return result.scalar()
192
+
193
+ async def update_alert_status(
194
+ db: AsyncSession,
195
+ alert_id: int,
196
+ status: AlertStatus,
197
+ action_taken: Optional[str] = None,
198
+ ) -> Optional[Alert]:
199
+ """
200
+ Update alert status.
201
+
202
+ Args:
203
+ db: Database session
204
+ alert_id: Alert ID
205
+ status: New status
206
+ action_taken: Description of action taken
207
+
208
+ Returns:
209
+ Optional[Alert]: Updated alert or None if not found
210
+ """
211
+ alert = await get_alert_by_id(db, alert_id)
212
+ if not alert:
213
+ return None
214
+
215
+ alert.status = status
216
+
217
+ if action_taken:
218
+ alert.action_taken = action_taken
219
+
220
+ if status == AlertStatus.RESOLVED:
221
+ alert.resolved_at = datetime.utcnow()
222
+
223
+ alert.updated_at = datetime.utcnow()
224
+
225
+ await db.commit()
226
+ await db.refresh(alert)
227
+
228
+ return alert
229
+
230
+ async def mark_alert_as_read(
231
+ db: AsyncSession,
232
+ alert_id: int,
233
+ ) -> Optional[Alert]:
234
+ """
235
+ Mark alert as read.
236
+
237
+ Args:
238
+ db: Database session
239
+ alert_id: Alert ID
240
+
241
+ Returns:
242
+ Optional[Alert]: Updated alert or None if not found
243
+ """
244
+ alert = await get_alert_by_id(db, alert_id)
245
+ if not alert:
246
+ return None
247
+
248
+ alert.is_read = True
249
+ alert.updated_at = datetime.utcnow()
250
+
251
+ await db.commit()
252
+ await db.refresh(alert)
253
+
254
+ return alert
255
+
256
+ async def assign_alert(
257
+ db: AsyncSession,
258
+ alert_id: int,
259
+ user_id: int,
260
+ ) -> Optional[Alert]:
261
+ """
262
+ Assign alert to a user.
263
+
264
+ Args:
265
+ db: Database session
266
+ alert_id: Alert ID
267
+ user_id: User ID to assign to
268
+
269
+ Returns:
270
+ Optional[Alert]: Updated alert or None if not found
271
+ """
272
+ alert = await get_alert_by_id(db, alert_id)
273
+ if not alert:
274
+ return None
275
+
276
+ alert.assigned_to_id = user_id
277
+ alert.status = AlertStatus.ASSIGNED
278
+ alert.updated_at = datetime.utcnow()
279
+
280
+ await db.commit()
281
+ await db.refresh(alert)
282
+
283
+ return alert
284
+
285
+ async def get_alert_counts_by_severity(
286
+ db: AsyncSession,
287
+ from_date: Optional[datetime] = None,
288
+ to_date: Optional[datetime] = None,
289
+ ) -> Dict[str, int]:
290
+ """
291
+ Get count of alerts by severity.
292
+
293
+ Args:
294
+ db: Database session
295
+ from_date: Filter by generated_at >= from_date
296
+ to_date: Filter by generated_at <= to_date
297
+
298
+ Returns:
299
+ Dict[str, int]: Mapping of severity to count
300
+ """
301
+ result = {}
302
+
303
+ for severity in ThreatSeverity:
304
+ query = select(func.count(Alert.id)).filter(Alert.severity == severity)
305
+
306
+ if from_date:
307
+ query = query.filter(Alert.generated_at >= from_date)
308
+
309
+ if to_date:
310
+ query = query.filter(Alert.generated_at <= to_date)
311
+
312
+ count_result = await db.execute(query)
313
+ count = count_result.scalar() or 0
314
+ result[severity.value] = count
315
+
316
+ return result
src/api/services/dark_web_content_service.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Service for dark web content operations.
3
+ """
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy.future import select
6
+ from sqlalchemy import func, or_, text
7
+ from datetime import datetime
8
+ from typing import List, Optional, Dict, Any, Union
9
+
10
+ from src.models.dark_web_content import DarkWebContent, DarkWebMention, ContentType, ContentStatus
11
+ from src.models.threat import Threat, ThreatCategory, ThreatSeverity, ThreatStatus
12
+ from src.api.schemas import PaginationParams
13
+
14
+ async def create_content(
15
+ db: AsyncSession,
16
+ url: str,
17
+ content: str,
18
+ title: Optional[str] = None,
19
+ content_type: ContentType = ContentType.OTHER,
20
+ content_status: ContentStatus = ContentStatus.NEW,
21
+ source_name: Optional[str] = None,
22
+ source_type: Optional[str] = None,
23
+ language: Optional[str] = None,
24
+ relevance_score: float = 0.0,
25
+ sentiment_score: float = 0.0,
26
+ entity_data: Optional[str] = None,
27
+ ) -> DarkWebContent:
28
+ """
29
+ Create a new dark web content entry.
30
+
31
+ Args:
32
+ db: Database session
33
+ url: URL of the content
34
+ content: Text content
35
+ title: Title of the content
36
+ content_type: Type of content
37
+ content_status: Status of content
38
+ source_name: Name of the source
39
+ source_type: Type of source
40
+ language: Language of the content
41
+ relevance_score: Relevance score (0-1)
42
+ sentiment_score: Sentiment score (-1 to 1)
43
+ entity_data: JSON string of extracted entities
44
+
45
+ Returns:
46
+ DarkWebContent: Created content
47
+ """
48
+ # Extract domain from URL if possible
49
+ domain = None
50
+ if url:
51
+ try:
52
+ from urllib.parse import urlparse
53
+ parsed_url = urlparse(url)
54
+ domain = parsed_url.netloc
55
+ except:
56
+ pass
57
+
58
+ db_content = DarkWebContent(
59
+ url=url,
60
+ domain=domain,
61
+ title=title,
62
+ content=content,
63
+ content_type=content_type,
64
+ content_status=content_status,
65
+ source_name=source_name,
66
+ source_type=source_type,
67
+ language=language,
68
+ scraped_at=datetime.utcnow(),
69
+ relevance_score=relevance_score,
70
+ sentiment_score=sentiment_score,
71
+ entity_data=entity_data,
72
+ )
73
+
74
+ db.add(db_content)
75
+ await db.commit()
76
+ await db.refresh(db_content)
77
+
78
+ return db_content
79
+
80
+ async def get_content_by_id(db: AsyncSession, content_id: int) -> Optional[DarkWebContent]:
81
+ """
82
+ Get dark web content by ID.
83
+
84
+ Args:
85
+ db: Database session
86
+ content_id: Content ID
87
+
88
+ Returns:
89
+ Optional[DarkWebContent]: Content or None if not found
90
+ """
91
+ result = await db.execute(select(DarkWebContent).filter(DarkWebContent.id == content_id))
92
+ return result.scalars().first()
93
+
94
+ async def get_contents(
95
+ db: AsyncSession,
96
+ pagination: PaginationParams,
97
+ content_type: Optional[List[ContentType]] = None,
98
+ content_status: Optional[List[ContentStatus]] = None,
99
+ source_name: Optional[str] = None,
100
+ search_query: Optional[str] = None,
101
+ from_date: Optional[datetime] = None,
102
+ to_date: Optional[datetime] = None,
103
+ ) -> List[DarkWebContent]:
104
+ """
105
+ Get dark web contents with filtering and pagination.
106
+
107
+ Args:
108
+ db: Database session
109
+ pagination: Pagination parameters
110
+ content_type: Filter by content type
111
+ content_status: Filter by content status
112
+ source_name: Filter by source name
113
+ search_query: Search in title and content
114
+ from_date: Filter by scraped_at >= from_date
115
+ to_date: Filter by scraped_at <= to_date
116
+
117
+ Returns:
118
+ List[DarkWebContent]: List of dark web contents
119
+ """
120
+ query = select(DarkWebContent)
121
+
122
+ # Apply filters
123
+ if content_type:
124
+ query = query.filter(DarkWebContent.content_type.in_(content_type))
125
+
126
+ if content_status:
127
+ query = query.filter(DarkWebContent.content_status.in_(content_status))
128
+
129
+ if source_name:
130
+ query = query.filter(DarkWebContent.source_name == source_name)
131
+
132
+ if search_query:
133
+ search_filter = or_(
134
+ DarkWebContent.title.ilike(f"%{search_query}%"),
135
+ DarkWebContent.content.ilike(f"%{search_query}%")
136
+ )
137
+ query = query.filter(search_filter)
138
+
139
+ if from_date:
140
+ query = query.filter(DarkWebContent.scraped_at >= from_date)
141
+
142
+ if to_date:
143
+ query = query.filter(DarkWebContent.scraped_at <= to_date)
144
+
145
+ # Apply pagination
146
+ query = query.order_by(DarkWebContent.scraped_at.desc())
147
+ query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
148
+
149
+ result = await db.execute(query)
150
+ return result.scalars().all()
151
+
152
+ async def count_contents(
153
+ db: AsyncSession,
154
+ content_type: Optional[List[ContentType]] = None,
155
+ content_status: Optional[List[ContentStatus]] = None,
156
+ source_name: Optional[str] = None,
157
+ search_query: Optional[str] = None,
158
+ from_date: Optional[datetime] = None,
159
+ to_date: Optional[datetime] = None,
160
+ ) -> int:
161
+ """
162
+ Count dark web contents with filtering.
163
+
164
+ Args:
165
+ db: Database session
166
+ content_type: Filter by content type
167
+ content_status: Filter by content status
168
+ source_name: Filter by source name
169
+ search_query: Search in title and content
170
+ from_date: Filter by scraped_at >= from_date
171
+ to_date: Filter by scraped_at <= to_date
172
+
173
+ Returns:
174
+ int: Count of dark web contents
175
+ """
176
+ query = select(func.count(DarkWebContent.id))
177
+
178
+ # Apply filters (same as in get_contents)
179
+ if content_type:
180
+ query = query.filter(DarkWebContent.content_type.in_(content_type))
181
+
182
+ if content_status:
183
+ query = query.filter(DarkWebContent.content_status.in_(content_status))
184
+
185
+ if source_name:
186
+ query = query.filter(DarkWebContent.source_name == source_name)
187
+
188
+ if search_query:
189
+ search_filter = or_(
190
+ DarkWebContent.title.ilike(f"%{search_query}%"),
191
+ DarkWebContent.content.ilike(f"%{search_query}%")
192
+ )
193
+ query = query.filter(search_filter)
194
+
195
+ if from_date:
196
+ query = query.filter(DarkWebContent.scraped_at >= from_date)
197
+
198
+ if to_date:
199
+ query = query.filter(DarkWebContent.scraped_at <= to_date)
200
+
201
+ result = await db.execute(query)
202
+ return result.scalar()
203
+
204
+ async def create_mention(
205
+ db: AsyncSession,
206
+ content_id: int,
207
+ keyword: str,
208
+ keyword_category: Optional[str] = None,
209
+ context: Optional[str] = None,
210
+ snippet: Optional[str] = None,
211
+ mention_type: Optional[str] = None,
212
+ confidence: float = 0.0,
213
+ is_verified: bool = False,
214
+ ) -> DarkWebMention:
215
+ """
216
+ Create a new dark web mention.
217
+
218
+ Args:
219
+ db: Database session
220
+ content_id: ID of the content where the mention was found
221
+ keyword: Keyword that was mentioned
222
+ keyword_category: Category of the keyword
223
+ context: Text surrounding the mention
224
+ snippet: Extract of text containing the mention
225
+ mention_type: Type of mention
226
+ confidence: Confidence score (0-1)
227
+ is_verified: Whether the mention is verified
228
+
229
+ Returns:
230
+ DarkWebMention: Created mention
231
+ """
232
+ db_mention = DarkWebMention(
233
+ content_id=content_id,
234
+ keyword=keyword,
235
+ keyword_category=keyword_category,
236
+ context=context,
237
+ snippet=snippet,
238
+ mention_type=mention_type,
239
+ confidence=confidence,
240
+ is_verified=is_verified,
241
+ )
242
+
243
+ db.add(db_mention)
244
+ await db.commit()
245
+ await db.refresh(db_mention)
246
+
247
+ return db_mention
248
+
249
+ async def get_mention_by_id(db: AsyncSession, mention_id: int) -> Optional[DarkWebMention]:
250
+ """
251
+ Get dark web mention by ID.
252
+
253
+ Args:
254
+ db: Database session
255
+ mention_id: Mention ID
256
+
257
+ Returns:
258
+ Optional[DarkWebMention]: Mention or None if not found
259
+ """
260
+ result = await db.execute(select(DarkWebMention).filter(DarkWebMention.id == mention_id))
261
+ return result.scalars().first()
262
+
263
+ async def get_mentions(
264
+ db: AsyncSession,
265
+ pagination: PaginationParams,
266
+ keyword: Optional[str] = None,
267
+ content_id: Optional[int] = None,
268
+ is_verified: Optional[bool] = None,
269
+ from_date: Optional[datetime] = None,
270
+ to_date: Optional[datetime] = None,
271
+ ) -> List[DarkWebMention]:
272
+ """
273
+ Get dark web mentions with filtering and pagination.
274
+
275
+ Args:
276
+ db: Database session
277
+ pagination: Pagination parameters
278
+ keyword: Filter by keyword
279
+ content_id: Filter by content ID
280
+ is_verified: Filter by verification status
281
+ from_date: Filter by created_at >= from_date
282
+ to_date: Filter by created_at <= to_date
283
+
284
+ Returns:
285
+ List[DarkWebMention]: List of dark web mentions
286
+ """
287
+ query = select(DarkWebMention)
288
+
289
+ # Apply filters
290
+ if keyword:
291
+ query = query.filter(DarkWebMention.keyword.ilike(f"%{keyword}%"))
292
+
293
+ if content_id:
294
+ query = query.filter(DarkWebMention.content_id == content_id)
295
+
296
+ if is_verified is not None:
297
+ query = query.filter(DarkWebMention.is_verified == is_verified)
298
+
299
+ if from_date:
300
+ query = query.filter(DarkWebMention.created_at >= from_date)
301
+
302
+ if to_date:
303
+ query = query.filter(DarkWebMention.created_at <= to_date)
304
+
305
+ # Apply pagination
306
+ query = query.order_by(DarkWebMention.created_at.desc())
307
+ query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
308
+
309
+ result = await db.execute(query)
310
+ return result.scalars().all()
311
+
312
+ async def create_threat_from_content(
313
+ db: AsyncSession,
314
+ content_id: int,
315
+ title: str,
316
+ description: str,
317
+ severity: ThreatSeverity,
318
+ category: ThreatCategory,
319
+ confidence_score: float = 0.0,
320
+ ) -> Threat:
321
+ """
322
+ Create a threat from dark web content.
323
+
324
+ Args:
325
+ db: Database session
326
+ content_id: ID of the content
327
+ title: Threat title
328
+ description: Threat description
329
+ severity: Threat severity
330
+ category: Threat category
331
+ confidence_score: Confidence score (0-1)
332
+
333
+ Returns:
334
+ Threat: Created threat
335
+ """
336
+ # Get the content
337
+ content = await get_content_by_id(db, content_id)
338
+ if not content:
339
+ raise ValueError(f"Content with ID {content_id} not found")
340
+
341
+ # Create the threat
342
+ from src.api.services.threat_service import create_threat
343
+
344
+ threat = await create_threat(
345
+ db=db,
346
+ title=title,
347
+ description=description,
348
+ severity=severity,
349
+ category=category,
350
+ status=ThreatStatus.NEW,
351
+ source_url=content.url,
352
+ source_name=content.source_name,
353
+ source_type=content.source_type,
354
+ confidence_score=confidence_score,
355
+ )
356
+
357
+ return threat
src/api/services/report_service.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Service for working with intelligence reports.
3
+ """
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy.future import select
6
+ from sqlalchemy import update, delete, func, desc, and_, or_
7
+ from typing import List, Optional, Dict, Any, Union
8
+ import logging
9
+ from datetime import datetime
10
+
11
+ from src.models.report import Report, ReportType, ReportStatus
12
+ from src.models.threat import ThreatSeverity
13
+ from src.api.schemas import PaginationParams
14
+
15
+ # Configure logger
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ async def create_report(
20
+ db: AsyncSession,
21
+ title: str,
22
+ summary: str,
23
+ content: str,
24
+ report_type: ReportType,
25
+ report_id: str,
26
+ status: ReportStatus = ReportStatus.DRAFT,
27
+ severity: Optional[ThreatSeverity] = None,
28
+ publish_date: Optional[datetime] = None,
29
+ time_period_start: Optional[datetime] = None,
30
+ time_period_end: Optional[datetime] = None,
31
+ keywords: Optional[List[str]] = None,
32
+ source_data: Optional[Dict[str, Any]] = None,
33
+ author_id: Optional[int] = None,
34
+ ) -> Report:
35
+ """
36
+ Create a new intelligence report.
37
+
38
+ Args:
39
+ db: Database session
40
+ title: Report title
41
+ summary: Report summary
42
+ content: Report content
43
+ report_type: Type of report
44
+ report_id: Custom ID for the report
45
+ status: Report status
46
+ severity: Report severity
47
+ publish_date: Publication date
48
+ time_period_start: Start of time period covered
49
+ time_period_end: End of time period covered
50
+ keywords: List of keywords related to the report
51
+ source_data: Sources and references
52
+ author_id: ID of the report author
53
+
54
+ Returns:
55
+ Report: Created report
56
+ """
57
+ report = Report(
58
+ title=title,
59
+ summary=summary,
60
+ content=content,
61
+ report_type=report_type,
62
+ report_id=report_id,
63
+ status=status,
64
+ severity=severity,
65
+ publish_date=publish_date,
66
+ time_period_start=time_period_start,
67
+ time_period_end=time_period_end,
68
+ keywords=keywords or [],
69
+ source_data=source_data or {},
70
+ author_id=author_id,
71
+ )
72
+
73
+ db.add(report)
74
+ await db.commit()
75
+ await db.refresh(report)
76
+
77
+ return report
78
+
79
+
80
+ async def get_report_by_id(db: AsyncSession, report_id: int) -> Optional[Report]:
81
+ """
82
+ Get report by ID.
83
+
84
+ Args:
85
+ db: Database session
86
+ report_id: Report ID
87
+
88
+ Returns:
89
+ Optional[Report]: Found report or None
90
+ """
91
+ result = await db.execute(
92
+ select(Report).where(Report.id == report_id)
93
+ )
94
+ return result.scalars().first()
95
+
96
+
97
+ async def get_report_by_custom_id(db: AsyncSession, custom_id: str) -> Optional[Report]:
98
+ """
99
+ Get report by custom ID.
100
+
101
+ Args:
102
+ db: Database session
103
+ custom_id: Custom report ID (e.g., RPT-2023-0001)
104
+
105
+ Returns:
106
+ Optional[Report]: Found report or None
107
+ """
108
+ result = await db.execute(
109
+ select(Report).where(Report.report_id == custom_id)
110
+ )
111
+ return result.scalars().first()
112
+
113
+
114
+ async def get_reports(
115
+ db: AsyncSession,
116
+ pagination: PaginationParams,
117
+ report_type: Optional[List[ReportType]] = None,
118
+ status: Optional[List[ReportStatus]] = None,
119
+ severity: Optional[List[ThreatSeverity]] = None,
120
+ search_query: Optional[str] = None,
121
+ keywords: Optional[List[str]] = None,
122
+ author_id: Optional[int] = None,
123
+ from_date: Optional[datetime] = None,
124
+ to_date: Optional[datetime] = None,
125
+ ) -> List[Report]:
126
+ """
127
+ Get reports with filtering.
128
+
129
+ Args:
130
+ db: Database session
131
+ pagination: Pagination parameters
132
+ report_type: Filter by report type
133
+ status: Filter by status
134
+ severity: Filter by severity
135
+ search_query: Search in title and summary
136
+ keywords: Filter by keywords
137
+ author_id: Filter by author ID
138
+ from_date: Filter by created_at >= from_date
139
+ to_date: Filter by created_at <= to_date
140
+
141
+ Returns:
142
+ List[Report]: List of reports
143
+ """
144
+ query = select(Report)
145
+
146
+ # Apply filters
147
+ if report_type:
148
+ query = query.where(Report.report_type.in_(report_type))
149
+
150
+ if status:
151
+ query = query.where(Report.status.in_(status))
152
+
153
+ if severity:
154
+ query = query.where(Report.severity.in_(severity))
155
+
156
+ if search_query:
157
+ search_filter = or_(
158
+ Report.title.ilike(f"%{search_query}%"),
159
+ Report.summary.ilike(f"%{search_query}%"),
160
+ Report.content.ilike(f"%{search_query}%"),
161
+ )
162
+ query = query.where(search_filter)
163
+
164
+ if keywords:
165
+ # For JSON arrays, need to use a more complex query
166
+ for keyword in keywords:
167
+ query = query.where(Report.keywords.contains([keyword]))
168
+
169
+ if author_id:
170
+ query = query.where(Report.author_id == author_id)
171
+
172
+ if from_date:
173
+ query = query.where(Report.created_at >= from_date)
174
+
175
+ if to_date:
176
+ query = query.where(Report.created_at <= to_date)
177
+
178
+ # Apply pagination
179
+ query = query.order_by(desc(Report.created_at))
180
+ query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
181
+
182
+ result = await db.execute(query)
183
+ return result.scalars().all()
184
+
185
+
186
+ async def count_reports(
187
+ db: AsyncSession,
188
+ report_type: Optional[List[ReportType]] = None,
189
+ status: Optional[List[ReportStatus]] = None,
190
+ severity: Optional[List[ThreatSeverity]] = None,
191
+ search_query: Optional[str] = None,
192
+ keywords: Optional[List[str]] = None,
193
+ author_id: Optional[int] = None,
194
+ from_date: Optional[datetime] = None,
195
+ to_date: Optional[datetime] = None,
196
+ ) -> int:
197
+ """
198
+ Count reports with filtering.
199
+
200
+ Args are the same as get_reports, except pagination.
201
+
202
+ Returns:
203
+ int: Count of matching reports
204
+ """
205
+ query = select(func.count(Report.id))
206
+
207
+ # Apply filters
208
+ if report_type:
209
+ query = query.where(Report.report_type.in_(report_type))
210
+
211
+ if status:
212
+ query = query.where(Report.status.in_(status))
213
+
214
+ if severity:
215
+ query = query.where(Report.severity.in_(severity))
216
+
217
+ if search_query:
218
+ search_filter = or_(
219
+ Report.title.ilike(f"%{search_query}%"),
220
+ Report.summary.ilike(f"%{search_query}%"),
221
+ Report.content.ilike(f"%{search_query}%"),
222
+ )
223
+ query = query.where(search_filter)
224
+
225
+ if keywords:
226
+ # For JSON arrays, need to use a more complex query
227
+ for keyword in keywords:
228
+ query = query.where(Report.keywords.contains([keyword]))
229
+
230
+ if author_id:
231
+ query = query.where(Report.author_id == author_id)
232
+
233
+ if from_date:
234
+ query = query.where(Report.created_at >= from_date)
235
+
236
+ if to_date:
237
+ query = query.where(Report.created_at <= to_date)
238
+
239
+ result = await db.execute(query)
240
+ return result.scalar()
241
+
242
+
243
+ async def update_report(
244
+ db: AsyncSession,
245
+ report_id: int,
246
+ title: Optional[str] = None,
247
+ summary: Optional[str] = None,
248
+ content: Optional[str] = None,
249
+ report_type: Optional[ReportType] = None,
250
+ status: Optional[ReportStatus] = None,
251
+ severity: Optional[ThreatSeverity] = None,
252
+ publish_date: Optional[datetime] = None,
253
+ time_period_start: Optional[datetime] = None,
254
+ time_period_end: Optional[datetime] = None,
255
+ keywords: Optional[List[str]] = None,
256
+ source_data: Optional[Dict[str, Any]] = None,
257
+ ) -> Optional[Report]:
258
+ """
259
+ Update a report.
260
+
261
+ Args:
262
+ db: Database session
263
+ report_id: Report ID
264
+ Other args: Fields to update
265
+
266
+ Returns:
267
+ Optional[Report]: Updated report or None
268
+ """
269
+ report = await get_report_by_id(db, report_id)
270
+
271
+ if not report:
272
+ return None
273
+
274
+ # Update fields if provided
275
+ if title is not None:
276
+ report.title = title
277
+
278
+ if summary is not None:
279
+ report.summary = summary
280
+
281
+ if content is not None:
282
+ report.content = content
283
+
284
+ if report_type is not None:
285
+ report.report_type = report_type
286
+
287
+ if status is not None:
288
+ report.status = status
289
+
290
+ if severity is not None:
291
+ report.severity = severity
292
+
293
+ if publish_date is not None:
294
+ report.publish_date = publish_date
295
+
296
+ if time_period_start is not None:
297
+ report.time_period_start = time_period_start
298
+
299
+ if time_period_end is not None:
300
+ report.time_period_end = time_period_end
301
+
302
+ if keywords is not None:
303
+ report.keywords = keywords
304
+
305
+ if source_data is not None:
306
+ report.source_data = source_data
307
+
308
+ await db.commit()
309
+ await db.refresh(report)
310
+
311
+ return report
312
+
313
+
314
+ async def add_threat_to_report(
315
+ db: AsyncSession,
316
+ report_id: int,
317
+ threat_id: int,
318
+ ) -> Optional[Report]:
319
+ """
320
+ Add a threat to a report.
321
+
322
+ Args:
323
+ db: Database session
324
+ report_id: Report ID
325
+ threat_id: Threat ID
326
+
327
+ Returns:
328
+ Optional[Report]: Updated report or None
329
+ """
330
+ from src.api.services.threat_service import get_threat_by_id
331
+
332
+ # Get report and threat
333
+ report = await get_report_by_id(db, report_id)
334
+ threat = await get_threat_by_id(db, threat_id)
335
+
336
+ if not report or not threat:
337
+ return None
338
+
339
+ # Add threat to report
340
+ report.threats.append(threat)
341
+ await db.commit()
342
+ await db.refresh(report)
343
+
344
+ return report
345
+
346
+
347
+ async def remove_threat_from_report(
348
+ db: AsyncSession,
349
+ report_id: int,
350
+ threat_id: int,
351
+ ) -> Optional[Report]:
352
+ """
353
+ Remove a threat from a report.
354
+
355
+ Args:
356
+ db: Database session
357
+ report_id: Report ID
358
+ threat_id: Threat ID
359
+
360
+ Returns:
361
+ Optional[Report]: Updated report or None
362
+ """
363
+ from src.api.services.threat_service import get_threat_by_id
364
+
365
+ # Get report and threat
366
+ report = await get_report_by_id(db, report_id)
367
+ threat = await get_threat_by_id(db, threat_id)
368
+
369
+ if not report or not threat:
370
+ return None
371
+
372
+ # Remove threat from report
373
+ if threat in report.threats:
374
+ report.threats.remove(threat)
375
+ await db.commit()
376
+ await db.refresh(report)
377
+
378
+ return report
379
+
380
+
381
+ async def publish_report(
382
+ db: AsyncSession,
383
+ report_id: int,
384
+ ) -> Optional[Report]:
385
+ """
386
+ Publish a report.
387
+
388
+ Args:
389
+ db: Database session
390
+ report_id: Report ID
391
+
392
+ Returns:
393
+ Optional[Report]: Updated report or None
394
+ """
395
+ report = await get_report_by_id(db, report_id)
396
+
397
+ if not report:
398
+ return None
399
+
400
+ # Update status and publish date
401
+ report.status = ReportStatus.PUBLISHED
402
+
403
+ if not report.publish_date:
404
+ report.publish_date = datetime.utcnow()
405
+
406
+ await db.commit()
407
+ await db.refresh(report)
408
+
409
+ return report
410
+
411
+
412
+ async def archive_report(
413
+ db: AsyncSession,
414
+ report_id: int,
415
+ ) -> Optional[Report]:
416
+ """
417
+ Archive a report.
418
+
419
+ Args:
420
+ db: Database session
421
+ report_id: Report ID
422
+
423
+ Returns:
424
+ Optional[Report]: Updated report or None
425
+ """
426
+ report = await get_report_by_id(db, report_id)
427
+
428
+ if not report:
429
+ return None
430
+
431
+ # Update status
432
+ report.status = ReportStatus.ARCHIVED
433
+ await db.commit()
434
+ await db.refresh(report)
435
+
436
+ return report
src/api/services/search_history_service.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Search History and Trends Service
3
+
4
+ This service manages search history, saved searches, and trend analysis.
5
+ """
6
+ import logging
7
+ import json
8
+ from datetime import datetime, timedelta
9
+ from typing import List, Dict, Any, Optional, Tuple, Union
10
+
11
+ from sqlalchemy import func, desc, and_, or_, text
12
+ from sqlalchemy.future import select
13
+ from sqlalchemy.ext.asyncio import AsyncSession
14
+ from sqlalchemy.orm import selectinload
15
+
16
+ from src.models.search_history import SearchHistory, SearchResult, SavedSearch, TrendTopic
17
+ from src.models.dark_web_content import DarkWebContent
18
+ from src.models.user import User
19
+
20
+ # Configure logging
21
+ logger = logging.getLogger(__name__)
22
+
23
+ async def add_search_history(
24
+ db: AsyncSession,
25
+ query: str,
26
+ user_id: Optional[int] = None,
27
+ result_count: int = 0,
28
+ category: Optional[str] = None,
29
+ is_saved: bool = False,
30
+ notes: Optional[str] = None,
31
+ tags: Optional[str] = None
32
+ ) -> SearchHistory:
33
+ """
34
+ Add a new search history entry.
35
+
36
+ Args:
37
+ db: Database session
38
+ query: Search query
39
+ user_id: ID of the user who performed the search (optional)
40
+ result_count: Number of results returned
41
+ category: Category of the search
42
+ is_saved: Whether this is a saved search
43
+ notes: Optional notes
44
+ tags: Optional tags (comma-separated)
45
+
46
+ Returns:
47
+ The created SearchHistory object
48
+ """
49
+ search_history = SearchHistory(
50
+ query=query,
51
+ user_id=user_id,
52
+ result_count=result_count,
53
+ category=category,
54
+ is_saved=is_saved,
55
+ notes=notes,
56
+ tags=tags
57
+ )
58
+
59
+ db.add(search_history)
60
+ await db.commit()
61
+ await db.refresh(search_history)
62
+
63
+ # Update trend data
64
+ await update_trend_data(db, query, category)
65
+
66
+ return search_history
67
+
68
+ async def add_search_result(
69
+ db: AsyncSession,
70
+ search_id: int,
71
+ url: str,
72
+ title: Optional[str] = None,
73
+ snippet: Optional[str] = None,
74
+ source: Optional[str] = None,
75
+ relevance_score: float = 0.0,
76
+ content_id: Optional[int] = None
77
+ ) -> SearchResult:
78
+ """
79
+ Add a new search result.
80
+
81
+ Args:
82
+ db: Database session
83
+ search_id: ID of the parent search
84
+ url: URL of the result
85
+ title: Title of the result
86
+ snippet: Text snippet from the result
87
+ source: Source of the result
88
+ relevance_score: Score indicating relevance to the search query
89
+ content_id: ID of the content in our database (if applicable)
90
+
91
+ Returns:
92
+ The created SearchResult object
93
+ """
94
+ search_result = SearchResult(
95
+ search_id=search_id,
96
+ url=url,
97
+ title=title,
98
+ snippet=snippet,
99
+ source=source,
100
+ relevance_score=relevance_score,
101
+ content_id=content_id
102
+ )
103
+
104
+ db.add(search_result)
105
+ await db.commit()
106
+ await db.refresh(search_result)
107
+
108
+ return search_result
109
+
110
+ async def get_search_history(
111
+ db: AsyncSession,
112
+ skip: int = 0,
113
+ limit: int = 100,
114
+ user_id: Optional[int] = None,
115
+ query_filter: Optional[str] = None,
116
+ date_from: Optional[datetime] = None,
117
+ date_to: Optional[datetime] = None,
118
+ category: Optional[str] = None,
119
+ is_saved: Optional[bool] = None,
120
+ include_results: bool = False
121
+ ) -> List[SearchHistory]:
122
+ """
123
+ Get search history with filtering options.
124
+
125
+ Args:
126
+ db: Database session
127
+ skip: Number of items to skip
128
+ limit: Maximum number of items to return
129
+ user_id: Filter by user ID
130
+ query_filter: Filter by search query (partial match)
131
+ date_from: Filter by timestamp (from)
132
+ date_to: Filter by timestamp (to)
133
+ category: Filter by category
134
+ is_saved: Filter by saved status
135
+ include_results: Whether to include search results
136
+
137
+ Returns:
138
+ List of SearchHistory objects
139
+ """
140
+ statement = select(SearchHistory)
141
+
142
+ # Apply filters
143
+ if user_id is not None:
144
+ statement = statement.where(SearchHistory.user_id == user_id)
145
+
146
+ if query_filter:
147
+ statement = statement.where(SearchHistory.query.ilike(f"%{query_filter}%"))
148
+
149
+ if date_from:
150
+ statement = statement.where(SearchHistory.timestamp >= date_from)
151
+
152
+ if date_to:
153
+ statement = statement.where(SearchHistory.timestamp <= date_to)
154
+
155
+ if category:
156
+ statement = statement.where(SearchHistory.category == category)
157
+
158
+ if is_saved is not None:
159
+ statement = statement.where(SearchHistory.is_saved == is_saved)
160
+
161
+ # Load related data if requested
162
+ if include_results:
163
+ statement = statement.options(selectinload(SearchHistory.search_results))
164
+
165
+ # Apply pagination
166
+ statement = statement.order_by(desc(SearchHistory.timestamp)).offset(skip).limit(limit)
167
+
168
+ result = await db.execute(statement)
169
+ return result.scalars().all()
170
+
171
+ async def get_search_by_id(
172
+ db: AsyncSession,
173
+ search_id: int,
174
+ include_results: bool = False
175
+ ) -> Optional[SearchHistory]:
176
+ """
177
+ Get a search history entry by ID.
178
+
179
+ Args:
180
+ db: Database session
181
+ search_id: Search history ID
182
+ include_results: Whether to include search results
183
+
184
+ Returns:
185
+ SearchHistory object or None if not found
186
+ """
187
+ statement = select(SearchHistory).where(SearchHistory.id == search_id)
188
+
189
+ if include_results:
190
+ statement = statement.options(selectinload(SearchHistory.search_results))
191
+
192
+ result = await db.execute(statement)
193
+ return result.scalars().first()
194
+
195
+ async def delete_search_history(db: AsyncSession, search_id: int) -> bool:
196
+ """
197
+ Delete a search history entry.
198
+
199
+ Args:
200
+ db: Database session
201
+ search_id: ID of the search to delete
202
+
203
+ Returns:
204
+ True if successful, False otherwise
205
+ """
206
+ search = await get_search_by_id(db, search_id)
207
+ if not search:
208
+ return False
209
+
210
+ await db.delete(search)
211
+ await db.commit()
212
+ return True
213
+
214
+ async def save_search(
215
+ db: AsyncSession,
216
+ search_id: int,
217
+ is_saved: bool = True,
218
+ notes: Optional[str] = None,
219
+ tags: Optional[str] = None
220
+ ) -> Optional[SearchHistory]:
221
+ """
222
+ Save or unsave a search history entry.
223
+
224
+ Args:
225
+ db: Database session
226
+ search_id: ID of the search
227
+ is_saved: Whether to save or unsave
228
+ notes: Optional notes to add
229
+ tags: Optional tags to add (comma-separated)
230
+
231
+ Returns:
232
+ Updated SearchHistory object or None if not found
233
+ """
234
+ search = await get_search_by_id(db, search_id)
235
+ if not search:
236
+ return None
237
+
238
+ search.is_saved = is_saved
239
+
240
+ if notes:
241
+ search.notes = notes
242
+
243
+ if tags:
244
+ search.tags = tags
245
+
246
+ await db.commit()
247
+ await db.refresh(search)
248
+ return search
249
+
250
+ async def create_saved_search(
251
+ db: AsyncSession,
252
+ name: str,
253
+ query: str,
254
+ user_id: int,
255
+ frequency: int = 24,
256
+ notification_enabled: bool = True,
257
+ threshold: int = 1,
258
+ category: Optional[str] = None
259
+ ) -> SavedSearch:
260
+ """
261
+ Create a new saved search with periodic monitoring.
262
+
263
+ Args:
264
+ db: Database session
265
+ name: Name of the saved search
266
+ query: Search query
267
+ user_id: ID of the user
268
+ frequency: How often to run this search (in hours, 0 for manual only)
269
+ notification_enabled: Whether to send notifications for new results
270
+ threshold: Minimum number of new results for notification
271
+ category: Category of the search
272
+
273
+ Returns:
274
+ The created SavedSearch object
275
+ """
276
+ saved_search = SavedSearch(
277
+ name=name,
278
+ query=query,
279
+ user_id=user_id,
280
+ frequency=frequency,
281
+ notification_enabled=notification_enabled,
282
+ threshold=threshold,
283
+ category=category
284
+ )
285
+
286
+ db.add(saved_search)
287
+ await db.commit()
288
+ await db.refresh(saved_search)
289
+
290
+ return saved_search
291
+
292
+ async def get_saved_searches(
293
+ db: AsyncSession,
294
+ user_id: Optional[int] = None,
295
+ is_active: Optional[bool] = None,
296
+ skip: int = 0,
297
+ limit: int = 100
298
+ ) -> List[SavedSearch]:
299
+ """
300
+ Get saved searches with filtering options.
301
+
302
+ Args:
303
+ db: Database session
304
+ user_id: Filter by user ID
305
+ is_active: Filter by active status
306
+ skip: Number of items to skip
307
+ limit: Maximum number of items to return
308
+
309
+ Returns:
310
+ List of SavedSearch objects
311
+ """
312
+ statement = select(SavedSearch)
313
+
314
+ # Apply filters
315
+ if user_id is not None:
316
+ statement = statement.where(SavedSearch.user_id == user_id)
317
+
318
+ if is_active is not None:
319
+ statement = statement.where(SavedSearch.is_active == is_active)
320
+
321
+ # Apply pagination
322
+ statement = statement.order_by(SavedSearch.name).offset(skip).limit(limit)
323
+
324
+ result = await db.execute(statement)
325
+ return result.scalars().all()
326
+
327
+ async def update_trend_data(
328
+ db: AsyncSession,
329
+ query: str,
330
+ category: Optional[str] = None
331
+ ) -> None:
332
+ """
333
+ Update trend data based on search queries.
334
+
335
+ Args:
336
+ db: Database session
337
+ query: Search query
338
+ category: Category of the search
339
+ """
340
+ # Split query into individual terms/topics
341
+ topics = [t.strip() for t in query.split() if len(t.strip()) > 3]
342
+
343
+ # Process each topic
344
+ for topic in topics:
345
+ # Check if topic already exists
346
+ statement = select(TrendTopic).where(TrendTopic.topic == topic)
347
+ result = await db.execute(statement)
348
+ trend_topic = result.scalars().first()
349
+
350
+ if trend_topic:
351
+ # Update existing topic
352
+ trend_topic.last_seen = datetime.utcnow()
353
+ trend_topic.mention_count += 1
354
+
355
+ # Calculate growth rate (percentage change over the last 24 hours)
356
+ time_diff = (trend_topic.last_seen - trend_topic.first_seen).total_seconds() / 3600 # hours
357
+ if time_diff > 0:
358
+ hourly_rate = trend_topic.mention_count / time_diff
359
+ trend_topic.growth_rate = hourly_rate * 24 # daily growth rate
360
+
361
+ # Update category if provided and not already set
362
+ if category and not trend_topic.category:
363
+ trend_topic.category = category
364
+ else:
365
+ # Create a new trend topic
366
+ trend_topic = TrendTopic(
367
+ topic=topic,
368
+ category=category,
369
+ mention_count=1,
370
+ growth_rate=1.0 # Initial growth rate
371
+ )
372
+ db.add(trend_topic)
373
+
374
+ await db.commit()
375
+
376
+ async def get_trending_topics(
377
+ db: AsyncSession,
378
+ days: int = 7,
379
+ limit: int = 20,
380
+ category: Optional[str] = None,
381
+ min_mentions: int = 3
382
+ ) -> List[TrendTopic]:
383
+ """
384
+ Get trending topics over a specific time period.
385
+
386
+ Args:
387
+ db: Database session
388
+ days: Number of days to consider
389
+ limit: Maximum number of topics to return
390
+ category: Filter by category
391
+ min_mentions: Minimum number of mentions
392
+
393
+ Returns:
394
+ List of TrendTopic objects sorted by growth rate
395
+ """
396
+ cutoff_date = datetime.utcnow() - timedelta(days=days)
397
+
398
+ statement = select(TrendTopic).where(
399
+ and_(
400
+ TrendTopic.last_seen >= cutoff_date,
401
+ TrendTopic.mention_count >= min_mentions,
402
+ TrendTopic.is_active == True
403
+ )
404
+ )
405
+
406
+ if category:
407
+ statement = statement.where(TrendTopic.category == category)
408
+
409
+ statement = statement.order_by(desc(TrendTopic.growth_rate)).limit(limit)
410
+
411
+ result = await db.execute(statement)
412
+ return result.scalars().all()
413
+
414
+ async def get_search_frequency(
415
+ db: AsyncSession,
416
+ days: int = 30,
417
+ interval: str = 'day'
418
+ ) -> List[Dict[str, Any]]:
419
+ """
420
+ Get search frequency over time for visualization.
421
+
422
+ Args:
423
+ db: Database session
424
+ days: Number of days to analyze
425
+ interval: Time interval ('hour', 'day', 'week', 'month')
426
+
427
+ Returns:
428
+ List of dictionaries with time intervals and search counts
429
+ """
430
+ cutoff_date = datetime.utcnow() - timedelta(days=days)
431
+
432
+ # SQL query depends on the interval
433
+ if interval == 'hour':
434
+ date_format = "YYYY-MM-DD HH24:00"
435
+ trunc_expr = func.date_trunc('hour', SearchHistory.timestamp)
436
+ elif interval == 'day':
437
+ date_format = "YYYY-MM-DD"
438
+ trunc_expr = func.date_trunc('day', SearchHistory.timestamp)
439
+ elif interval == 'week':
440
+ date_format = "YYYY-WW"
441
+ trunc_expr = func.date_trunc('week', SearchHistory.timestamp)
442
+ else: # month
443
+ date_format = "YYYY-MM"
444
+ trunc_expr = func.date_trunc('month', SearchHistory.timestamp)
445
+
446
+ # Query for search count by interval
447
+ statement = select(
448
+ trunc_expr.label('interval'),
449
+ func.count(SearchHistory.id).label('count')
450
+ ).where(
451
+ SearchHistory.timestamp >= cutoff_date
452
+ ).group_by(
453
+ 'interval'
454
+ ).order_by(
455
+ 'interval'
456
+ )
457
+
458
+ result = await db.execute(statement)
459
+ rows = result.all()
460
+
461
+ # Convert to list of dictionaries
462
+ return [{"interval": row.interval, "count": row.count} for row in rows]
463
+
464
+ async def get_popular_searches(
465
+ db: AsyncSession,
466
+ days: int = 30,
467
+ limit: int = 10
468
+ ) -> List[Dict[str, Any]]:
469
+ """
470
+ Get the most popular search terms.
471
+
472
+ Args:
473
+ db: Database session
474
+ days: Number of days to analyze
475
+ limit: Maximum number of terms to return
476
+
477
+ Returns:
478
+ List of dictionaries with search queries and counts
479
+ """
480
+ cutoff_date = datetime.utcnow() - timedelta(days=days)
481
+
482
+ statement = select(
483
+ SearchHistory.query,
484
+ func.count(SearchHistory.id).label('count')
485
+ ).where(
486
+ SearchHistory.timestamp >= cutoff_date
487
+ ).group_by(
488
+ SearchHistory.query
489
+ ).order_by(
490
+ desc('count')
491
+ ).limit(limit)
492
+
493
+ result = await db.execute(statement)
494
+ rows = result.all()
495
+
496
+ return [{"query": row.query, "count": row.count} for row in rows]
497
+
498
+ async def get_search_categories(
499
+ db: AsyncSession,
500
+ days: int = 30
501
+ ) -> List[Dict[str, Any]]:
502
+ """
503
+ Get distribution of search categories.
504
+
505
+ Args:
506
+ db: Database session
507
+ days: Number of days to analyze
508
+
509
+ Returns:
510
+ List of dictionaries with categories and counts
511
+ """
512
+ cutoff_date = datetime.utcnow() - timedelta(days=days)
513
+
514
+ statement = select(
515
+ SearchHistory.category,
516
+ func.count(SearchHistory.id).label('count')
517
+ ).where(
518
+ and_(
519
+ SearchHistory.timestamp >= cutoff_date,
520
+ SearchHistory.category.is_not(None)
521
+ )
522
+ ).group_by(
523
+ SearchHistory.category
524
+ ).order_by(
525
+ desc('count')
526
+ )
527
+
528
+ result = await db.execute(statement)
529
+ rows = result.all()
530
+
531
+ return [{"category": row.category or "Uncategorized", "count": row.count} for row in rows]
532
+
533
+ async def get_search_trend_analysis(
534
+ db: AsyncSession,
535
+ days: int = 90,
536
+ trend_days: int = 7,
537
+ limit: int = 10
538
+ ) -> Dict[str, Any]:
539
+ """
540
+ Get comprehensive analysis of search trends.
541
+
542
+ Args:
543
+ db: Database session
544
+ days: Total days to analyze
545
+ trend_days: Days to calculate short-term trends
546
+ limit: Maximum number of items in each category
547
+
548
+ Returns:
549
+ Dictionary with various trend analyses
550
+ """
551
+ # Get overall search frequency
552
+ frequency = await get_search_frequency(db, days, 'day')
553
+
554
+ # Get popular searches
555
+ popular = await get_popular_searches(db, days, limit)
556
+
557
+ # Get recent trending topics
558
+ trending = await get_trending_topics(db, trend_days, limit)
559
+
560
+ # Get category distribution
561
+ categories = await get_search_categories(db, days)
562
+
563
+ # Get recent (last 24 hours) vs. overall popular terms
564
+ recent_popular = await get_popular_searches(db, 1, limit)
565
+
566
+ # Calculate velocity (rate of change)
567
+ # This compares the last 7 days to the previous 7 days
568
+ cutoff_recent = datetime.utcnow() - timedelta(days=trend_days)
569
+ cutoff_previous = cutoff_recent - timedelta(days=trend_days)
570
+
571
+ # Query for velocity calculation
572
+ statement_recent = select(func.count(SearchHistory.id)).where(
573
+ SearchHistory.timestamp >= cutoff_recent
574
+ )
575
+ statement_previous = select(func.count(SearchHistory.id)).where(
576
+ and_(
577
+ SearchHistory.timestamp >= cutoff_previous,
578
+ SearchHistory.timestamp < cutoff_recent
579
+ )
580
+ )
581
+
582
+ result_recent = await db.execute(statement_recent)
583
+ result_previous = await db.execute(statement_previous)
584
+
585
+ count_recent = result_recent.scalar() or 0
586
+ count_previous = result_previous.scalar() or 0
587
+
588
+ if count_previous > 0:
589
+ velocity = (count_recent - count_previous) / count_previous * 100 # percentage change
590
+ else:
591
+ velocity = 100.0 if count_recent > 0 else 0.0
592
+
593
+ # Compile the results
594
+ return {
595
+ "frequency": frequency,
596
+ "popular_searches": popular,
597
+ "trending_topics": [
598
+ {"topic": t.topic, "mentions": t.mention_count, "growth_rate": t.growth_rate}
599
+ for t in trending
600
+ ],
601
+ "categories": categories,
602
+ "recent_popular": recent_popular,
603
+ "velocity": velocity,
604
+ "total_searches": {
605
+ "total": count_recent + count_previous,
606
+ "recent": count_recent,
607
+ "previous": count_previous
608
+ }
609
+ }
src/api/services/subscription_service.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Subscription service.
3
+
4
+ This module provides functions for managing subscriptions.
5
+ """
6
+ import os
7
+ import logging
8
+ from datetime import datetime, timedelta
9
+ from typing import List, Dict, Any, Optional, Tuple, Union
10
+
11
+ import stripe
12
+ from sqlalchemy.ext.asyncio import AsyncSession
13
+ from sqlalchemy import select, update, delete
14
+ from sqlalchemy.orm import joinedload
15
+
16
+ from src.models.subscription import (
17
+ SubscriptionPlan, UserSubscription, PaymentHistory,
18
+ SubscriptionTier, BillingPeriod, SubscriptionStatus, PaymentStatus
19
+ )
20
+ from src.models.user import User
21
+
22
+ # Set up Stripe API key
23
+ stripe.api_key = os.environ.get("STRIPE_SECRET_KEY")
24
+ STRIPE_PUBLISHABLE_KEY = os.environ.get("STRIPE_PUBLISHABLE_KEY")
25
+
26
+ # Set up logging
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ async def get_subscription_plans(
31
+ db: AsyncSession,
32
+ active_only: bool = True
33
+ ) -> List[SubscriptionPlan]:
34
+ """
35
+ Get all subscription plans.
36
+
37
+ Args:
38
+ db: Database session
39
+ active_only: If True, only return active plans
40
+
41
+ Returns:
42
+ List of subscription plans
43
+ """
44
+ query = select(SubscriptionPlan)
45
+
46
+ if active_only:
47
+ query = query.where(SubscriptionPlan.is_active == True)
48
+
49
+ result = await db.execute(query)
50
+ plans = result.scalars().all()
51
+
52
+ return plans
53
+
54
+
55
+ async def get_subscription_plan_by_id(
56
+ db: AsyncSession,
57
+ plan_id: int
58
+ ) -> Optional[SubscriptionPlan]:
59
+ """
60
+ Get a subscription plan by ID.
61
+
62
+ Args:
63
+ db: Database session
64
+ plan_id: ID of the plan to get
65
+
66
+ Returns:
67
+ Subscription plan or None if not found
68
+ """
69
+ query = select(SubscriptionPlan).where(SubscriptionPlan.id == plan_id)
70
+ result = await db.execute(query)
71
+ plan = result.scalars().first()
72
+
73
+ return plan
74
+
75
+
76
+ async def get_subscription_plan_by_tier(
77
+ db: AsyncSession,
78
+ tier: SubscriptionTier
79
+ ) -> Optional[SubscriptionPlan]:
80
+ """
81
+ Get a subscription plan by tier.
82
+
83
+ Args:
84
+ db: Database session
85
+ tier: Tier of the plan to get
86
+
87
+ Returns:
88
+ Subscription plan or None if not found
89
+ """
90
+ query = select(SubscriptionPlan).where(SubscriptionPlan.tier == tier)
91
+ result = await db.execute(query)
92
+ plan = result.scalars().first()
93
+
94
+ return plan
95
+
96
+
97
+ async def create_subscription_plan(
98
+ db: AsyncSession,
99
+ name: str,
100
+ tier: SubscriptionTier,
101
+ description: str,
102
+ price_monthly: float,
103
+ price_annually: float,
104
+ max_alerts: int = 10,
105
+ max_reports: int = 5,
106
+ max_searches_per_day: int = 20,
107
+ max_monitoring_keywords: int = 10,
108
+ max_data_retention_days: int = 30,
109
+ supports_api_access: bool = False,
110
+ supports_live_feed: bool = False,
111
+ supports_dark_web_monitoring: bool = False,
112
+ supports_export: bool = False,
113
+ supports_advanced_analytics: bool = False,
114
+ create_stripe_product: bool = True
115
+ ) -> Optional[SubscriptionPlan]:
116
+ """
117
+ Create a new subscription plan.
118
+
119
+ Args:
120
+ db: Database session
121
+ name: Name of the plan
122
+ tier: Tier of the plan
123
+ description: Description of the plan
124
+ price_monthly: Monthly price of the plan
125
+ price_annually: Annual price of the plan
126
+ max_alerts: Maximum number of alerts allowed
127
+ max_reports: Maximum number of reports allowed
128
+ max_searches_per_day: Maximum number of searches per day
129
+ max_monitoring_keywords: Maximum number of monitoring keywords
130
+ max_data_retention_days: Maximum number of days to retain data
131
+ supports_api_access: Whether the plan supports API access
132
+ supports_live_feed: Whether the plan supports live feed
133
+ supports_dark_web_monitoring: Whether the plan supports dark web monitoring
134
+ supports_export: Whether the plan supports data export
135
+ supports_advanced_analytics: Whether the plan supports advanced analytics
136
+ create_stripe_product: Whether to create a Stripe product for this plan
137
+
138
+ Returns:
139
+ Created subscription plan or None if creation failed
140
+ """
141
+ # Check if plan with the same tier already exists
142
+ existing_plan = await get_subscription_plan_by_tier(db, tier)
143
+
144
+ if existing_plan:
145
+ logger.warning(f"Subscription plan with tier {tier} already exists")
146
+ return None
147
+
148
+ # Create Stripe product if requested
149
+ stripe_product_id = None
150
+ stripe_monthly_price_id = None
151
+ stripe_annual_price_id = None
152
+
153
+ if create_stripe_product and stripe.api_key:
154
+ try:
155
+ # Create Stripe product
156
+ product = stripe.Product.create(
157
+ name=name,
158
+ description=description,
159
+ metadata={
160
+ "tier": tier.value,
161
+ "max_alerts": max_alerts,
162
+ "max_reports": max_reports,
163
+ "max_searches_per_day": max_searches_per_day,
164
+ "max_monitoring_keywords": max_monitoring_keywords,
165
+ "max_data_retention_days": max_data_retention_days,
166
+ "supports_api_access": "yes" if supports_api_access else "no",
167
+ "supports_live_feed": "yes" if supports_live_feed else "no",
168
+ "supports_dark_web_monitoring": "yes" if supports_dark_web_monitoring else "no",
169
+ "supports_export": "yes" if supports_export else "no",
170
+ "supports_advanced_analytics": "yes" if supports_advanced_analytics else "no"
171
+ }
172
+ )
173
+
174
+ stripe_product_id = product.id
175
+
176
+ # Create monthly price
177
+ monthly_price = stripe.Price.create(
178
+ product=product.id,
179
+ unit_amount=int(price_monthly * 100), # Stripe uses cents
180
+ currency="usd",
181
+ recurring={"interval": "month"},
182
+ metadata={"billing_period": "monthly"}
183
+ )
184
+
185
+ stripe_monthly_price_id = monthly_price.id
186
+
187
+ # Create annual price
188
+ annual_price = stripe.Price.create(
189
+ product=product.id,
190
+ unit_amount=int(price_annually * 100), # Stripe uses cents
191
+ currency="usd",
192
+ recurring={"interval": "year"},
193
+ metadata={"billing_period": "annually"}
194
+ )
195
+
196
+ stripe_annual_price_id = annual_price.id
197
+
198
+ logger.info(f"Created Stripe product {product.id} for plan {name}")
199
+ except Exception as e:
200
+ logger.error(f"Failed to create Stripe product for plan {name}: {e}")
201
+
202
+ # Create plan in database
203
+ plan = SubscriptionPlan(
204
+ name=name,
205
+ tier=tier,
206
+ description=description,
207
+ price_monthly=price_monthly,
208
+ price_annually=price_annually,
209
+ max_alerts=max_alerts,
210
+ max_reports=max_reports,
211
+ max_searches_per_day=max_searches_per_day,
212
+ max_monitoring_keywords=max_monitoring_keywords,
213
+ max_data_retention_days=max_data_retention_days,
214
+ supports_api_access=supports_api_access,
215
+ supports_live_feed=supports_live_feed,
216
+ supports_dark_web_monitoring=supports_dark_web_monitoring,
217
+ supports_export=supports_export,
218
+ supports_advanced_analytics=supports_advanced_analytics,
219
+ stripe_product_id=stripe_product_id,
220
+ stripe_monthly_price_id=stripe_monthly_price_id,
221
+ stripe_annual_price_id=stripe_annual_price_id
222
+ )
223
+
224
+ db.add(plan)
225
+ await db.commit()
226
+ await db.refresh(plan)
227
+
228
+ return plan
229
+
230
+
231
+ async def update_subscription_plan(
232
+ db: AsyncSession,
233
+ plan_id: int,
234
+ name: Optional[str] = None,
235
+ description: Optional[str] = None,
236
+ price_monthly: Optional[float] = None,
237
+ price_annually: Optional[float] = None,
238
+ is_active: Optional[bool] = None,
239
+ max_alerts: Optional[int] = None,
240
+ max_reports: Optional[int] = None,
241
+ max_searches_per_day: Optional[int] = None,
242
+ max_monitoring_keywords: Optional[int] = None,
243
+ max_data_retention_days: Optional[int] = None,
244
+ supports_api_access: Optional[bool] = None,
245
+ supports_live_feed: Optional[bool] = None,
246
+ supports_dark_web_monitoring: Optional[bool] = None,
247
+ supports_export: Optional[bool] = None,
248
+ supports_advanced_analytics: Optional[bool] = None,
249
+ update_stripe_product: bool = True
250
+ ) -> Optional[SubscriptionPlan]:
251
+ """
252
+ Update a subscription plan.
253
+
254
+ Args:
255
+ db: Database session
256
+ plan_id: ID of the plan to update
257
+ name: New name of the plan
258
+ description: New description of the plan
259
+ price_monthly: New monthly price of the plan
260
+ price_annually: New annual price of the plan
261
+ is_active: New active status of the plan
262
+ max_alerts: New maximum number of alerts allowed
263
+ max_reports: New maximum number of reports allowed
264
+ max_searches_per_day: New maximum number of searches per day
265
+ max_monitoring_keywords: New maximum number of monitoring keywords
266
+ max_data_retention_days: New maximum number of days to retain data
267
+ supports_api_access: New API access support status
268
+ supports_live_feed: New live feed support status
269
+ supports_dark_web_monitoring: New dark web monitoring support status
270
+ supports_export: New data export support status
271
+ supports_advanced_analytics: New advanced analytics support status
272
+ update_stripe_product: Whether to update the Stripe product for this plan
273
+
274
+ Returns:
275
+ Updated subscription plan or None if update failed
276
+ """
277
+ # Get existing plan
278
+ plan = await get_subscription_plan_by_id(db, plan_id)
279
+
280
+ if not plan:
281
+ logger.warning(f"Subscription plan with ID {plan_id} not found")
282
+ return None
283
+
284
+ # Prepare update data
285
+ update_data = {}
286
+
287
+ if name is not None:
288
+ update_data["name"] = name
289
+
290
+ if description is not None:
291
+ update_data["description"] = description
292
+
293
+ if price_monthly is not None:
294
+ update_data["price_monthly"] = price_monthly
295
+
296
+ if price_annually is not None:
297
+ update_data["price_annually"] = price_annually
298
+
299
+ if is_active is not None:
300
+ update_data["is_active"] = is_active
301
+
302
+ if max_alerts is not None:
303
+ update_data["max_alerts"] = max_alerts
304
+
305
+ if max_reports is not None:
306
+ update_data["max_reports"] = max_reports
307
+
308
+ if max_searches_per_day is not None:
309
+ update_data["max_searches_per_day"] = max_searches_per_day
310
+
311
+ if max_monitoring_keywords is not None:
312
+ update_data["max_monitoring_keywords"] = max_monitoring_keywords
313
+
314
+ if max_data_retention_days is not None:
315
+ update_data["max_data_retention_days"] = max_data_retention_days
316
+
317
+ if supports_api_access is not None:
318
+ update_data["supports_api_access"] = supports_api_access
319
+
320
+ if supports_live_feed is not None:
321
+ update_data["supports_live_feed"] = supports_live_feed
322
+
323
+ if supports_dark_web_monitoring is not None:
324
+ update_data["supports_dark_web_monitoring"] = supports_dark_web_monitoring
325
+
326
+ if supports_export is not None:
327
+ update_data["supports_export"] = supports_export
328
+
329
+ if supports_advanced_analytics is not None:
330
+ update_data["supports_advanced_analytics"] = supports_advanced_analytics
331
+
332
+ # Update Stripe product if requested
333
+ if update_stripe_product and plan.stripe_product_id and stripe.api_key:
334
+ try:
335
+ # Update Stripe product
336
+ product_update_data = {}
337
+
338
+ if name is not None:
339
+ product_update_data["name"] = name
340
+
341
+ if description is not None:
342
+ product_update_data["description"] = description
343
+
344
+ metadata_update = {}
345
+
346
+ if max_alerts is not None:
347
+ metadata_update["max_alerts"] = max_alerts
348
+
349
+ if max_reports is not None:
350
+ metadata_update["max_reports"] = max_reports
351
+
352
+ if max_searches_per_day is not None:
353
+ metadata_update["max_searches_per_day"] = max_searches_per_day
354
+
355
+ if max_monitoring_keywords is not None:
356
+ metadata_update["max_monitoring_keywords"] = max_monitoring_keywords
357
+
358
+ if max_data_retention_days is not None:
359
+ metadata_update["max_data_retention_days"] = max_data_retention_days
360
+
361
+ if supports_api_access is not None:
362
+ metadata_update["supports_api_access"] = "yes" if supports_api_access else "no"
363
+
364
+ if supports_live_feed is not None:
365
+ metadata_update["supports_live_feed"] = "yes" if supports_live_feed else "no"
366
+
367
+ if supports_dark_web_monitoring is not None:
368
+ metadata_update["supports_dark_web_monitoring"] = "yes" if supports_dark_web_monitoring else "no"
369
+
370
+ if supports_export is not None:
371
+ metadata_update["supports_export"] = "yes" if supports_export else "no"
372
+
373
+ if supports_advanced_analytics is not None:
374
+ metadata_update["supports_advanced_analytics"] = "yes" if supports_advanced_analytics else "no"
375
+
376
+ if metadata_update:
377
+ product_update_data["metadata"] = metadata_update
378
+
379
+ if product_update_data:
380
+ stripe.Product.modify(plan.stripe_product_id, **product_update_data)
381
+
382
+ # Update prices if needed
383
+ if price_monthly is not None and plan.stripe_monthly_price_id:
384
+ # Can't update existing price in Stripe, create a new one
385
+ new_monthly_price = stripe.Price.create(
386
+ product=plan.stripe_product_id,
387
+ unit_amount=int(price_monthly * 100), # Stripe uses cents
388
+ currency="usd",
389
+ recurring={"interval": "month"},
390
+ metadata={"billing_period": "monthly"}
391
+ )
392
+
393
+ update_data["stripe_monthly_price_id"] = new_monthly_price.id
394
+
395
+ if price_annually is not None and plan.stripe_annual_price_id:
396
+ # Can't update existing price in Stripe, create a new one
397
+ new_annual_price = stripe.Price.create(
398
+ product=plan.stripe_product_id,
399
+ unit_amount=int(price_annually * 100), # Stripe uses cents
400
+ currency="usd",
401
+ recurring={"interval": "year"},
402
+ metadata={"billing_period": "annually"}
403
+ )
404
+
405
+ update_data["stripe_annual_price_id"] = new_annual_price.id
406
+
407
+ logger.info(f"Updated Stripe product {plan.stripe_product_id} for plan {plan.name}")
408
+ except Exception as e:
409
+ logger.error(f"Failed to update Stripe product for plan {plan.name}: {e}")
410
+
411
+ # Update plan in database
412
+ if update_data:
413
+ await db.execute(
414
+ update(SubscriptionPlan)
415
+ .where(SubscriptionPlan.id == plan_id)
416
+ .values(**update_data)
417
+ )
418
+
419
+ await db.commit()
420
+
421
+ # Refresh plan
422
+ plan = await get_subscription_plan_by_id(db, plan_id)
423
+
424
+ return plan
425
+
426
+
427
+ async def get_user_subscription(
428
+ db: AsyncSession,
429
+ user_id: int
430
+ ) -> Optional[UserSubscription]:
431
+ """
432
+ Get a user's active subscription.
433
+
434
+ Args:
435
+ db: Database session
436
+ user_id: ID of the user
437
+
438
+ Returns:
439
+ User subscription or None if not found
440
+ """
441
+ query = (
442
+ select(UserSubscription)
443
+ .where(UserSubscription.user_id == user_id)
444
+ .where(UserSubscription.status != SubscriptionStatus.CANCELED)
445
+ .options(joinedload(UserSubscription.plan))
446
+ )
447
+
448
+ result = await db.execute(query)
449
+ subscription = result.scalars().first()
450
+
451
+ return subscription
452
+
453
+
454
+ async def get_user_subscription_by_id(
455
+ db: AsyncSession,
456
+ subscription_id: int
457
+ ) -> Optional[UserSubscription]:
458
+ """
459
+ Get a user subscription by ID.
460
+
461
+ Args:
462
+ db: Database session
463
+ subscription_id: ID of the subscription
464
+
465
+ Returns:
466
+ User subscription or None if not found
467
+ """
468
+ query = (
469
+ select(UserSubscription)
470
+ .where(UserSubscription.id == subscription_id)
471
+ .options(joinedload(UserSubscription.plan))
472
+ )
473
+
474
+ result = await db.execute(query)
475
+ subscription = result.scalars().first()
476
+
477
+ return subscription
478
+
479
+
480
+ async def create_user_subscription(
481
+ db: AsyncSession,
482
+ user_id: int,
483
+ plan_id: int,
484
+ billing_period: BillingPeriod = BillingPeriod.MONTHLY,
485
+ create_stripe_subscription: bool = True,
486
+ payment_method_id: Optional[str] = None
487
+ ) -> Optional[UserSubscription]:
488
+ """
489
+ Create a new user subscription.
490
+
491
+ Args:
492
+ db: Database session
493
+ user_id: ID of the user
494
+ plan_id: ID of the subscription plan
495
+ billing_period: Billing period (monthly or annually)
496
+ create_stripe_subscription: Whether to create a Stripe subscription
497
+ payment_method_id: ID of the payment method to use (required if create_stripe_subscription is True)
498
+
499
+ Returns:
500
+ Created user subscription or None if creation failed
501
+ """
502
+ # Check if user exists
503
+ query = select(User).where(User.id == user_id)
504
+ result = await db.execute(query)
505
+ user = result.scalars().first()
506
+
507
+ if not user:
508
+ logger.warning(f"User with ID {user_id} not found")
509
+ return None
510
+
511
+ # Check if plan exists
512
+ plan = await get_subscription_plan_by_id(db, plan_id)
513
+
514
+ if not plan:
515
+ logger.warning(f"Subscription plan with ID {plan_id} not found")
516
+ return None
517
+
518
+ # Check if user already has an active subscription
519
+ existing_subscription = await get_user_subscription(db, user_id)
520
+
521
+ if existing_subscription:
522
+ logger.warning(f"User with ID {user_id} already has an active subscription")
523
+ return None
524
+
525
+ # Calculate subscription period
526
+ now = datetime.utcnow()
527
+
528
+ if billing_period == BillingPeriod.MONTHLY:
529
+ current_period_end = now + timedelta(days=30)
530
+ price = plan.price_monthly
531
+ stripe_price_id = plan.stripe_monthly_price_id
532
+ elif billing_period == BillingPeriod.ANNUALLY:
533
+ current_period_end = now + timedelta(days=365)
534
+ price = plan.price_annually
535
+ stripe_price_id = plan.stripe_annual_price_id
536
+ else:
537
+ logger.warning(f"Invalid billing period: {billing_period}")
538
+ return None
539
+
540
+ # Create Stripe subscription if requested
541
+ stripe_subscription_id = None
542
+ stripe_customer_id = None
543
+
544
+ if create_stripe_subscription and stripe.api_key and plan.stripe_product_id:
545
+ if not payment_method_id:
546
+ logger.warning("Payment method ID is required to create a Stripe subscription")
547
+ return None
548
+
549
+ try:
550
+ # Create or retrieve Stripe customer
551
+ customers = stripe.Customer.list(email=user.email)
552
+
553
+ if customers.data:
554
+ customer = customers.data[0]
555
+ stripe_customer_id = customer.id
556
+ else:
557
+ customer = stripe.Customer.create(
558
+ email=user.email,
559
+ name=user.full_name,
560
+ metadata={"user_id": user_id}
561
+ )
562
+
563
+ stripe_customer_id = customer.id
564
+
565
+ # Attach payment method to customer
566
+ stripe.PaymentMethod.attach(
567
+ payment_method_id,
568
+ customer=stripe_customer_id
569
+ )
570
+
571
+ # Set as default payment method
572
+ stripe.Customer.modify(
573
+ stripe_customer_id,
574
+ invoice_settings={
575
+ "default_payment_method": payment_method_id
576
+ }
577
+ )
578
+
579
+ # Create subscription
580
+ subscription = stripe.Subscription.create(
581
+ customer=stripe_customer_id,
582
+ items=[
583
+ {"price": stripe_price_id}
584
+ ],
585
+ expand=["latest_invoice.payment_intent"]
586
+ )
587
+
588
+ stripe_subscription_id = subscription.id
589
+
590
+ logger.info(f"Created Stripe subscription {subscription.id} for user {user_id}")
591
+ except Exception as e:
592
+ logger.error(f"Failed to create Stripe subscription for user {user_id}: {e}")
593
+ return None
594
+
595
+ # Create subscription in database
596
+ subscription = UserSubscription(
597
+ user_id=user_id,
598
+ plan_id=plan_id,
599
+ status=SubscriptionStatus.ACTIVE,
600
+ billing_period=billing_period,
601
+ current_period_start=now,
602
+ current_period_end=current_period_end,
603
+ stripe_subscription_id=stripe_subscription_id,
604
+ stripe_customer_id=stripe_customer_id
605
+ )
606
+
607
+ db.add(subscription)
608
+ await db.commit()
609
+ await db.refresh(subscription)
610
+
611
+ # Record payment
612
+ if subscription.id:
613
+ payment_status = PaymentStatus.SUCCEEDED if stripe_subscription_id else PaymentStatus.PENDING
614
+
615
+ payment = PaymentHistory(
616
+ user_id=user_id,
617
+ subscription_id=subscription.id,
618
+ amount=price,
619
+ currency="USD",
620
+ status=payment_status
621
+ )
622
+
623
+ db.add(payment)
624
+ await db.commit()
625
+
626
+ return subscription
627
+
628
+
629
+ async def cancel_user_subscription(
630
+ db: AsyncSession,
631
+ subscription_id: int,
632
+ cancel_stripe_subscription: bool = True
633
+ ) -> Optional[UserSubscription]:
634
+ """
635
+ Cancel a user subscription.
636
+
637
+ Args:
638
+ db: Database session
639
+ subscription_id: ID of the subscription to cancel
640
+ cancel_stripe_subscription: Whether to cancel the Stripe subscription
641
+
642
+ Returns:
643
+ Canceled user subscription or None if cancellation failed
644
+ """
645
+ # Get subscription
646
+ subscription = await get_user_subscription_by_id(db, subscription_id)
647
+
648
+ if not subscription:
649
+ logger.warning(f"Subscription with ID {subscription_id} not found")
650
+ return None
651
+
652
+ # Cancel Stripe subscription if requested
653
+ if cancel_stripe_subscription and subscription.stripe_subscription_id and stripe.api_key:
654
+ try:
655
+ stripe.Subscription.modify(
656
+ subscription.stripe_subscription_id,
657
+ cancel_at_period_end=True
658
+ )
659
+
660
+ logger.info(f"Canceled Stripe subscription {subscription.stripe_subscription_id} at period end")
661
+ except Exception as e:
662
+ logger.error(f"Failed to cancel Stripe subscription {subscription.stripe_subscription_id}: {e}")
663
+
664
+ # Update subscription in database
665
+ now = datetime.utcnow()
666
+
667
+ await db.execute(
668
+ update(UserSubscription)
669
+ .where(UserSubscription.id == subscription_id)
670
+ .values(
671
+ status=SubscriptionStatus.CANCELED,
672
+ canceled_at=now
673
+ )
674
+ )
675
+
676
+ await db.commit()
677
+
678
+ # Refresh subscription
679
+ subscription = await get_user_subscription_by_id(db, subscription_id)
680
+
681
+ return subscription
src/api/services/threat_service.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Service for threat operations.
3
+ """
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy.future import select
6
+ from sqlalchemy import func, or_, and_
7
+ from datetime import datetime, timedelta
8
+ from typing import List, Optional, Dict, Any, Union
9
+
10
+ from src.models.threat import Threat, ThreatSeverity, ThreatStatus, ThreatCategory
11
+ from src.models.indicator import Indicator, IndicatorType
12
+ from src.api.schemas import PaginationParams
13
+
14
+ async def create_threat(
15
+ db: AsyncSession,
16
+ title: str,
17
+ description: str,
18
+ severity: ThreatSeverity,
19
+ category: ThreatCategory,
20
+ status: ThreatStatus = ThreatStatus.NEW,
21
+ source_url: Optional[str] = None,
22
+ source_name: Optional[str] = None,
23
+ source_type: Optional[str] = None,
24
+ affected_entity: Optional[str] = None,
25
+ affected_entity_type: Optional[str] = None,
26
+ confidence_score: float = 0.0,
27
+ risk_score: float = 0.0,
28
+ ) -> Threat:
29
+ """
30
+ Create a new threat.
31
+
32
+ Args:
33
+ db: Database session
34
+ title: Threat title
35
+ description: Threat description
36
+ severity: Threat severity
37
+ category: Threat category
38
+ status: Threat status
39
+ source_url: URL of the source
40
+ source_name: Name of the source
41
+ source_type: Type of source
42
+ affected_entity: Name of affected entity
43
+ affected_entity_type: Type of affected entity
44
+ confidence_score: Confidence score (0-1)
45
+ risk_score: Risk score (0-1)
46
+
47
+ Returns:
48
+ Threat: Created threat
49
+ """
50
+ db_threat = Threat(
51
+ title=title,
52
+ description=description,
53
+ severity=severity,
54
+ category=category,
55
+ status=status,
56
+ source_url=source_url,
57
+ source_name=source_name,
58
+ source_type=source_type,
59
+ discovered_at=datetime.utcnow(),
60
+ affected_entity=affected_entity,
61
+ affected_entity_type=affected_entity_type,
62
+ confidence_score=confidence_score,
63
+ risk_score=risk_score,
64
+ )
65
+
66
+ db.add(db_threat)
67
+ await db.commit()
68
+ await db.refresh(db_threat)
69
+
70
+ return db_threat
71
+
72
+ async def get_threat_by_id(db: AsyncSession, threat_id: int) -> Optional[Threat]:
73
+ """
74
+ Get threat by ID.
75
+
76
+ Args:
77
+ db: Database session
78
+ threat_id: Threat ID
79
+
80
+ Returns:
81
+ Optional[Threat]: Threat or None if not found
82
+ """
83
+ result = await db.execute(select(Threat).filter(Threat.id == threat_id))
84
+ return result.scalars().first()
85
+
86
+ async def get_threats(
87
+ db: AsyncSession,
88
+ pagination: PaginationParams,
89
+ severity: Optional[List[ThreatSeverity]] = None,
90
+ status: Optional[List[ThreatStatus]] = None,
91
+ category: Optional[List[ThreatCategory]] = None,
92
+ search_query: Optional[str] = None,
93
+ from_date: Optional[datetime] = None,
94
+ to_date: Optional[datetime] = None,
95
+ ) -> List[Threat]:
96
+ """
97
+ Get threats with filtering and pagination.
98
+
99
+ Args:
100
+ db: Database session
101
+ pagination: Pagination parameters
102
+ severity: Filter by severity
103
+ status: Filter by status
104
+ category: Filter by category
105
+ search_query: Search in title and description
106
+ from_date: Filter by discovered_at >= from_date
107
+ to_date: Filter by discovered_at <= to_date
108
+
109
+ Returns:
110
+ List[Threat]: List of threats
111
+ """
112
+ query = select(Threat)
113
+
114
+ # Apply filters
115
+ if severity:
116
+ query = query.filter(Threat.severity.in_(severity))
117
+
118
+ if status:
119
+ query = query.filter(Threat.status.in_(status))
120
+
121
+ if category:
122
+ query = query.filter(Threat.category.in_(category))
123
+
124
+ if search_query:
125
+ search_filter = or_(
126
+ Threat.title.ilike(f"%{search_query}%"),
127
+ Threat.description.ilike(f"%{search_query}%")
128
+ )
129
+ query = query.filter(search_filter)
130
+
131
+ if from_date:
132
+ query = query.filter(Threat.discovered_at >= from_date)
133
+
134
+ if to_date:
135
+ query = query.filter(Threat.discovered_at <= to_date)
136
+
137
+ # Apply pagination
138
+ query = query.order_by(Threat.discovered_at.desc())
139
+ query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
140
+
141
+ result = await db.execute(query)
142
+ return result.scalars().all()
143
+
144
+ async def count_threats(
145
+ db: AsyncSession,
146
+ severity: Optional[List[ThreatSeverity]] = None,
147
+ status: Optional[List[ThreatStatus]] = None,
148
+ category: Optional[List[ThreatCategory]] = None,
149
+ search_query: Optional[str] = None,
150
+ from_date: Optional[datetime] = None,
151
+ to_date: Optional[datetime] = None,
152
+ ) -> int:
153
+ """
154
+ Count threats with filtering.
155
+
156
+ Args:
157
+ db: Database session
158
+ severity: Filter by severity
159
+ status: Filter by status
160
+ category: Filter by category
161
+ search_query: Search in title and description
162
+ from_date: Filter by discovered_at >= from_date
163
+ to_date: Filter by discovered_at <= to_date
164
+
165
+ Returns:
166
+ int: Count of threats
167
+ """
168
+ query = select(func.count(Threat.id))
169
+
170
+ # Apply filters (same as in get_threats)
171
+ if severity:
172
+ query = query.filter(Threat.severity.in_(severity))
173
+
174
+ if status:
175
+ query = query.filter(Threat.status.in_(status))
176
+
177
+ if category:
178
+ query = query.filter(Threat.category.in_(category))
179
+
180
+ if search_query:
181
+ search_filter = or_(
182
+ Threat.title.ilike(f"%{search_query}%"),
183
+ Threat.description.ilike(f"%{search_query}%")
184
+ )
185
+ query = query.filter(search_filter)
186
+
187
+ if from_date:
188
+ query = query.filter(Threat.discovered_at >= from_date)
189
+
190
+ if to_date:
191
+ query = query.filter(Threat.discovered_at <= to_date)
192
+
193
+ result = await db.execute(query)
194
+ return result.scalar()
195
+
196
+ async def update_threat(
197
+ db: AsyncSession,
198
+ threat_id: int,
199
+ title: Optional[str] = None,
200
+ description: Optional[str] = None,
201
+ severity: Optional[ThreatSeverity] = None,
202
+ status: Optional[ThreatStatus] = None,
203
+ category: Optional[ThreatCategory] = None,
204
+ affected_entity: Optional[str] = None,
205
+ affected_entity_type: Optional[str] = None,
206
+ confidence_score: Optional[float] = None,
207
+ risk_score: Optional[float] = None,
208
+ ) -> Optional[Threat]:
209
+ """
210
+ Update threat.
211
+
212
+ Args:
213
+ db: Database session
214
+ threat_id: Threat ID
215
+ title: New title
216
+ description: New description
217
+ severity: New severity
218
+ status: New status
219
+ category: New category
220
+ affected_entity: New affected entity
221
+ affected_entity_type: New affected entity type
222
+ confidence_score: New confidence score
223
+ risk_score: New risk score
224
+
225
+ Returns:
226
+ Optional[Threat]: Updated threat or None if not found
227
+ """
228
+ threat = await get_threat_by_id(db, threat_id)
229
+ if not threat:
230
+ return None
231
+
232
+ if title is not None:
233
+ threat.title = title
234
+
235
+ if description is not None:
236
+ threat.description = description
237
+
238
+ if severity is not None:
239
+ threat.severity = severity
240
+
241
+ if status is not None:
242
+ threat.status = status
243
+
244
+ if category is not None:
245
+ threat.category = category
246
+
247
+ if affected_entity is not None:
248
+ threat.affected_entity = affected_entity
249
+
250
+ if affected_entity_type is not None:
251
+ threat.affected_entity_type = affected_entity_type
252
+
253
+ if confidence_score is not None:
254
+ threat.confidence_score = confidence_score
255
+
256
+ if risk_score is not None:
257
+ threat.risk_score = risk_score
258
+
259
+ threat.updated_at = datetime.utcnow()
260
+
261
+ await db.commit()
262
+ await db.refresh(threat)
263
+
264
+ return threat
265
+
266
+ async def add_indicator_to_threat(
267
+ db: AsyncSession,
268
+ threat_id: int,
269
+ value: str,
270
+ indicator_type: IndicatorType,
271
+ description: Optional[str] = None,
272
+ is_verified: bool = False,
273
+ context: Optional[str] = None,
274
+ source: Optional[str] = None,
275
+ confidence_score: float = 0.0,
276
+ ) -> Indicator:
277
+ """
278
+ Add an indicator to a threat.
279
+
280
+ Args:
281
+ db: Database session
282
+ threat_id: Threat ID
283
+ value: Indicator value
284
+ indicator_type: Indicator type
285
+ description: Description of the indicator
286
+ is_verified: Whether the indicator is verified
287
+ context: Context of the indicator
288
+ source: Source of the indicator
289
+ confidence_score: Confidence score (0-1)
290
+
291
+ Returns:
292
+ Indicator: Created indicator
293
+ """
294
+ # Check if threat exists
295
+ threat = await get_threat_by_id(db, threat_id)
296
+ if not threat:
297
+ raise ValueError(f"Threat with ID {threat_id} not found")
298
+
299
+ # Create indicator
300
+ db_indicator = Indicator(
301
+ threat_id=threat_id,
302
+ value=value,
303
+ indicator_type=indicator_type,
304
+ description=description,
305
+ is_verified=is_verified,
306
+ context=context,
307
+ source=source,
308
+ confidence_score=confidence_score,
309
+ first_seen=datetime.utcnow(),
310
+ last_seen=datetime.utcnow(),
311
+ )
312
+
313
+ db.add(db_indicator)
314
+ await db.commit()
315
+ await db.refresh(db_indicator)
316
+
317
+ return db_indicator
318
+
319
+ async def get_threat_statistics(
320
+ db: AsyncSession,
321
+ from_date: Optional[datetime] = None,
322
+ to_date: Optional[datetime] = None,
323
+ ) -> Dict[str, Any]:
324
+ """
325
+ Get threat statistics.
326
+
327
+ Args:
328
+ db: Database session
329
+ from_date: Filter by discovered_at >= from_date
330
+ to_date: Filter by discovered_at <= to_date
331
+
332
+ Returns:
333
+ Dict[str, Any]: Threat statistics
334
+ """
335
+ # Set default time range if not provided
336
+ if not to_date:
337
+ to_date = datetime.utcnow()
338
+
339
+ if not from_date:
340
+ from_date = to_date - timedelta(days=30)
341
+
342
+ # Get count by severity
343
+ severity_counts = {}
344
+ for severity in ThreatSeverity:
345
+ query = select(func.count(Threat.id)).filter(and_(
346
+ Threat.severity == severity,
347
+ Threat.discovered_at >= from_date,
348
+ Threat.discovered_at <= to_date,
349
+ ))
350
+ result = await db.execute(query)
351
+ severity_counts[severity.value] = result.scalar() or 0
352
+
353
+ # Get count by status
354
+ status_counts = {}
355
+ for status in ThreatStatus:
356
+ query = select(func.count(Threat.id)).filter(and_(
357
+ Threat.status == status,
358
+ Threat.discovered_at >= from_date,
359
+ Threat.discovered_at <= to_date,
360
+ ))
361
+ result = await db.execute(query)
362
+ status_counts[status.value] = result.scalar() or 0
363
+
364
+ # Get count by category
365
+ category_counts = {}
366
+ for category in ThreatCategory:
367
+ query = select(func.count(Threat.id)).filter(and_(
368
+ Threat.category == category,
369
+ Threat.discovered_at >= from_date,
370
+ Threat.discovered_at <= to_date,
371
+ ))
372
+ result = await db.execute(query)
373
+ category_counts[category.value] = result.scalar() or 0
374
+
375
+ # Get total count
376
+ query = select(func.count(Threat.id)).filter(and_(
377
+ Threat.discovered_at >= from_date,
378
+ Threat.discovered_at <= to_date,
379
+ ))
380
+ result = await db.execute(query)
381
+ total_count = result.scalar() or 0
382
+
383
+ # Get count by day
384
+ time_series = []
385
+ current_date = from_date.date()
386
+ end_date = to_date.date()
387
+
388
+ while current_date <= end_date:
389
+ next_date = current_date + timedelta(days=1)
390
+ query = select(func.count(Threat.id)).filter(and_(
391
+ Threat.discovered_at >= datetime.combine(current_date, datetime.min.time()),
392
+ Threat.discovered_at < datetime.combine(next_date, datetime.min.time()),
393
+ ))
394
+ result = await db.execute(query)
395
+ count = result.scalar() or 0
396
+ time_series.append({
397
+ "date": current_date.isoformat(),
398
+ "count": count
399
+ })
400
+ current_date = next_date
401
+
402
+ # Return statistics
403
+ return {
404
+ "total_count": total_count,
405
+ "severity_counts": severity_counts,
406
+ "status_counts": status_counts,
407
+ "category_counts": category_counts,
408
+ "time_series": time_series,
409
+ "from_date": from_date.isoformat(),
410
+ "to_date": to_date.isoformat(),
411
+ }
src/api/services/user_service.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy.ext.asyncio import AsyncSession
2
+ from sqlalchemy.future import select
3
+ from sqlalchemy import update
4
+ from passlib.context import CryptContext
5
+ from typing import Optional, List, Dict, Any
6
+ import logging
7
+
8
+ from src.models.user import User
9
+ from src.api.schemas import UserCreate, UserUpdate, UserInDB
10
+
11
+ # Configure logger
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Password context for hashing and verification
15
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
16
+
17
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
18
+ """
19
+ Verify password against hash.
20
+
21
+ Args:
22
+ plain_password: Plain password
23
+ hashed_password: Hashed password
24
+
25
+ Returns:
26
+ bool: True if password is correct
27
+ """
28
+ return pwd_context.verify(plain_password, hashed_password)
29
+
30
+ def get_password_hash(password: str) -> str:
31
+ """
32
+ Hash password.
33
+
34
+ Args:
35
+ password: Plain password
36
+
37
+ Returns:
38
+ str: Hashed password
39
+ """
40
+ return pwd_context.hash(password)
41
+
42
+ async def get_user_by_username(db: AsyncSession, username: str) -> Optional[UserInDB]:
43
+ """
44
+ Get user by username.
45
+
46
+ Args:
47
+ db: Database session
48
+ username: Username
49
+
50
+ Returns:
51
+ Optional[UserInDB]: User if found, None otherwise
52
+ """
53
+ try:
54
+ result = await db.execute(select(User).where(User.username == username))
55
+ user = result.scalars().first()
56
+
57
+ if not user:
58
+ return None
59
+
60
+ # Convert SQLAlchemy model to Pydantic model
61
+ user_dict = {c.name: getattr(user, c.name) for c in user.__table__.columns}
62
+ return UserInDB(**user_dict)
63
+ except Exception as e:
64
+ logger.error(f"Error getting user by username: {e}")
65
+ return None
66
+
67
+ async def authenticate_user(db: AsyncSession, username: str, password: str) -> Optional[UserInDB]:
68
+ """
69
+ Authenticate user.
70
+
71
+ Args:
72
+ db: Database session
73
+ username: Username
74
+ password: Plain password
75
+
76
+ Returns:
77
+ Optional[UserInDB]: User if authenticated, None otherwise
78
+ """
79
+ user = await get_user_by_username(db, username)
80
+
81
+ if not user:
82
+ return None
83
+
84
+ if not verify_password(password, user.hashed_password):
85
+ return None
86
+
87
+ return user
88
+
89
+ async def create_user(db: AsyncSession, user_data: UserCreate) -> Optional[UserInDB]:
90
+ """
91
+ Create a new user.
92
+
93
+ Args:
94
+ db: Database session
95
+ user_data: User data
96
+
97
+ Returns:
98
+ Optional[UserInDB]: Created user
99
+ """
100
+ try:
101
+ # Check if user already exists
102
+ existing_user = await get_user_by_username(db, user_data.username)
103
+ if existing_user:
104
+ return None
105
+
106
+ # Create new user
107
+ hashed_password = get_password_hash(user_data.password)
108
+ user = User(
109
+ username=user_data.username,
110
+ email=user_data.email,
111
+ full_name=user_data.full_name,
112
+ hashed_password=hashed_password,
113
+ is_active=user_data.is_active
114
+ )
115
+
116
+ db.add(user)
117
+ await db.commit()
118
+ await db.refresh(user)
119
+
120
+ # Convert SQLAlchemy model to Pydantic model
121
+ user_dict = {c.name: getattr(user, c.name) for c in user.__table__.columns}
122
+ return UserInDB(**user_dict)
123
+ except Exception as e:
124
+ logger.error(f"Error creating user: {e}")
125
+ await db.rollback()
126
+ return None
127
+
128
+ async def update_user(db: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserInDB]:
129
+ """
130
+ Update user.
131
+
132
+ Args:
133
+ db: Database session
134
+ user_id: User ID
135
+ user_data: User data
136
+
137
+ Returns:
138
+ Optional[UserInDB]: Updated user
139
+ """
140
+ try:
141
+ # Create update dictionary
142
+ update_data = user_data.dict(exclude_unset=True)
143
+
144
+ # Hash password if provided
145
+ if "password" in update_data:
146
+ update_data["hashed_password"] = get_password_hash(update_data.pop("password"))
147
+
148
+ # Update user
149
+ stmt = update(User).where(User.id == user_id).values(**update_data)
150
+ await db.execute(stmt)
151
+ await db.commit()
152
+
153
+ # Get updated user
154
+ result = await db.execute(select(User).where(User.id == user_id))
155
+ user = result.scalars().first()
156
+
157
+ if not user:
158
+ return None
159
+
160
+ # Convert SQLAlchemy model to Pydantic model
161
+ user_dict = {c.name: getattr(user, c.name) for c in user.__table__.columns}
162
+ return UserInDB(**user_dict)
163
+ except Exception as e:
164
+ logger.error(f"Error updating user: {e}")
165
+ await db.rollback()
166
+ return None