Hammad712 commited on
Commit
49243b2
·
verified ·
1 Parent(s): e844077

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +208 -0
main.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # auth.py
2
+ import os
3
+ import uuid
4
+ import logging
5
+ from datetime import datetime, timedelta
6
+ from urllib.parse import quote_plus
7
+ from typing import Optional
8
+
9
+ from dotenv import load_dotenv
10
+ from fastapi import APIRouter, HTTPException, Depends, Request, UploadFile, File, Form
11
+ from fastapi.responses import StreamingResponse
12
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
13
+ from jose import JWTError, jwt
14
+ from passlib.context import CryptContext
15
+ from pymongo import MongoClient
16
+ import gridfs
17
+
18
+ from models import User, UserUpdate, Token, LoginResponse
19
+ from config import CONNECTION_STRING, SECRET_KEY, ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS
20
+
21
+ load_dotenv()
22
+
23
+ logger = logging.getLogger("uvicorn")
24
+ logger.setLevel(logging.INFO)
25
+
26
+ # Updated MongoDB initialization: now using CONNECTION_STRING from config.py
27
+ client = MongoClient(CONNECTION_STRING)
28
+ db = client.users_database
29
+ users_collection = db.users
30
+ # GridFS instance for storing avatars
31
+ fs = gridfs.GridFS(db, collection="avatars")
32
+
33
+ # OAuth2 setup
34
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
35
+ router = APIRouter(prefix="/auth", tags=["auth"])
36
+
37
+ # Password hashing
38
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
39
+
40
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
41
+ return pwd_context.verify(plain_password, hashed_password)
42
+
43
+ def get_password_hash(password: str) -> str:
44
+ return pwd_context.hash(password)
45
+
46
+ def get_user(email: str) -> Optional[dict]:
47
+ return users_collection.find_one({"email": email})
48
+
49
+ def authenticate_user(email: str, password: str) -> Optional[dict]:
50
+ user = get_user(email)
51
+ if not user or not verify_password(password, user["hashed_password"]):
52
+ return None
53
+ return user
54
+
55
+ def create_token(data: dict, expires_delta: timedelta = None) -> str:
56
+ to_encode = data.copy()
57
+ expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
58
+ to_encode.update({"exp": expire})
59
+ algorithm = "HS256"
60
+ return jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm)
61
+
62
+ def create_access_token(email: str) -> str:
63
+ return create_token({"sub": email}, timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
64
+
65
+ def create_refresh_token(email: str) -> str:
66
+ return create_token({"sub": email}, timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS))
67
+
68
+ def get_current_user(token: str = Depends(oauth2_scheme)) -> dict:
69
+ try:
70
+ payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
71
+ email: str = payload.get("sub")
72
+ if not email:
73
+ raise HTTPException(status_code=401, detail="Invalid credentials")
74
+ user = get_user(email)
75
+ if not user:
76
+ raise HTTPException(status_code=401, detail="User not found")
77
+ return user
78
+ except JWTError:
79
+ raise HTTPException(status_code=401, detail="Invalid token")
80
+
81
+ async def save_avatar_file_to_gridfs(file: UploadFile) -> str:
82
+ allowed_types = ["image/jpeg", "image/png", "image/gif"]
83
+ if file.content_type not in allowed_types:
84
+ logger.error(f"Unsupported file type: {file.content_type}")
85
+ raise HTTPException(
86
+ status_code=400,
87
+ detail="Invalid image format. Only JPEG, PNG, and GIF are accepted."
88
+ )
89
+ try:
90
+ contents = await file.read()
91
+ file_id = fs.put(contents, filename=file.filename, contentType=file.content_type)
92
+ logger.info(f"Avatar stored in GridFS with file_id: {file_id}")
93
+ return str(file_id)
94
+ except Exception as e:
95
+ logger.exception("Failed to store avatar in GridFS")
96
+ raise HTTPException(status_code=500, detail="Could not store avatar file in MongoDB.")
97
+
98
+ @router.post("/signup", response_model=Token)
99
+ async def signup(
100
+ request: Request,
101
+ name: str = Form(...),
102
+ email: str = Form(...),
103
+ password: str = Form(...),
104
+ avatar: Optional[UploadFile] = File(None)
105
+ ):
106
+ try:
107
+ _ = User(name=name, email=email, password=password)
108
+ except Exception as e:
109
+ logger.error(f"Validation error during signup: {e}")
110
+ raise HTTPException(status_code=400, detail=str(e))
111
+ if get_user(email):
112
+ logger.warning(f"Attempt to register already existing email: {email}")
113
+ raise HTTPException(status_code=400, detail="Email already registered")
114
+ hashed_password = get_password_hash(password)
115
+ user_data = {
116
+ "name": name,
117
+ "email": email,
118
+ "hashed_password": hashed_password,
119
+ "chat_histories": []
120
+ }
121
+ if avatar:
122
+ file_id = await save_avatar_file_to_gridfs(avatar)
123
+ user_data["avatar"] = file_id
124
+ users_collection.insert_one(user_data)
125
+ logger.info(f"New user registered: {email}")
126
+ return {
127
+ "access_token": create_access_token(email),
128
+ "refresh_token": create_refresh_token(email),
129
+ "token_type": "bearer"
130
+ }
131
+
132
+ @router.post("/login", response_model=LoginResponse)
133
+ async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
134
+ user = authenticate_user(form_data.username, form_data.password)
135
+ if not user:
136
+ logger.warning(f"Failed login attempt for: {form_data.username}")
137
+ raise HTTPException(status_code=401, detail="Incorrect username or password")
138
+ logger.info(f"User logged in: {user['email']}")
139
+ avatar_url = None
140
+ if "avatar" in user and user["avatar"]:
141
+ avatar_url = f"/auth/avatar/{user['avatar']}"
142
+ return {
143
+ "access_token": create_access_token(user["email"]),
144
+ "refresh_token": create_refresh_token(user["email"]),
145
+ "token_type": "bearer",
146
+ "name": user["name"],
147
+ "avatar": avatar_url
148
+ }
149
+
150
+ @router.get("/user/data")
151
+ async def get_user_data(request: Request, current_user: dict = Depends(get_current_user)):
152
+ avatar_url = None
153
+ if "avatar" in current_user and current_user["avatar"]:
154
+ avatar_url = f"/auth/avatar/{current_user['avatar']}"
155
+ return {
156
+ "name": current_user["name"],
157
+ "email": current_user["email"],
158
+ "avatar": avatar_url,
159
+ "chat_histories": current_user.get("chat_histories", [])
160
+ }
161
+
162
+ @router.put("/user/update")
163
+ async def update_user(
164
+ request: Request,
165
+ name: Optional[str] = Form(None),
166
+ email: Optional[str] = Form(None),
167
+ password: Optional[str] = Form(None),
168
+ avatar: Optional[UploadFile] = File(None),
169
+ current_user: dict = Depends(get_current_user)
170
+ ):
171
+ update_data = {}
172
+ if name is not None:
173
+ update_data["name"] = name
174
+ if email is not None:
175
+ update_data["email"] = email
176
+ if password is not None:
177
+ try:
178
+ _ = User(name=current_user["name"], email=current_user["email"], password=password)
179
+ except Exception as e:
180
+ logger.error(f"Password validation error during update: {e}")
181
+ raise HTTPException(status_code=400, detail=str(e))
182
+ update_data["hashed_password"] = get_password_hash(password)
183
+ if avatar:
184
+ file_id = await save_avatar_file_to_gridfs(avatar)
185
+ update_data["avatar"] = file_id
186
+ if not update_data:
187
+ logger.info("No update parameters provided")
188
+ raise HTTPException(status_code=400, detail="No update parameters provided")
189
+ users_collection.update_one({"email": current_user["email"]}, {"$set": update_data})
190
+ logger.info(f"User updated: {current_user['email']}")
191
+ return {"message": "User updated successfully"}
192
+
193
+ @router.post("/logout")
194
+ async def logout(request: Request, current_user: dict = Depends(get_current_user)):
195
+ logger.info(f"User logged out: {current_user['email']}")
196
+ return {"message": "User logged out successfully"}
197
+
198
+ from bson import ObjectId
199
+
200
+ @router.get("/avatar/{file_id}")
201
+ async def get_avatar(file_id: str):
202
+ try:
203
+ # Convert the file_id string to an ObjectId before fetching
204
+ file = fs.get(ObjectId(file_id))
205
+ return StreamingResponse(file, media_type=file.content_type)
206
+ except Exception as e:
207
+ logger.error(f"Avatar not found for file_id {file_id}: {e}")
208
+ raise HTTPException(status_code=404, detail="Avatar not found")