mominah commited on
Commit
0f75bc4
·
verified ·
1 Parent(s): 45773d5

Create auth.py

Browse files
Files changed (1) hide show
  1. auth.py +276 -0
auth.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import logging
4
+ from datetime import datetime, timedelta
5
+ from urllib.parse import quote_plus
6
+ from typing import Optional
7
+
8
+ from dotenv import load_dotenv
9
+ from fastapi import APIRouter, HTTPException, Depends, Request, status, UploadFile, File, Form
10
+ from fastapi.responses import JSONResponse
11
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
12
+ from slowapi import Limiter, _rate_limit_exceeded_handler
13
+ from slowapi.util import get_remote_address
14
+ from jose import JWTError, jwt
15
+ from passlib.context import CryptContext
16
+ from pydantic import BaseModel, EmailStr, Field, validator
17
+ from pymongo import MongoClient
18
+ import os.path
19
+
20
+ load_dotenv()
21
+
22
+ # Setup logging
23
+ logger = logging.getLogger("uvicorn")
24
+ logger.setLevel(logging.INFO)
25
+
26
+ # MongoDB setup for user management
27
+ password = quote_plus(os.getenv("MONGO_PASSWORD"))
28
+ MONGO_URL = os.getenv("MONGO_ENDPOINT").replace("${PASSWORD}", password)
29
+ client = MongoClient(MONGO_URL)
30
+ db = client.users_database
31
+ users_collection = db.users
32
+
33
+ # Rate limiter (applied to auth endpoints)
34
+ limiter = Limiter(key_func=get_remote_address, default_limits=["200 per day", "50 per hour"])
35
+
36
+ # OAuth2 setup
37
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
38
+
39
+ # Create an APIRouter instance
40
+ router = APIRouter()
41
+
42
+ # Pydantic models
43
+ class User(BaseModel):
44
+ name: str = Field(..., min_length=3, max_length=50)
45
+ email: EmailStr
46
+ password: str
47
+
48
+ @validator("password")
49
+ def validate_password(cls, value):
50
+ if len(value) < 8:
51
+ raise ValueError("Password must be at least 8 characters long.")
52
+ if not any(char.isdigit() for char in value):
53
+ raise ValueError("Password must include at least one number.")
54
+ if not any(char.isupper() for char in value):
55
+ raise ValueError("Password must include at least one uppercase letter.")
56
+ if not any(char.islower() for char in value):
57
+ raise ValueError("Password must include at least one lowercase letter.")
58
+ if not any(char in "!@#$%^&*()-_+=<>?/" for char in value):
59
+ raise ValueError("Password must include at least one special character.")
60
+ return value
61
+
62
+ class UserUpdate(BaseModel):
63
+ name: Optional[str] = Field(None, min_length=3, max_length=50)
64
+ email: Optional[EmailStr]
65
+ password: Optional[str]
66
+
67
+ @validator("password")
68
+ def validate_password(cls, value):
69
+ if value is not None:
70
+ if len(value) < 8:
71
+ raise ValueError("Password must be at least 8 characters long.")
72
+ if not any(char.isdigit() for char in value):
73
+ raise ValueError("Password must include at least one number.")
74
+ if not any(char.isupper() for char in value):
75
+ raise ValueError("Password must include at least one uppercase letter.")
76
+ if not any(char.islower() for char in value):
77
+ raise ValueError("Password must include at least one lowercase letter.")
78
+ if not any(char in "!@#$%^&*()-_+=<>?/" for char in value):
79
+ raise ValueError("Password must include at least one special character.")
80
+ return value
81
+
82
+ class Token(BaseModel):
83
+ access_token: str
84
+ refresh_token: str
85
+ token_type: str
86
+
87
+ class LoginResponse(Token):
88
+ name: str
89
+ avatar: Optional[str] = None
90
+
91
+ class TokenData(BaseModel):
92
+ email: Optional[str] = None
93
+
94
+ # Password hashing
95
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
96
+
97
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
98
+ return pwd_context.verify(plain_password, hashed_password)
99
+
100
+ def get_password_hash(password: str) -> str:
101
+ return pwd_context.hash(password)
102
+
103
+ def get_user(email: str) -> Optional[dict]:
104
+ return users_collection.find_one({"email": email})
105
+
106
+ def authenticate_user(email: str, password: str) -> Optional[dict]:
107
+ user = get_user(email)
108
+ if not user or not verify_password(password, user["hashed_password"]):
109
+ return None
110
+ return user
111
+
112
+ def create_token(data: dict, expires_delta: timedelta = None) -> str:
113
+ to_encode = data.copy()
114
+ expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
115
+ to_encode.update({"exp": expire})
116
+ secret_key = os.getenv("SECRET_KEY")
117
+ algorithm = "HS256"
118
+ return jwt.encode(to_encode, secret_key, algorithm=algorithm)
119
+
120
+ def create_access_token(email: str) -> str:
121
+ return create_token({"sub": email}, timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "90"))))
122
+
123
+ def create_refresh_token(email: str) -> str:
124
+ return create_token({"sub": email}, timedelta(days=int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))))
125
+
126
+ def get_current_user(token: str = Depends(oauth2_scheme)) -> dict:
127
+ secret_key = os.getenv("SECRET_KEY")
128
+ try:
129
+ payload = jwt.decode(token, secret_key, algorithms=["HS256"])
130
+ email: str = payload.get("sub")
131
+ if not email:
132
+ raise HTTPException(status_code=401, detail="Invalid credentials")
133
+ user = get_user(email)
134
+ if not user:
135
+ raise HTTPException(status_code=401, detail="User not found")
136
+ return user
137
+ except JWTError:
138
+ raise HTTPException(status_code=401, detail="Invalid token")
139
+
140
+ # Setup for avatar file saving
141
+ AVATAR_DIR = "avatars"
142
+ if not os.path.exists(AVATAR_DIR):
143
+ os.makedirs(AVATAR_DIR)
144
+
145
+ def save_avatar_file(file: UploadFile) -> str:
146
+ allowed_types = ["image/jpeg", "image/png", "image/gif"]
147
+ if file.content_type not in allowed_types:
148
+ logger.error(f"Unsupported file type: {file.content_type}")
149
+ raise HTTPException(
150
+ status_code=400,
151
+ detail="Invalid image format. Only JPEG, PNG, and GIF are accepted."
152
+ )
153
+ file_extension = os.path.splitext(file.filename)[1]
154
+ unique_filename = f"{uuid.uuid4()}{file_extension}"
155
+ file_path = os.path.join(AVATAR_DIR, unique_filename)
156
+ try:
157
+ contents = file.file.read()
158
+ with open(file_path, "wb") as f:
159
+ f.write(contents)
160
+ logger.info(f"Avatar saved as {file_path}")
161
+ except Exception as e:
162
+ logger.exception("Failed to save avatar file")
163
+ raise HTTPException(status_code=500, detail="Could not save avatar file.")
164
+ finally:
165
+ file.file.close()
166
+ return file_path
167
+
168
+ # ----- Auth Endpoints -----
169
+
170
+ @router.post("/signup", response_model=Token)
171
+ @limiter.limit("5/minute")
172
+ async def signup(
173
+ request: Request,
174
+ name: str = Form(...),
175
+ email: EmailStr = Form(...),
176
+ password: str = Form(...),
177
+ avatar: Optional[UploadFile] = File(None)
178
+ ):
179
+ # Validate input using the User model
180
+ try:
181
+ _ = User(name=name, email=email, password=password)
182
+ except Exception as e:
183
+ logger.error(f"Validation error during signup: {e}")
184
+ raise HTTPException(status_code=400, detail=str(e))
185
+ if get_user(email):
186
+ logger.warning(f"Attempt to register already existing email: {email}")
187
+ raise HTTPException(status_code=400, detail="Email already registered")
188
+ hashed_password = get_password_hash(password)
189
+ user_data = {
190
+ "name": name,
191
+ "email": email,
192
+ "hashed_password": hashed_password,
193
+ "chat_histories": [] # Initialize an empty array for chat histories and future data.
194
+ }
195
+ if avatar:
196
+ avatar_path = save_avatar_file(avatar)
197
+ user_data["avatar"] = f"/avatars/{os.path.basename(avatar_path)}"
198
+ users_collection.insert_one(user_data)
199
+ logger.info(f"New user registered: {email}")
200
+ return {
201
+ "access_token": create_access_token(email),
202
+ "refresh_token": create_refresh_token(email),
203
+ "token_type": "bearer"
204
+ }
205
+
206
+ @router.post("/login", response_model=LoginResponse)
207
+ @limiter.limit("10/minute")
208
+ async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
209
+ user = authenticate_user(form_data.username, form_data.password)
210
+ if not user:
211
+ logger.warning(f"Failed login attempt for: {form_data.username}")
212
+ raise HTTPException(status_code=401, detail="Incorrect username or password")
213
+ logger.info(f"User logged in: {user['email']}")
214
+ return {
215
+ "access_token": create_access_token(user["email"]),
216
+ "refresh_token": create_refresh_token(user["email"]),
217
+ "token_type": "bearer",
218
+ "name": user["name"],
219
+ "avatar": user.get("avatar")
220
+ }
221
+
222
+ @router.get("/user/data")
223
+ @limiter.limit("20/minute")
224
+ async def get_user_data(request: Request, current_user: dict = Depends(get_current_user)):
225
+ return {
226
+ "name": current_user["name"],
227
+ "email": current_user["email"],
228
+ "avatar": current_user.get("avatar"),
229
+ "chat_histories": current_user.get("chat_histories", [])
230
+ }
231
+
232
+ @router.put("/user/update")
233
+ @limiter.limit("10/minute")
234
+ async def update_user(
235
+ request: Request,
236
+ name: Optional[str] = Form(None),
237
+ email: Optional[EmailStr] = Form(None),
238
+ password: Optional[str] = Form(None),
239
+ avatar: Optional[UploadFile] = File(None),
240
+ current_user: dict = Depends(get_current_user)
241
+ ):
242
+ update_data = {}
243
+ if name is not None:
244
+ update_data["name"] = name
245
+ if email is not None:
246
+ update_data["email"] = email
247
+ if password is not None:
248
+ try:
249
+ _ = User(name=current_user["name"], email=current_user["email"], password=password)
250
+ except Exception as e:
251
+ logger.error(f"Password validation error during update: {e}")
252
+ raise HTTPException(status_code=400, detail=str(e))
253
+ update_data["hashed_password"] = get_password_hash(password)
254
+ if avatar:
255
+ avatar_path = save_avatar_file(avatar)
256
+ update_data["avatar"] = f"/avatars/{os.path.basename(avatar_path)}"
257
+ if not update_data:
258
+ logger.info("No update parameters provided")
259
+ raise HTTPException(status_code=400, detail="No update parameters provided")
260
+ users_collection.update_one({"email": current_user["email"]}, {"$set": update_data})
261
+ logger.info(f"User updated: {current_user['email']}")
262
+ return {"message": "User updated successfully"}
263
+
264
+ @router.post("/logout")
265
+ @limiter.limit("20/minute")
266
+ async def logout(request: Request, current_user: dict = Depends(get_current_user)):
267
+ logger.info(f"User logged out: {current_user['email']}")
268
+ return {"message": "User logged out successfully"}
269
+
270
+ @router.exception_handler(Exception)
271
+ async def generic_exception_handler(request: Request, exc):
272
+ logger.exception("Unhandled exception occurred in auth module")
273
+ return JSONResponse(
274
+ status_code=500,
275
+ content={"detail": "An internal server error occurred."}
276
+ )