hardik8588 commited on
Commit
6517783
·
verified ·
1 Parent(s): ff92fec

Update auth.py

Browse files
Files changed (1) hide show
  1. auth.py +655 -654
auth.py CHANGED
@@ -1,655 +1,656 @@
1
- import sqlite3
2
- import uuid
3
- import os
4
- import logging
5
- from datetime import datetime, timedelta
6
- import hashlib # Use hashlib instead of jwt
7
- from passlib.hash import bcrypt
8
- from dotenv import load_dotenv
9
- from fastapi import Depends, HTTPException
10
- from fastapi.security import OAuth2PasswordBearer
11
- from pydantic import BaseModel
12
- from typing import Optional
13
- from fastapi import HTTPException, status
14
- import jwt
15
- from jwt.exceptions import PyJWTError
16
- import sqlite3
17
-
18
- # Load environment variables
19
- load_dotenv()
20
-
21
- # Configure logging
22
- logging.basicConfig(
23
- level=logging.INFO,
24
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
25
- )
26
- logger = logging.getLogger('auth')
27
-
28
- # Security configuration
29
- SECRET_KEY = os.getenv("JWT_SECRET", "your-secret-key-for-development-only")
30
- ALGORITHM = "HS256"
31
- JWT_EXPIRATION_DELTA = timedelta(days=1) # Token valid for 1 day
32
- # Database path from environment variable or default
33
- # Fix the incorrect DB_PATH
34
- DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "data/user_data.db"))
35
-
36
- # FastAPI OAuth2 scheme
37
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
38
-
39
- # Pydantic models for FastAPI
40
- class User(BaseModel):
41
- id: str
42
- email: str
43
- subscription_tier: str = "free_tier"
44
- subscription_expiry: Optional[datetime] = None
45
- api_calls_remaining: int = 5
46
- last_reset_date: Optional[datetime] = None
47
-
48
- class UserCreate(BaseModel):
49
- email: str
50
- password: str
51
-
52
- class Token(BaseModel):
53
- access_token: str
54
- token_type: str
55
-
56
- class TokenData(BaseModel):
57
- user_id: Optional[str] = None
58
-
59
- # Subscription tiers and limits
60
- # Update the SUBSCRIPTION_TIERS dictionary
61
- SUBSCRIPTION_TIERS = {
62
- "free_tier": {
63
- "price": 0,
64
- "currency": "INR",
65
- "features": ["basic_document_analysis", "basic_risk_assessment"],
66
- "limits": {
67
- "document_size_mb": 5,
68
- "documents_per_month": 3,
69
- "video_size_mb": 0,
70
- "audio_size_mb": 0,
71
- "daily_api_calls": 10, # <-- Add this
72
- "max_document_size_mb": 5 # <-- Add this
73
- }
74
- },
75
- "standard_tier": {
76
- "price": 799,
77
- "currency": "INR",
78
- "features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot"],
79
- "limits": {
80
- "document_size_mb": 20,
81
- "documents_per_month": 20,
82
- "video_size_mb": 100,
83
- "audio_size_mb": 50,
84
- "daily_api_calls": 100, # <-- Add this
85
- "max_document_size_mb": 20 # <-- Add this
86
- }
87
- },
88
- "premium_tier": {
89
- "price": 1499,
90
- "currency": "INR",
91
- "features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot", "detailed_risk_assessment", "contract_clause_analysis"],
92
- "limits": {
93
- "document_size_mb": 50,
94
- "documents_per_month": 999999,
95
- "video_size_mb": 500,
96
- "audio_size_mb": 200,
97
- "daily_api_calls": 1000, # <-- Add this
98
- "max_document_size_mb": 50 # <-- Add this
99
- }
100
- }
101
- }
102
-
103
- # Database connection management
104
- def get_db_connection():
105
- """Create and return a database connection with proper error handling"""
106
- try:
107
- # Ensure the directory exists
108
- db_dir = os.path.dirname(DB_PATH)
109
- os.makedirs(db_dir, exist_ok=True)
110
-
111
- conn = sqlite3.connect(DB_PATH)
112
- conn.row_factory = sqlite3.Row # Return rows as dictionaries
113
- return conn
114
- except sqlite3.Error as e:
115
- logger.error(f"Database connection error: {e}")
116
- raise Exception(f"Database connection failed: {e}")
117
-
118
- # Database setup
119
- # In the init_auth_db function, update the CREATE TABLE statement to match our schema
120
- def init_auth_db():
121
- """Initialize the authentication database with required tables"""
122
- try:
123
- conn = get_db_connection()
124
- c = conn.cursor()
125
-
126
- # Create users table with the correct schema
127
- c.execute('''
128
- CREATE TABLE IF NOT EXISTS users (
129
- id TEXT PRIMARY KEY,
130
- email TEXT UNIQUE NOT NULL,
131
- hashed_password TEXT NOT NULL,
132
- password TEXT,
133
- subscription_tier TEXT DEFAULT 'free_tier',
134
- is_active BOOLEAN DEFAULT 1,
135
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
136
- api_calls_remaining INTEGER DEFAULT 10,
137
- last_reset_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP
138
- )
139
- ''')
140
-
141
- # Create subscriptions table
142
- c.execute('''
143
- CREATE TABLE IF NOT EXISTS subscriptions (
144
- id TEXT PRIMARY KEY,
145
- user_id TEXT,
146
- tier TEXT,
147
- plan_id TEXT,
148
- status TEXT,
149
- created_at TIMESTAMP,
150
- expires_at TIMESTAMP,
151
- paypal_subscription_id TEXT,
152
- FOREIGN KEY (user_id) REFERENCES users (id)
153
- )
154
- ''')
155
-
156
- # Create usage stats table
157
- c.execute('''
158
- CREATE TABLE IF NOT EXISTS usage_stats (
159
- id TEXT PRIMARY KEY,
160
- user_id TEXT,
161
- month INTEGER,
162
- year INTEGER,
163
- analyses_used INTEGER,
164
- FOREIGN KEY (user_id) REFERENCES users (id)
165
- )
166
- ''')
167
-
168
- # Create tokens table for refresh tokens
169
- c.execute('''
170
- CREATE TABLE IF NOT EXISTS refresh_tokens (
171
- user_id TEXT,
172
- token TEXT,
173
- expires_at TIMESTAMP,
174
- FOREIGN KEY (user_id) REFERENCES users (id)
175
- )
176
- ''')
177
-
178
- conn.commit()
179
- logger.info("Database initialized successfully")
180
- except Exception as e:
181
- logger.error(f"Database initialization error: {e}")
182
- raise
183
- finally:
184
- if conn:
185
- conn.close()
186
-
187
- # Initialize the database
188
- init_auth_db()
189
-
190
- # Password hashing with bcrypt
191
- # Update the password hashing and verification functions to use a more reliable method
192
-
193
- # Replace these functions
194
- # Remove these conflicting functions
195
- # def hash_password(password):
196
- # """Hash a password using bcrypt"""
197
- # return bcrypt.hash(password)
198
- #
199
- # def verify_password(plain_password, hashed_password):
200
- # """Verify a password against its hash"""
201
- # return bcrypt.verify(plain_password, hashed_password)
202
-
203
- # Keep only these improved functions
204
- def hash_password(password):
205
- """Hash a password using bcrypt"""
206
- # Use a more direct approach to avoid bcrypt version issues
207
- import bcrypt
208
- # Convert password to bytes if it's not already
209
- if isinstance(password, str):
210
- password = password.encode('utf-8')
211
- # Generate salt and hash
212
- salt = bcrypt.gensalt()
213
- hashed = bcrypt.hashpw(password, salt)
214
- # Return as string for storage
215
- return hashed.decode('utf-8')
216
-
217
- def verify_password(plain_password, hashed_password):
218
- """Verify a password against its hash"""
219
- import bcrypt
220
- # Convert inputs to bytes if they're not already
221
- if isinstance(plain_password, str):
222
- plain_password = plain_password.encode('utf-8')
223
- if isinstance(hashed_password, str):
224
- hashed_password = hashed_password.encode('utf-8')
225
-
226
- try:
227
- # Use direct bcrypt verification
228
- return bcrypt.checkpw(plain_password, hashed_password)
229
- except Exception as e:
230
- logger.error(f"Password verification error: {e}")
231
- return False
232
-
233
- # User registration
234
- def register_user(email, password):
235
- try:
236
- conn = get_db_connection()
237
- c = conn.cursor()
238
-
239
- # Check if user already exists
240
- c.execute("SELECT * FROM users WHERE email = ?", (email,))
241
- if c.fetchone():
242
- return False, "Email already registered"
243
-
244
- # Create new user
245
- user_id = str(uuid.uuid4())
246
-
247
- # Add more detailed logging
248
- logger.info(f"Registering new user with email: {email}")
249
- hashed_pw = hash_password(password)
250
- logger.info(f"Password hashed successfully: {bool(hashed_pw)}")
251
-
252
- c.execute("""
253
- INSERT INTO users
254
- (id, email, hashed_password, subscription_tier, api_calls_remaining, last_reset_date)
255
- VALUES (?, ?, ?, ?, ?, ?)
256
- """, (user_id, email, hashed_pw, "free_tier", 5, datetime.now()))
257
-
258
- conn.commit()
259
- logger.info(f"User registered successfully: {email}")
260
-
261
- # Verify the user was actually stored
262
- c.execute("SELECT * FROM users WHERE email = ?", (email,))
263
- stored_user = c.fetchone()
264
- logger.info(f"User verification after registration: {bool(stored_user)}")
265
-
266
- access_token = create_access_token(user_id)
267
- return True, {
268
- "user_id": user_id,
269
- "access_token": access_token,
270
- "token_type": "bearer"
271
- }
272
- except Exception as e:
273
- logger.error(f"User registration error: {e}")
274
- return False, f"Registration failed: {str(e)}"
275
- finally:
276
- if conn:
277
- conn.close()
278
-
279
- # User login
280
- # Fix the authenticate_user function
281
- # In the authenticate_user function, update the password verification to use hashed_password
282
- def authenticate_user(email, password):
283
- """Authenticate a user and return user data with tokens"""
284
- try:
285
- conn = get_db_connection()
286
- c = conn.cursor()
287
-
288
- # Get user by email
289
- c.execute("SELECT * FROM users WHERE email = ? AND is_active = 1", (email,))
290
- user = c.fetchone()
291
-
292
- if not user:
293
- logger.warning(f"User not found: {email}")
294
- return None
295
-
296
- # Add debug logging for password verification
297
- logger.info(f"Verifying password for user: {email}")
298
- logger.info(f"Stored hashed password: {user['hashed_password'][:20]}...")
299
-
300
- try:
301
- # Check if password verification works
302
- is_valid = verify_password(password, user['hashed_password'])
303
- logger.info(f"Password verification result: {is_valid}")
304
-
305
- if not is_valid:
306
- logger.warning(f"Password verification failed for user: {email}")
307
- return None
308
- except Exception as e:
309
- logger.error(f"Password verification error: {e}")
310
- return None
311
-
312
- # Update last login time if column exists
313
- try:
314
- c.execute("UPDATE users SET last_login = ? WHERE id = ?",
315
- (datetime.now(), user['id']))
316
- conn.commit()
317
- except sqlite3.OperationalError:
318
- # last_login column might not exist
319
- pass
320
-
321
- # Convert sqlite3.Row to dict to use get() method
322
- user_dict = dict(user)
323
-
324
- # Create and return a User object
325
- return User(
326
- id=user_dict['id'],
327
- email=user_dict['email'],
328
- subscription_tier=user_dict.get('subscription_tier', 'free_tier'),
329
- subscription_expiry=None, # Handle this properly if needed
330
- api_calls_remaining=user_dict.get('api_calls_remaining', 5),
331
- last_reset_date=user_dict.get('last_reset_date')
332
- )
333
- except Exception as e:
334
- logger.error(f"Login error: {e}")
335
- return None
336
- finally:
337
- if conn:
338
- conn.close()
339
-
340
- # Token generation and validation - completely replaced
341
- def create_access_token(user_id):
342
- """Create a new access token for a user"""
343
- try:
344
- # Create a JWT token with user_id and expiration
345
- expiration = datetime.now() + JWT_EXPIRATION_DELTA
346
-
347
- # Create a token payload
348
- payload = {
349
- "sub": user_id,
350
- "exp": expiration.timestamp()
351
- }
352
-
353
- # Generate the JWT token
354
- token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
355
-
356
- logger.info(f"Created access token for user: {user_id}")
357
- return token
358
- except Exception as e:
359
- logger.error(f"Token creation error: {e}")
360
- return None
361
-
362
-
363
- def update_auth_db_schema():
364
- """Update the authentication database schema with any missing columns"""
365
- try:
366
- conn = get_db_connection()
367
- c = conn.cursor()
368
-
369
- # Check if tier column exists in subscriptions table
370
- c.execute("PRAGMA table_info(subscriptions)")
371
- columns = [column[1] for column in c.fetchall()]
372
-
373
- # Add tier column if it doesn't exist
374
- if "tier" not in columns:
375
- logger.info("Adding 'tier' column to subscriptions table")
376
- c.execute("ALTER TABLE subscriptions ADD COLUMN tier TEXT")
377
- conn.commit()
378
- logger.info("Database schema updated successfully")
379
-
380
- conn.close()
381
- except Exception as e:
382
- logger.error(f"Database schema update error: {e}")
383
- raise HTTPException(
384
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
385
- detail=f"Database schema update error: {str(e)}"
386
- )
387
-
388
- # Add this to your get_current_user function
389
- async def get_current_user(token: str = Depends(oauth2_scheme)):
390
- credentials_exception = HTTPException(
391
- status_code=status.HTTP_401_UNAUTHORIZED,
392
- detail="Could not validate credentials",
393
- headers={"WWW-Authenticate": "Bearer"},
394
- )
395
- try:
396
- # Decode the JWT token
397
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
398
- user_id: str = payload.get("sub")
399
- if user_id is None:
400
- logger.error("Token missing 'sub' field")
401
- raise credentials_exception
402
- except Exception as e:
403
- logger.error(f"Token validation error: {str(e)}")
404
- raise credentials_exception
405
-
406
- # Get user from database
407
- conn = get_db_connection()
408
- cursor = conn.cursor()
409
- cursor.execute("SELECT id, email, subscription_tier, is_active FROM users WHERE id = ?", (user_id,))
410
- user_data = cursor.fetchone()
411
- conn.close()
412
-
413
- if user_data is None:
414
- logger.error(f"User not found: {user_id}")
415
- raise credentials_exception
416
-
417
- user = User(
418
- id=user_data[0],
419
- email=user_data[1],
420
- subscription_tier=user_data[2],
421
- is_active=bool(user_data[3])
422
- )
423
-
424
- return user
425
-
426
- async def get_current_active_user(current_user: User = Depends(get_current_user)):
427
- """Get the current active user"""
428
- return current_user
429
-
430
- def create_user_subscription(email, tier):
431
- """Create a subscription for a user"""
432
- try:
433
- # Get user by email
434
- conn = get_db_connection()
435
- c = conn.cursor()
436
-
437
- # Get user ID
438
- c.execute("SELECT id FROM users WHERE email = ?", (email,))
439
- user_data = c.fetchone()
440
-
441
- if not user_data:
442
- return False, "User not found"
443
-
444
- user_id = user_data['id']
445
-
446
- # Check if tier is valid
447
- valid_tiers = ["standard_tier", "premium_tier"]
448
- if tier not in valid_tiers:
449
- return False, f"Invalid tier: {tier}. Must be one of {valid_tiers}"
450
-
451
- # Create subscription
452
- subscription_id = str(uuid.uuid4())
453
- created_at = datetime.now()
454
- expires_at = created_at + timedelta(days=30) # 30-day subscription
455
-
456
- # Insert subscription
457
- c.execute("""
458
- INSERT INTO subscriptions
459
- (id, user_id, tier, status, created_at, expires_at)
460
- VALUES (?, ?, ?, ?, ?, ?)
461
- """, (subscription_id, user_id, tier, "active", created_at, expires_at))
462
-
463
- # Update user's subscription tier
464
- c.execute("""
465
- UPDATE users
466
- SET subscription_tier = ?
467
- WHERE id = ?
468
- """, (tier, user_id))
469
-
470
- conn.commit()
471
-
472
- return True, {
473
- "id": subscription_id,
474
- "user_id": user_id,
475
- "tier": tier,
476
- "status": "active",
477
- "created_at": created_at.isoformat(),
478
- "expires_at": expires_at.isoformat()
479
- }
480
- except Exception as e:
481
- logger.error(f"Subscription creation error: {e}")
482
- return False, f"Failed to create subscription: {str(e)}"
483
- finally:
484
- if conn:
485
- conn.close()
486
-
487
- def get_user(user_id: str):
488
- """Get user by ID"""
489
- try:
490
- conn = get_db_connection()
491
- c = conn.cursor()
492
-
493
- # Get user
494
- c.execute("SELECT * FROM users WHERE id = ? AND is_active = 1", (user_id,))
495
- user_data = c.fetchone()
496
-
497
- if not user_data:
498
- return None
499
-
500
- # Convert to User model
501
- user_dict = dict(user_data)
502
-
503
- # Handle datetime conversions if needed
504
- if user_dict.get("subscription_expiry") and isinstance(user_dict["subscription_expiry"], str):
505
- user_dict["subscription_expiry"] = datetime.fromisoformat(user_dict["subscription_expiry"])
506
- if user_dict.get("last_reset_date") and isinstance(user_dict["last_reset_date"], str):
507
- user_dict["last_reset_date"] = datetime.fromisoformat(user_dict["last_reset_date"])
508
-
509
- return User(
510
- id=user_dict['id'],
511
- email=user_dict['email'],
512
- subscription_tier=user_dict['subscription_tier'],
513
- subscription_expiry=user_dict.get('subscription_expiry'),
514
- api_calls_remaining=user_dict.get('api_calls_remaining', 5),
515
- last_reset_date=user_dict.get('last_reset_date')
516
- )
517
- except Exception as e:
518
- logger.error(f"Get user error: {e}")
519
- return None
520
- finally:
521
- if conn:
522
- conn.close()
523
-
524
- def check_subscription_access(user: User, feature: str, file_size_mb: Optional[float] = None):
525
- """Check if the user has access to the requested feature and file size"""
526
- # Check if subscription is expired
527
- if user.subscription_tier != "free_tier" and user.subscription_expiry and user.subscription_expiry < datetime.now():
528
- # Downgrade to free tier if subscription expired
529
- user.subscription_tier = "free_tier"
530
- user.api_calls_remaining = SUBSCRIPTION_TIERS["free_tier"]["daily_api_calls"]
531
- with get_db_connection() as conn:
532
- c = conn.cursor()
533
- c.execute("""
534
- UPDATE users
535
- SET subscription_tier = ?, api_calls_remaining = ?
536
- WHERE id = ?
537
- """, (user.subscription_tier, user.api_calls_remaining, user.id))
538
- conn.commit()
539
-
540
- # Reset API calls if needed
541
- user = reset_api_calls_if_needed(user)
542
-
543
- # Check if user has API calls remaining
544
- if user.api_calls_remaining <= 0:
545
- raise HTTPException(
546
- status_code=429,
547
- detail="API call limit reached for today. Please upgrade your subscription or try again tomorrow."
548
- )
549
-
550
- # Check if feature is available in user's subscription tier
551
- tier_features = SUBSCRIPTION_TIERS[user.subscription_tier]["features"]
552
- if feature not in tier_features:
553
- raise HTTPException(
554
- status_code=403,
555
- detail=f"The {feature} feature is not available in your {user.subscription_tier} subscription. Please upgrade to access this feature."
556
- )
557
-
558
- # Check file size limit if applicable
559
- if file_size_mb:
560
- max_size = SUBSCRIPTION_TIERS[user.subscription_tier]["max_document_size_mb"]
561
- if file_size_mb > max_size:
562
- raise HTTPException(
563
- status_code=413,
564
- detail=f"File size exceeds the {max_size}MB limit for your {user.subscription_tier} subscription. Please upgrade or use a smaller file."
565
- )
566
-
567
- # Decrement API calls remaining
568
- user.api_calls_remaining -= 1
569
- with get_db_connection() as conn:
570
- c = conn.cursor()
571
- c.execute("""
572
- UPDATE users
573
- SET api_calls_remaining = ?
574
- WHERE id = ?
575
- """, (user.api_calls_remaining, user.id))
576
- conn.commit()
577
-
578
- return True
579
-
580
- def reset_api_calls_if_needed(user: User):
581
- """Reset API call counter if it's a new day"""
582
- today = datetime.now().date()
583
- if user.last_reset_date is None or user.last_reset_date.date() < today:
584
- tier_limits = SUBSCRIPTION_TIERS[user.subscription_tier]
585
- user.api_calls_remaining = tier_limits["daily_api_calls"]
586
- user.last_reset_date = datetime.now()
587
- # Update the user in the database
588
- with get_db_connection() as conn:
589
- c = conn.cursor()
590
- c.execute("""
591
- UPDATE users
592
- SET api_calls_remaining = ?, last_reset_date = ?
593
- WHERE id = ?
594
- """, (user.api_calls_remaining, user.last_reset_date, user.id))
595
- conn.commit()
596
-
597
- return user
598
-
599
- def login_user(email, password):
600
- """Login a user with email and password"""
601
- try:
602
- # Authenticate user
603
- user = authenticate_user(email, password)
604
- if not user:
605
- return False, "Incorrect username or password"
606
-
607
- # Create access token
608
- access_token = create_access_token(user.id)
609
-
610
- # Create refresh token
611
- refresh_token = str(uuid.uuid4())
612
- expires_at = datetime.now() + timedelta(days=30)
613
-
614
- # Store refresh token
615
- conn = get_db_connection()
616
- c = conn.cursor()
617
- c.execute("INSERT INTO refresh_tokens VALUES (?, ?, ?)",
618
- (user.id, refresh_token, expires_at))
619
- conn.commit()
620
-
621
- # Get subscription info
622
- c.execute("SELECT * FROM subscriptions WHERE user_id = ? AND status = 'active'", (user.id,))
623
- subscription = c.fetchone()
624
-
625
- # Convert subscription to dict if it exists, otherwise set to None
626
- subscription_dict = dict(subscription) if subscription else None
627
-
628
- conn.close()
629
-
630
- return True, {
631
- "user_id": user.id,
632
- "email": user.email,
633
- "access_token": access_token,
634
- "refresh_token": refresh_token,
635
- "subscription": subscription_dict
636
- }
637
- except Exception as e:
638
- logger.error(f"Login error: {e}")
639
- return False, f"Login failed: {str(e)}"
640
-
641
-
642
- def get_subscription_plans():
643
- """
644
- Returns a list of available subscription plans based on SUBSCRIPTION_TIERS.
645
- """
646
- plans = []
647
- for tier, details in SUBSCRIPTION_TIERS.items():
648
- plans.append({
649
- "tier": tier,
650
- "price": details["price"],
651
- "currency": details["currency"],
652
- "features": details["features"],
653
- "limits": details["limits"]
654
- })
 
