import sqlite3 import uuid import os import logging from datetime import datetime, timedelta import hashlib # Use hashlib instead of jwt from passlib.hash import bcrypt from dotenv import load_dotenv from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from pydantic import BaseModel from typing import Optional from fastapi import HTTPException, status import jwt # Load environment variables load_dotenv() # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger('auth') # Security configuration SECRET_KEY = os.getenv("JWT_SECRET", "your-secret-key-for-development-only") ALGORITHM = "HS256" JWT_EXPIRATION_DELTA = timedelta(days=1) # Token valid for 1 day # Database path from environment variable or default # Fix the incorrect DB_PATH DB_PATH = os.getenv("DB_PATH", "/tmp/data/user_data.db") # FastAPI OAuth2 scheme oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # Pydantic models for FastAPI class User(BaseModel): id: str email: str subscription_tier: str = "free_tier" subscription_expiry: Optional[datetime] = None api_calls_remaining: int = 5 last_reset_date: Optional[datetime] = None class UserCreate(BaseModel): email: str password: str class Token(BaseModel): access_token: str token_type: str class TokenData(BaseModel): user_id: Optional[str] = None # Subscription tiers and limits # Update the SUBSCRIPTION_TIERS dictionary SUBSCRIPTION_TIERS = { "free_tier": { "price": 0, "currency": "INR", "features": ["basic_document_analysis", "basic_risk_assessment","document_analysis", "video_analysis", "audio_analysis", "chatbot", "detailed_risk_assessment", "contract_clause_analysis"], "limits": { "document_size_mb": 100, "documents_per_month": 100, "video_size_mb": 100, "audio_size_mb": 100, "daily_api_calls": 100, "max_document_size_mb": 100 } }, "standard_tier": { "price": 799, "currency": "INR", "features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot"], "limits": { "document_size_mb": 20, "documents_per_month": 20, "video_size_mb": 100, "audio_size_mb": 50, "daily_api_calls": 100, "max_document_size_mb": 20 } }, "premium_tier": { "price": 1499, "currency": "INR", "features": [ "basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot", "detailed_risk_assessment", "contract_clause_analysis" ], "limits": { "document_size_mb": 50, "documents_per_month": 999999, "video_size_mb": 500, "audio_size_mb": 200, "daily_api_calls": 1000, "max_document_size_mb": 50 } } } # Database connection management def get_db_connection(): """Create and return a database connection with proper error handling""" try: # Ensure the directory exists DB_PATH = os.getenv("DB_PATH", "/tmp/data/user_data.db") db_dir = os.path.dirname(DB_PATH) os.makedirs(db_dir, exist_ok=True) conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row # Return rows as dictionaries return conn except sqlite3.Error as e: logger.error(f"Database connection error: {e}") raise Exception(f"Database connection failed: {e}") # Database setup # In the init_auth_db function, update the CREATE TABLE statement to match our schema def init_auth_db(): """Initialize the authentication database with required tables""" conn= None try: conn = get_db_connection() c = conn.cursor() # Create users table with the correct schema c.execute(''' CREATE TABLE IF NOT EXISTS users ( id TEXT PRIMARY KEY, email TEXT UNIQUE NOT NULL, hashed_password TEXT NOT NULL, password TEXT, subscription_tier TEXT DEFAULT 'free_tier', is_active BOOLEAN DEFAULT 1, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, api_calls_remaining INTEGER DEFAULT 10, last_reset_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') # Create subscriptions table c.execute(''' CREATE TABLE IF NOT EXISTS subscriptions ( id TEXT PRIMARY KEY, user_id TEXT, tier TEXT, plan_id TEXT, status TEXT, created_at TIMESTAMP, expires_at TIMESTAMP, paypal_subscription_id TEXT, FOREIGN KEY (user_id) REFERENCES users (id) ) ''') # Create usage stats table c.execute(''' CREATE TABLE IF NOT EXISTS usage_stats ( id TEXT PRIMARY KEY, user_id TEXT, month INTEGER, year INTEGER, analyses_used INTEGER, FOREIGN KEY (user_id) REFERENCES users (id) ) ''') # Create tokens table for refresh tokens c.execute(''' CREATE TABLE IF NOT EXISTS refresh_tokens ( user_id TEXT, token TEXT, expires_at TIMESTAMP, FOREIGN KEY (user_id) REFERENCES users (id) ) ''') conn.commit() logger.info("Database initialized successfully") except Exception as e: logger.error(f"Database initialization error: {e}") raise finally: if conn: conn.close() # Initialize the database init_auth_db() # Password hashing with bcrypt # Update the password hashing and verification functions to use a more reliable method # Replace these functions # Remove these conflicting functions # def hash_password(password): # """Hash a password using bcrypt""" # return bcrypt.hash(password) # # def verify_password(plain_password, hashed_password): # """Verify a password against its hash""" # return bcrypt.verify(plain_password, hashed_password) # Keep only these improved functions def hash_password(password): """Hash a password using bcrypt""" # Use a more direct approach to avoid bcrypt version issues import bcrypt # Convert password to bytes if it's not already if isinstance(password, str): password = password.encode('utf-8') # Generate salt and hash salt = bcrypt.gensalt() hashed = bcrypt.hashpw(password, salt) # Return as string for storage return hashed.decode('utf-8') def verify_password(plain_password, hashed_password): """Verify a password against its hash""" import bcrypt # Convert inputs to bytes if they're not already if isinstance(plain_password, str): plain_password = plain_password.encode('utf-8') if isinstance(hashed_password, str): hashed_password = hashed_password.encode('utf-8') try: # Use direct bcrypt verification return bcrypt.checkpw(plain_password, hashed_password) except Exception as e: logger.error(f"Password verification error: {e}") return False # User registration def register_user(email, password): try: conn = get_db_connection() c = conn.cursor() # Check if user already exists c.execute("SELECT * FROM users WHERE email = ?", (email,)) if c.fetchone(): return False, "Email already registered" # Create new user user_id = str(uuid.uuid4()) # Add more detailed logging logger.info(f"Registering new user with email: {email}") hashed_pw = hash_password(password) logger.info(f"Password hashed successfully: {bool(hashed_pw)}") c.execute(""" INSERT INTO users (id, email, hashed_password, subscription_tier, api_calls_remaining, last_reset_date) VALUES (?, ?, ?, ?, ?, ?) """, (user_id, email, hashed_pw, "free_tier", 5, datetime.now())) conn.commit() logger.info(f"User registered successfully: {email}") # Verify the user was actually stored c.execute("SELECT * FROM users WHERE email = ?", (email,)) stored_user = c.fetchone() logger.info(f"User verification after registration: {bool(stored_user)}") access_token = create_access_token(user_id) return True, { "user_id": user_id, "access_token": access_token, "token_type": "bearer" } except Exception as e: logger.error(f"User registration error: {e}") return False, f"Registration failed: {str(e)}" finally: if conn: conn.close() # User login # Fix the authenticate_user function # In the authenticate_user function, update the password verification to use hashed_password def authenticate_user(email, password): """Authenticate a user and return user data with tokens""" try: conn = get_db_connection() c = conn.cursor() # Get user by email c.execute("SELECT * FROM users WHERE email = ? AND is_active = 1", (email,)) user = c.fetchone() if not user: logger.warning(f"User not found: {email}") return None # Add debug logging for password verification logger.info(f"Verifying password for user: {email}") logger.info(f"Stored hashed password: {user['hashed_password'][:20]}...") try: # Check if password verification works is_valid = verify_password(password, user['hashed_password']) logger.info(f"Password verification result: {is_valid}") if not is_valid: logger.warning(f"Password verification failed for user: {email}") return None except Exception as e: logger.error(f"Password verification error: {e}") return None # Update last login time if column exists try: c.execute("UPDATE users SET last_login = ? WHERE id = ?", (datetime.now(), user['id'])) conn.commit() except sqlite3.OperationalError: # last_login column might not exist pass # Convert sqlite3.Row to dict to use get() method user_dict = dict(user) # Create and return a User object return User( id=user_dict['id'], email=user_dict['email'], subscription_tier=user_dict.get('subscription_tier', 'free_tier'), subscription_expiry=None, # Handle this properly if needed api_calls_remaining=user_dict.get('api_calls_remaining', 5), last_reset_date=user_dict.get('last_reset_date') ) except Exception as e: logger.error(f"Login error: {e}") return None finally: if conn: conn.close() # Token generation and validation - completely replaced def create_access_token(user_id): """Create a new access token for a user""" try: # Create a JWT token with user_id and expiration expiration = datetime.now() + JWT_EXPIRATION_DELTA # Create a token payload payload = { "sub": user_id, "exp": expiration.timestamp() } # Generate the JWT token token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) logger.info(f"Created access token for user: {user_id}") return token except Exception as e: logger.error(f"Token creation error: {e}") return None def update_auth_db_schema(): """Update the authentication database schema with any missing columns""" try: conn = get_db_connection() c = conn.cursor() # Check if tier column exists in subscriptions table c.execute("PRAGMA table_info(subscriptions)") columns = [column[1] for column in c.fetchall()] # Add tier column if it doesn't exist if "tier" not in columns: logger.info("Adding 'tier' column to subscriptions table") c.execute("ALTER TABLE subscriptions ADD COLUMN tier TEXT") conn.commit() logger.info("Database schema updated successfully") conn.close() except Exception as e: logger.error(f"Database schema update error: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Database schema update error: {str(e)}" ) # Add this to your get_current_user function async def get_current_user(token: str = Depends(oauth2_scheme)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: # Decode the JWT token payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) user_id: str = payload.get("sub") if user_id is None: logger.error("Token missing 'sub' field") raise credentials_exception except Exception as e: logger.error(f"Token validation error: {str(e)}") raise credentials_exception # Get user from database conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT id, email, subscription_tier, is_active FROM users WHERE id = ?", (user_id,)) user_data = cursor.fetchone() conn.close() if user_data is None: logger.error(f"User not found: {user_id}") raise credentials_exception user = User( id=user_data[0], email=user_data[1], subscription_tier=user_data[2], is_active=bool(user_data[3]) ) return user async def get_current_active_user(current_user: User = Depends(get_current_user)): """Get the current active user""" return current_user def create_user_subscription(email, tier): """Create a subscription for a user""" try: # Get user by email conn = get_db_connection() c = conn.cursor() # Get user ID c.execute("SELECT id FROM users WHERE email = ?", (email,)) user_data = c.fetchone() if not user_data: return False, "User not found" user_id = user_data['id'] # Check if tier is valid valid_tiers = ["standard_tier", "premium_tier"] if tier not in valid_tiers: return False, f"Invalid tier: {tier}. Must be one of {valid_tiers}" # Create subscription subscription_id = str(uuid.uuid4()) created_at = datetime.now() expires_at = created_at + timedelta(days=30) # 30-day subscription # Insert subscription c.execute(""" INSERT INTO subscriptions (id, user_id, tier, status, created_at, expires_at) VALUES (?, ?, ?, ?, ?, ?) """, (subscription_id, user_id, tier, "active", created_at, expires_at)) # Update user's subscription tier c.execute(""" UPDATE users SET subscription_tier = ? WHERE id = ? """, (tier, user_id)) conn.commit() return True, { "id": subscription_id, "user_id": user_id, "tier": tier, "status": "active", "created_at": created_at.isoformat(), "expires_at": expires_at.isoformat() } except Exception as e: logger.error(f"Subscription creation error: {e}") return False, f"Failed to create subscription: {str(e)}" finally: if conn: conn.close() def get_user(user_id: str): """Get user by ID""" try: conn = get_db_connection() c = conn.cursor() # Get user c.execute("SELECT * FROM users WHERE id = ? AND is_active = 1", (user_id,)) user_data = c.fetchone() if not user_data: return None # Convert to User model user_dict = dict(user_data) # Handle datetime conversions if needed if user_dict.get("subscription_expiry") and isinstance(user_dict["subscription_expiry"], str): user_dict["subscription_expiry"] = datetime.fromisoformat(user_dict["subscription_expiry"]) if user_dict.get("last_reset_date") and isinstance(user_dict["last_reset_date"], str): user_dict["last_reset_date"] = datetime.fromisoformat(user_dict["last_reset_date"]) return User( id=user_dict['id'], email=user_dict['email'], subscription_tier=user_dict['subscription_tier'], subscription_expiry=user_dict.get('subscription_expiry'), api_calls_remaining=user_dict.get('api_calls_remaining', 5), last_reset_date=user_dict.get('last_reset_date') ) except Exception as e: logger.error(f"Get user error: {e}") return None finally: if conn: conn.close() def check_subscription_access(user: User, feature: str, file_size_mb: Optional[float] = None): """Check if the user has access to the requested feature and file size""" # Check if subscription is expired if user.subscription_tier != "free_tier" and user.subscription_expiry and user.subscription_expiry < datetime.now(): # Downgrade to free tier if subscription expired user.subscription_tier = "free_tier" user.api_calls_remaining = SUBSCRIPTION_TIERS["free_tier"]["limits"]["daily_api_calls"] with get_db_connection() as conn: c = conn.cursor() c.execute(""" UPDATE users SET subscription_tier = ?, api_calls_remaining = ? WHERE id = ? """, (user.subscription_tier, user.api_calls_remaining, user.id)) conn.commit() # Reset API calls if needed user = reset_api_calls_if_needed(user) # Check if user has API calls remaining if user.api_calls_remaining <= 0: raise HTTPException( status_code=429, detail="API call limit reached for today. Please upgrade your subscription or try again tomorrow." ) # Check if feature is available in user's subscription tier tier_features = SUBSCRIPTION_TIERS[user.subscription_tier]["features"] if feature not in tier_features: raise HTTPException( status_code=403, detail=f"The {feature} feature is not available in your {user.subscription_tier} subscription. Please upgrade to access this feature." ) # Check file size limit if applicable if file_size_mb: max_size = SUBSCRIPTION_TIERS[user.subscription_tier]["limits"]["max_document_size_mb"] if file_size_mb > max_size: raise HTTPException( status_code=413, detail=f"File size exceeds the {max_size}MB limit for your {user.subscription_tier} subscription. Please upgrade or use a smaller file." ) # Decrement API calls remaining user.api_calls_remaining -= 1 with get_db_connection() as conn: c = conn.cursor() c.execute(""" UPDATE users SET api_calls_remaining = ? WHERE id = ? """, (user.api_calls_remaining, user.id)) conn.commit() return True def reset_api_calls_if_needed(user: User): """Reset API call counter if it's a new day""" today = datetime.now().date() if user.last_reset_date is None or user.last_reset_date.date() < today: tier_limits = SUBSCRIPTION_TIERS[user.subscription_tier]["limits"] user.api_calls_remaining = tier_limits["daily_api_calls"] user.last_reset_date = datetime.now() # Update the user in the database with get_db_connection() as conn: c = conn.cursor() c.execute(""" UPDATE users SET api_calls_remaining = ?, last_reset_date = ? WHERE id = ? """, (user.api_calls_remaining, user.last_reset_date, user.id)) conn.commit() return user def login_user(email, password): """Login a user with email and password""" try: # Authenticate user user = authenticate_user(email, password) if not user: return False, "Incorrect username or password" # Create access token access_token = create_access_token(user.id) # Create refresh token refresh_token = str(uuid.uuid4()) expires_at = datetime.now() + timedelta(days=30) # Store refresh token conn = get_db_connection() c = conn.cursor() c.execute("INSERT INTO refresh_tokens VALUES (?, ?, ?)", (user.id, refresh_token, expires_at)) conn.commit() # Get subscription info c.execute("SELECT * FROM subscriptions WHERE user_id = ? AND status = 'active'", (user.id,)) subscription = c.fetchone() # Convert subscription to dict if it exists, otherwise set to None subscription_dict = dict(subscription) if subscription else None conn.close() return True, { "user_id": user.id, "email": user.email, "access_token": access_token, "refresh_token": refresh_token, "subscription": subscription_dict } except Exception as e: logger.error(f"Login error: {e}") return False, f"Login failed: {str(e)}" def get_subscription_plans(): """ Returns a list of available subscription plans based on SUBSCRIPTION_TIERS. """ plans = [] for tier, details in SUBSCRIPTION_TIERS.items(): plans.append({ "tier": tier, "price": details["price"], "currency": details["currency"], "features": details["features"], "limits": details["limits"] }) return plans