655
  return plans
 
1
+ import sqlite3
2
+
3
+ import uuid
4
+ import os
5
+ import logging
6
+ from datetime import datetime, timedelta
7
+ import hashlib # Use hashlib instead of jwt
8
+ from passlib.hash import bcrypt
9
+ from dotenv import load_dotenv
10
+ from fastapi import Depends, HTTPException
11
+ from fastapi.security import OAuth2PasswordBearer
12
+ from pydantic import BaseModel
13
+ from typing import Optional
14
+ from fastapi import HTTPException, status
15
+ import jwt
16
+ from jose import JWTError
17
+ import sqlite3
18
+
19
+ # Load environment variables
20
+ load_dotenv()
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
26
+ )
27
+ logger = logging.getLogger('auth')
28
+
29
+ # Security configuration
30
+ SECRET_KEY = os.getenv("JWT_SECRET", "your-secret-key-for-development-only")
31
+ ALGORITHM = "HS256"
32
+ JWT_EXPIRATION_DELTA = timedelta(days=1) # Token valid for 1 day
33
+ # Database path from environment variable or default
34
+ # Fix the incorrect DB_PATH
35
+ DB_PATH = os.getenv("DB_PATH", os.path.join(os.path.dirname(__file__), "data/user_data.db"))
36
+
37
+ # FastAPI OAuth2 scheme
38
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
39
+
40
+ # Pydantic models for FastAPI
41
+ class User(BaseModel):
42
+ id: str
43
+ email: str
44
+ subscription_tier: str = "free_tier"
45
+ subscription_expiry: Optional[datetime] = None
46
+ api_calls_remaining: int = 5
47
+ last_reset_date: Optional[datetime] = None
48
+
49
+ class UserCreate(BaseModel):
50
+ email: str
51
+ password: str
52
+
53
+ class Token(BaseModel):
54
+ access_token: str
55
+ token_type: str
56
+
57
+ class TokenData(BaseModel):
58
+ user_id: Optional[str] = None
59
+
60
+ # Subscription tiers and limits
61
+ # Update the SUBSCRIPTION_TIERS dictionary
62
+ SUBSCRIPTION_TIERS = {
63
+ "free_tier": {
64
+ "price": 0,
65
+ "currency": "INR",
66
+ "features": ["basic_document_analysis", "basic_risk_assessment"],
67
+ "limits": {
68
+ "document_size_mb": 5,
69
+ "documents_per_month": 3,
70
+ "video_size_mb": 0,
71
+ "audio_size_mb": 0,
72
+ "daily_api_calls": 10, # <-- Add this
73
+ "max_document_size_mb": 5 # <-- Add this
74
+ }
75
+ },
76
+ "standard_tier": {
77
+ "price": 799,
78
+ "currency": "INR",
79
+ "features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot"],
80
+ "limits": {
81
+ "document_size_mb": 20,
82
+ "documents_per_month": 20,
83
+ "video_size_mb": 100,
84
+ "audio_size_mb": 50,
85
+ "daily_api_calls": 100, # <-- Add this
86
+ "max_document_size_mb": 20 # <-- Add this
87
+ }
88
+ },
89
+ "premium_tier": {
90
+ "price": 1499,
91
+ "currency": "INR",
92
+ "features": ["basic_document_analysis", "basic_risk_assessment", "video_analysis", "audio_analysis", "chatbot", "detailed_risk_assessment", "contract_clause_analysis"],
93
+ "limits": {
94
+ "document_size_mb": 50,
95
+ "documents_per_month": 999999,
96
+ "video_size_mb": 500,
97
+ "audio_size_mb": 200,
98
+ "daily_api_calls": 1000, # <-- Add this
99
+ "max_document_size_mb": 50 # <-- Add this
100
+ }
101
+ }
102
+ }
103
+
104
+ # Database connection management
105
+ def get_db_connection():
106
+ """Create and return a database connection with proper error handling"""
107
+ try:
108
+ # Ensure the directory exists
109
+ db_dir = os.path.dirname(DB_PATH)
110
+ os.makedirs(db_dir, exist_ok=True)
111
+
112
+ conn = sqlite3.connect(DB_PATH)
113
+ conn.row_factory = sqlite3.Row # Return rows as dictionaries
114
+ return conn
115
+ except sqlite3.Error as e:
116
+ logger.error(f"Database connection error: {e}")
117
+ raise Exception(f"Database connection failed: {e}")
118
+
119
+ # Database setup
120
+ # In the init_auth_db function, update the CREATE TABLE statement to match our schema
121
+ def init_auth_db():
122
+ """Initialize the authentication database with required tables"""
123
+ try:
124
+ conn = get_db_connection()
125
+ c = conn.cursor()
126
+
127
+ # Create users table with the correct schema
128
+ c.execute('''
129
+ CREATE TABLE IF NOT EXISTS users (
130
+ id TEXT PRIMARY KEY,
131
+ email TEXT UNIQUE NOT NULL,
132
+ hashed_password TEXT NOT NULL,
133
+ password TEXT,
134
+ subscription_tier TEXT DEFAULT 'free_tier',
135
+ is_active BOOLEAN DEFAULT 1,
136
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
137
+ api_calls_remaining INTEGER DEFAULT 10,
138
+ last_reset_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP
139
+ )
140
+ ''')
141
+
142
+ # Create subscriptions table
143
+ c.execute('''
144
+ CREATE TABLE IF NOT EXISTS subscriptions (
145
+ id TEXT PRIMARY KEY,
146
+ user_id TEXT,
147
+ tier TEXT,
148
+ plan_id TEXT,
149
+ status TEXT,
150
+ created_at TIMESTAMP,
151
+ expires_at TIMESTAMP,
152
+ paypal_subscription_id TEXT,
153
+ FOREIGN KEY (user_id) REFERENCES users (id)
154
+ )
155
+ ''')
156
+
157
+ # Create usage stats table
158
+ c.execute('''
159
+ CREATE TABLE IF NOT EXISTS usage_stats (
160
+ id TEXT PRIMARY KEY,
161
+ user_id TEXT,
162
+ month INTEGER,
163
+ year INTEGER,
164
+ analyses_used INTEGER,
165
+ FOREIGN KEY (user_id) REFERENCES users (id)
166
+ )
167
+ ''')
168
+
169
+ # Create tokens table for refresh tokens
170
+ c.execute('''
171
+ CREATE TABLE IF NOT EXISTS refresh_tokens (
172
+ user_id TEXT,
173
+ token TEXT,
174
+ expires_at TIMESTAMP,
175
+ FOREIGN KEY (user_id) REFERENCES users (id)
176
+ )
177
+ ''')
178
+
179
+ conn.commit()
180
+ logger.info("Database initialized successfully")
181
+ except Exception as e:
182
+ logger.error(f"Database initialization error: {e}")
183
+ raise
184
+ finally:
185
+ if conn:
186
+ conn.close()
187
+
188
+ # Initialize the database
189
+ init_auth_db()
190
+
191
+ # Password hashing with bcrypt
192
+ # Update the password hashing and verification functions to use a more reliable method
193
+
194
+ # Replace these functions
195
+ # Remove these conflicting functions
196
+ # def hash_password(password):
197
+ # """Hash a password using bcrypt"""
198
+ # return bcrypt.hash(password)
199
+ #
200
+ # def verify_password(plain_password, hashed_password):
201
+ # """Verify a password against its hash"""
202
+ # return bcrypt.verify(plain_password, hashed_password)
203
+
204
+ # Keep only these improved functions
205
+ def hash_password(password):
206
+ """Hash a password using bcrypt"""
207
+ # Use a more direct approach to avoid bcrypt version issues
208
+ import bcrypt
209
+ # Convert password to bytes if it's not already
210
+ if isinstance(password, str):
211
+ password = password.encode('utf-8')
212
+ # Generate salt and hash
213
+ salt = bcrypt.gensalt()
214
+ hashed = bcrypt.hashpw(password, salt)
215
+ # Return as string for storage
216
+ return hashed.decode('utf-8')
217
+
218
+ def verify_password(plain_password, hashed_password):
219
+ """Verify a password against its hash"""
220
+ import bcrypt
221
+ # Convert inputs to bytes if they're not already
222
+ if isinstance(plain_password, str):
223
+ plain_password = plain_password.encode('utf-8')
224
+ if isinstance(hashed_password, str):
225
+ hashed_password = hashed_password.encode('utf-8')
226
+
227
+ try:
228
+ # Use direct bcrypt verification
229
+ return bcrypt.checkpw(plain_password, hashed_password)
230
+ except Exception as e:
231
+ logger.error(f"Password verification error: {e}")
232
+ return False
233
+
234
+ # User registration
235
+ def register_user(email, password):
236
+ try:
237
+ conn = get_db_connection()
238
+ c = conn.cursor()
239
+
240
+ # Check if user already exists
241
+ c.execute("SELECT * FROM users WHERE email = ?", (email,))
242
+ if c.fetchone():
243
+ return False, "Email already registered"
244
+
245
+ # Create new user
246
+ user_id = str(uuid.uuid4())
247
+
248
+ # Add more detailed logging
249
+ logger.info(f"Registering new user with email: {email}")
250
+ hashed_pw = hash_password(password)
251
+ logger.info(f"Password hashed successfully: {bool(hashed_pw)}")
252
+
253
+ c.execute("""
254
+ INSERT INTO users
255
+ (id, email, hashed_password, subscription_tier, api_calls_remaining, last_reset_date)
256
+ VALUES (?, ?, ?, ?, ?, ?)
257
+ """, (user_id, email, hashed_pw, "free_tier", 5, datetime.now()))
258
+
259
+ conn.commit()
260
+ logger.info(f"User registered successfully: {email}")
261
+
262
+ # Verify the user was actually stored
263
+ c.execute("SELECT * FROM users WHERE email = ?", (email,))
264
+ stored_user = c.fetchone()
265
+ logger.info(f"User verification after registration: {bool(stored_user)}")
266
+
267
+ access_token = create_access_token(user_id)
268
+ return True, {
269
+ "user_id": user_id,
270
+ "access_token": access_token,
271
+ "token_type": "bearer"
272
+ }
273
+ except Exception as e:
274
+ logger.error(f"User registration error: {e}")
275
+ return False, f"Registration failed: {str(e)}"
276
+ finally:
277
+ if conn:
278
+ conn.close()
279
+
280
+ # User login
281
+ # Fix the authenticate_user function
282
+ # In the authenticate_user function, update the password verification to use hashed_password
283
+ def authenticate_user(email, password):
284
+ """Authenticate a user and return user data with tokens"""
285
+ try:
286
+ conn = get_db_connection()
287
+ c = conn.cursor()
288
+
289
+ # Get user by email
290
+ c.execute("SELECT * FROM users WHERE email = ? AND is_active = 1", (email,))
291
+ user = c.fetchone()
292
+
293
+ if not user:
294
+ logger.warning(f"User not found: {email}")
295
+ return None
296
+
297
+ # Add debug logging for password verification
298
+ logger.info(f"Verifying password for user: {email}")
299
+ logger.info(f"Stored hashed password: {user['hashed_password'][:20]}...")
300
+
301
+ try:
302
+ # Check if password verification works
303
+ is_valid = verify_password(password, user['hashed_password'])
304
+ logger.info(f"Password verification result: {is_valid}")
305
+
306
+ if not is_valid:
307
+ logger.warning(f"Password verification failed for user: {email}")
308
+ return None
309
+ except Exception as e:
310
+ logger.error(f"Password verification error: {e}")
311
+ return None
312
+
313
+ # Update last login time if column exists
314
+ try:
315
+ c.execute("UPDATE users SET last_login = ? WHERE id = ?",
316
+ (datetime.now(), user['id']))
317
+ conn.commit()
318
+ except sqlite3.OperationalError:
319
+ # last_login column might not exist
320
+ pass
321
+
322
+ # Convert sqlite3.Row to dict to use get() method
323
+ user_dict = dict(user)
324
+
325
+ # Create and return a User object
326
+ return User(
327
+ id=user_dict['id'],
328
+ email=user_dict['email'],
329
+ subscription_tier=user_dict.get('subscription_tier', 'free_tier'),
330
+ subscription_expiry=None, # Handle this properly if needed
331
+ api_calls_remaining=user_dict.get('api_calls_remaining', 5),
332
+ last_reset_date=user_dict.get('last_reset_date')
333
+ )
334
+ except Exception as e:
335
+ logger.error(f"Login error: {e}")
336
+ return None
337
+ finally:
338
+ if conn:
339
+ conn.close()
340
+
341
+ # Token generation and validation - completely replaced
342
+ def create_access_token(user_id):
343
+ """Create a new access token for a user"""
344
+ try:
345
+ # Create a JWT token with user_id and expiration
346
+ expiration = datetime.now() + JWT_EXPIRATION_DELTA
347
+
348
+ # Create a token payload
349
+ payload = {
350
+ "sub": user_id,
351
+ "exp": expiration.timestamp()
352
+ }
353
+
354
+ # Generate the JWT token
355
+ token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
356
+
357
+ logger.info(f"Created access token for user: {user_id}")
358
+ return token
359
+ except Exception as e:
360
+ logger.error(f"Token creation error: {e}")
361
+ return None
362
+
363
+
364
+ def update_auth_db_schema():
365
+ """Update the authentication database schema with any missing columns"""
366
+ try:
367
+ conn = get_db_connection()
368
+ c = conn.cursor()
369
+
370
+ # Check if tier column exists in subscriptions table
371
+ c.execute("PRAGMA table_info(subscriptions)")
372
+ columns = [column[1] for column in c.fetchall()]
373
+
374
+ # Add tier column if it doesn't exist
375
+ if "tier" not in columns:
376
+ logger.info("Adding 'tier' column to subscriptions table")
377
+ c.execute("ALTER TABLE subscriptions ADD COLUMN tier TEXT")
378
+ conn.commit()
379
+ logger.info("Database schema updated successfully")
380
+
381
+ conn.close()
382
+ except Exception as e:
383
+ logger.error(f"Database schema update error: {e}")
384
+ raise HTTPException(
385
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
386
+ detail=f"Database schema update error: {str(e)}"
387
+ )
388
+
389
+ # Add this to your get_current_user function
390
+ async def get_current_user(token: str = Depends(oauth2_scheme)):
391
+ credentials_exception = HTTPException(
392
+ status_code=status.HTTP_401_UNAUTHORIZED,
393
+ detail="Could not validate credentials",
394
+ headers={"WWW-Authenticate": "Bearer"},
395
+ )
396
+ try:
397
+ # Decode the JWT token
398
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
399
+ user_id: str = payload.get("sub")
400
+ if user_id is None:
401
+ logger.error("Token missing 'sub' field")
402
+ raise credentials_exception
403
+ except Exception as e:
404
+ logger.error(f"Token validation error: {str(e)}")
405
+ raise credentials_exception
406
+
407
+ # Get user from database
408
+ conn = get_db_connection()
409
+ cursor = conn.cursor()
410
+ cursor.execute("SELECT id, email, subscription_tier, is_active FROM users WHERE id = ?", (user_id,))
411
+ user_data = cursor.fetchone()
412
+ conn.close()
413
+
414
+ if user_data is None:
415
+ logger.error(f"User not found: {user_id}")
416
+ raise credentials_exception
417
+
418
+ user = User(
419
+ id=user_data[0],
420
+ email=user_data[1],
421
+ subscription_tier=user_data[2],
422
+ is_active=bool(user_data[3])
423
+ )
424
+
425
+ return user
426
+
427
+ async def get_current_active_user(current_user: User = Depends(get_current_user)):
428
+ """Get the current active user"""
429
+ return current_user
430
+
431
+ def create_user_subscription(email, tier):
432
+ """Create a subscription for a user"""
433
+ try:
434
+ # Get user by email
435
+ conn = get_db_connection()
436
+ c = conn.cursor()
437
+
438
+ # Get user ID
439
+ c.execute("SELECT id FROM users WHERE email = ?", (email,))
440
+ user_data = c.fetchone()
441
+
442
+ if not user_data:
443
+ return False, "User not found"
444
+
445
+ user_id = user_data['id']
446
+
447
+ # Check if tier is valid
448
+ valid_tiers = ["standard_tier", "premium_tier"]
449
+ if tier not in valid_tiers:
450
+ return False, f"Invalid tier: {tier}. Must be one of {valid_tiers}"
451
+
452
+ # Create subscription
453
+ subscription_id = str(uuid.uuid4())
454
+ created_at = datetime.now()
455
+ expires_at = created_at + timedelta(days=30) # 30-day subscription
456
+
457
+ # Insert subscription
458
+ c.execute("""
459
+ INSERT INTO subscriptions
460
+ (id, user_id, tier, status, created_at, expires_at)
461
+ VALUES (?, ?, ?, ?, ?, ?)
462
+ """, (subscription_id, user_id, tier, "active", created_at, expires_at))
463
+
464
+ # Update user's subscription tier
465
+ c.execute("""
466
+ UPDATE users
467
+ SET subscription_tier = ?
468
+ WHERE id = ?
469
+ """, (tier, user_id))
470
+
471
+ conn.commit()
472
+
473
+ return True, {
474
+ "id": subscription_id,
475
+ "user_id": user_id,
476
+ "tier": tier,
477
+ "status": "active",
478
+ "created_at": created_at.isoformat(),
479
+ "expires_at": expires_at.isoformat()
480
+ }
481
+ except Exception as e:
482
+ logger.error(f"Subscription creation error: {e}")
483
+ return False, f"Failed to create subscription: {str(e)}"
484
+ finally:
485
+ if conn:
486
+ conn.close()
487
+
488
+ def get_user(user_id: str):
489
+ """Get user by ID"""
490
+ try:
491
+ conn = get_db_connection()
492
+ c = conn.cursor()
493
+
494
+ # Get user
495
+ c.execute("SELECT * FROM users WHERE id = ? AND is_active = 1", (user_id,))
496
+ user_data = c.fetchone()
497
+
498
+ if not user_data:
499
+ return None
500
+
501
+ # Convert to User model
502
+ user_dict = dict(user_data)
503
+
504
+ # Handle datetime conversions if needed
505
+ if user_dict.get("subscription_expiry") and isinstance(user_dict["subscription_expiry"], str):
506
+ user_dict["subscription_expiry"] = datetime.fromisoformat(user_dict["subscription_expiry"])
507
+ if user_dict.get("last_reset_date") and isinstance(user_dict["last_reset_date"], str):
508
+ user_dict["last_reset_date"] = datetime.fromisoformat(user_dict["last_reset_date"])
509
+
510
+ return User(
511
+ id=user_dict['id'],
512
+ email=user_dict['email'],
513
+ subscription_tier=user_dict['subscription_tier'],
514
+ subscription_expiry=user_dict.get('subscription_expiry'),
515
+ api_calls_remaining=user_dict.get('api_calls_remaining', 5),
516
+ last_reset_date=user_dict.get('last_reset_date')
517
+ )
518
+ except Exception as e:
519
+ logger.error(f"Get user error: {e}")
520
+ return None
521
+ finally:
522
+ if conn:
523
+ conn.close()
524
+
525
+ def check_subscription_access(user: User, feature: str, file_size_mb: Optional[float] = None):
526
+ """Check if the user has access to the requested feature and file size"""
527
+ # Check if subscription is expired
528
+ if user.subscription_tier != "free_tier" and user.subscription_expiry and user.subscription_expiry < datetime.now():
529
+ # Downgrade to free tier if subscription expired
530
+ user.subscription_tier = "free_tier"
531
+ user.api_calls_remaining = SUBSCRIPTION_TIERS["free_tier"]["daily_api_calls"]
532
+ with get_db_connection() as conn:
533
+ c = conn.cursor()
534
+ c.execute("""
535
+ UPDATE users
536
+ SET subscription_tier = ?, api_calls_remaining = ?
537
+ WHERE id = ?
538
+ """, (user.subscription_tier, user.api_calls_remaining, user.id))
539
+ conn.commit()
540
+
541
+ # Reset API calls if needed
542
+ user = reset_api_calls_if_needed(user)
543
+
544
+ # Check if user has API calls remaining
545
+ if user.api_calls_remaining <= 0:
546
+ raise HTTPException(
547
+ status_code=429,
548
+ detail="API call limit reached for today. Please upgrade your subscription or try again tomorrow."
549
+ )
550
+
551
+ # Check if feature is available in user's subscription tier
552
+ tier_features = SUBSCRIPTION_TIERS[user.subscription_tier]["features"]
553
+ if feature not in tier_features:
554
+ raise HTTPException(
555
+ status_code=403,
556
+ detail=f"The {feature} feature is not available in your {user.subscription_tier} subscription. Please upgrade to access this feature."
557
+ )
558
+
559
+ # Check file size limit if applicable
560
+ if file_size_mb:
561
+ max_size = SUBSCRIPTION_TIERS[user.subscription_tier]["max_document_size_mb"]
562
+ if file_size_mb > max_size:
563
+ raise HTTPException(
564
+ status_code=413,
565
+ detail=f"File size exceeds the {max_size}MB limit for your {user.subscription_tier} subscription. Please upgrade or use a smaller file."
566
+ )
567
+
568
+ # Decrement API calls remaining
569
+ user.api_calls_remaining -= 1
570
+ with get_db_connection() as conn:
571
+ c = conn.cursor()
572
+ c.execute("""
573
+ UPDATE users
574
+ SET api_calls_remaining = ?
575
+ WHERE id = ?
576
+ """, (user.api_calls_remaining, user.id))
577
+ conn.commit()
578
+
579
+ return True
580
+
581
+ def reset_api_calls_if_needed(user: User):
582
+ """Reset API call counter if it's a new day"""
583
+ today = datetime.now().date()
584
+ if user.last_reset_date is None or user.last_reset_date.date() < today:
585
+ tier_limits = SUBSCRIPTION_TIERS[user.subscription_tier]
586
+ user.api_calls_remaining = tier_limits["daily_api_calls"]
587
+ user.last_reset_date = datetime.now()
588
+ # Update the user in the database
589
+ with get_db_connection() as conn:
590
+ c = conn.cursor()
591
+ c.execute("""
592
+ UPDATE users
593
+ SET api_calls_remaining = ?, last_reset_date = ?
594
+ WHERE id = ?
595
+ """, (user.api_calls_remaining, user.last_reset_date, user.id))
596
+ conn.commit()
597
+
598
+ return user
599
+
600
+ def login_user(email, password):
601
+ """Login a user with email and password"""
602
+ try:
603
+ # Authenticate user
604
+ user = authenticate_user(email, password)
605
+ if not user:
606
+ return False, "Incorrect username or password"
607
+
608
+ # Create access token
609
+ access_token = create_access_token(user.id)
610
+
611
+ # Create refresh token
612
+ refresh_token = str(uuid.uuid4())
613
+ expires_at = datetime.now() + timedelta(days=30)
614
+
615
+ # Store refresh token
616
+ conn = get_db_connection()
617
+ c = conn.cursor()
618
+ c.execute("INSERT INTO refresh_tokens VALUES (?, ?, ?)",
619
+ (user.id, refresh_token, expires_at))
620
+ conn.commit()
621
+
622
+ # Get subscription info
623
+ c.execute("SELECT * FROM subscriptions WHERE user_id = ? AND status = 'active'", (user.id,))
624
+ subscription = c.fetchone()
625
+
626
+ # Convert subscription to dict if it exists, otherwise set to None
627
+ subscription_dict = dict(subscription) if subscription else None
628
+
629
+ conn.close()
630
+
631
+ return True, {
632
+ "user_id": user.id,
633
+ "email": user.email,
634
+ "access_token": access_token,
635
+ "refresh_token": refresh_token,
636
+ "subscription": subscription_dict
637
+ }
638
+ except Exception as e:
639
+ logger.error(f"Login error: {e}")
640
+ return False, f"Login failed: {str(e)}"
641
+
642
+
643
+ def get_subscription_plans():
644
+ """
645
+ Returns a list of available subscription plans based on SUBSCRIPTION_TIERS.
646
+ """
647
+ plans = []
648
+ for tier, details in SUBSCRIPTION_TIERS.items():
649
+ plans.append({
650
+ "tier": tier,
651
+ "price": details["price"],
652
+ "currency": details["currency"],
653
+ "features": details["features"],
654
+ "limits": details["limits"]
655
+ })
656
  return plans