mominah commited on
Commit
392c4bc
·
verified ·
1 Parent(s): d7e3cc8

Update auth.py

Browse files
Files changed (1) hide show
  1. auth.py +33 -25
auth.py CHANGED
@@ -7,13 +7,14 @@ from typing import List, Optional, Any
7
 
8
  from dotenv import load_dotenv
9
  from fastapi import APIRouter, HTTPException, Depends, Request, UploadFile, File, Form
10
- from fastapi.responses import JSONResponse
11
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
12
  from jose import JWTError, jwt
13
  from passlib.context import CryptContext
14
  from pydantic import BaseModel, EmailStr, Field, validator
15
  from pymongo import MongoClient
16
  import os.path
 
17
 
18
  load_dotenv()
19
 
@@ -27,6 +28,8 @@ MONGO_URL = os.getenv("CONNECTION_STRING").replace("${PASSWORD}", password)
27
  client = MongoClient(MONGO_URL)
28
  db = client.users_database
29
  users_collection = db.users
 
 
30
 
31
  # OAuth2 setup
32
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@@ -132,12 +135,8 @@ def get_current_user(token: str = Depends(oauth2_scheme)) -> dict:
132
  except JWTError:
133
  raise HTTPException(status_code=401, detail="Invalid token")
134
 
135
- # Setup for avatar file saving
136
- AVATAR_DIR = "avatars"
137
- if not os.path.exists(AVATAR_DIR):
138
- os.makedirs(AVATAR_DIR)
139
-
140
- def save_avatar_file(file: UploadFile) -> str:
141
  allowed_types = ["image/jpeg", "image/png", "image/gif"]
142
  if file.content_type not in allowed_types:
143
  logger.error(f"Unsupported file type: {file.content_type}")
@@ -145,20 +144,14 @@ def save_avatar_file(file: UploadFile) -> str:
145
  status_code=400,
146
  detail="Invalid image format. Only JPEG, PNG, and GIF are accepted."
147
  )
148
- file_extension = os.path.splitext(file.filename)[1]
149
- unique_filename = f"{uuid.uuid4()}{file_extension}"
150
- file_path = os.path.join(AVATAR_DIR, unique_filename)
151
  try:
152
- contents = file.file.read()
153
- with open(file_path, "wb") as f:
154
- f.write(contents)
155
- logger.info(f"Avatar saved as {file_path}")
156
  except Exception as e:
157
- logger.exception("Failed to save avatar file")
158
- raise HTTPException(status_code=500, detail="Could not save avatar file.")
159
- finally:
160
- file.file.close()
161
- return file_path
162
 
163
  # ----- Auth Endpoints -----
164
 
@@ -186,8 +179,8 @@ async def signup(
186
  "chat_histories": []
187
  }
188
  if avatar:
189
- avatar_path = save_avatar_file(avatar)
190
- user_data["avatar"] = f"/avatars/{os.path.basename(avatar_path)}"
191
  users_collection.insert_one(user_data)
192
  logger.info(f"New user registered: {email}")
193
  return {
@@ -203,20 +196,26 @@ async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends
203
  logger.warning(f"Failed login attempt for: {form_data.username}")
204
  raise HTTPException(status_code=401, detail="Incorrect username or password")
205
  logger.info(f"User logged in: {user['email']}")
 
 
 
206
  return {
207
  "access_token": create_access_token(user["email"]),
208
  "refresh_token": create_refresh_token(user["email"]),
209
  "token_type": "bearer",
210
  "name": user["name"],
211
- "avatar": user.get("avatar")
212
  }
213
 
214
  @router.get("/user/data")
215
  async def get_user_data(request: Request, current_user: dict = Depends(get_current_user)):
 
 
 
216
  return {
217
  "name": current_user["name"],
218
  "email": current_user["email"],
219
- "avatar": current_user.get("avatar"),
220
  "chat_histories": current_user.get("chat_histories", [])
221
  }
222
 
@@ -242,8 +241,8 @@ async def update_user(
242
  raise HTTPException(status_code=400, detail=str(e))
243
  update_data["hashed_password"] = get_password_hash(password)
244
  if avatar:
245
- avatar_path = save_avatar_file(avatar)
246
- update_data["avatar"] = f"/avatars/{os.path.basename(avatar_path)}"
247
  if not update_data:
248
  logger.info("No update parameters provided")
249
  raise HTTPException(status_code=400, detail="No update parameters provided")
@@ -255,3 +254,12 @@ async def update_user(
255
  async def logout(request: Request, current_user: dict = Depends(get_current_user)):
256
  logger.info(f"User logged out: {current_user['email']}")
257
  return {"message": "User logged out successfully"}
 
 
 
 
 
 
 
 
 
 
7
 
8
  from dotenv import load_dotenv
9
  from fastapi import APIRouter, HTTPException, Depends, Request, UploadFile, File, Form
10
+ from fastapi.responses import JSONResponse, StreamingResponse
11
  from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
12
  from jose import JWTError, jwt
13
  from passlib.context import CryptContext
14
  from pydantic import BaseModel, EmailStr, Field, validator
15
  from pymongo import MongoClient
16
  import os.path
17
+ import gridfs # For storing binary files in MongoDB
18
 
19
  load_dotenv()
20
 
 
28
  client = MongoClient(MONGO_URL)
29
  db = client.users_database
30
  users_collection = db.users
31
+ # Create a GridFS instance for storing avatars in the "avatars" collection
32
+ fs = gridfs.GridFS(db, collection="avatars")
33
 
34
  # OAuth2 setup
35
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 
135
  except JWTError:
136
  raise HTTPException(status_code=401, detail="Invalid token")
137
 
138
+ # Instead of saving avatar to disk, store it in GridFS.
139
+ async def save_avatar_file_to_gridfs(file: UploadFile) -> str:
 
 
 
 
140
  allowed_types = ["image/jpeg", "image/png", "image/gif"]
141
  if file.content_type not in allowed_types:
142
  logger.error(f"Unsupported file type: {file.content_type}")
 
144
  status_code=400,
145
  detail="Invalid image format. Only JPEG, PNG, and GIF are accepted."
146
  )
 
 
 
147
  try:
148
+ contents = await file.read()
149
+ file_id = fs.put(contents, filename=file.filename, contentType=file.content_type)
150
+ logger.info(f"Avatar stored in GridFS with file_id: {file_id}")
151
+ return str(file_id)
152
  except Exception as e:
153
+ logger.exception("Failed to store avatar in GridFS")
154
+ raise HTTPException(status_code=500, detail="Could not store avatar file in MongoDB.")
 
 
 
155
 
156
  # ----- Auth Endpoints -----
157
 
 
179
  "chat_histories": []
180
  }
181
  if avatar:
182
+ file_id = await save_avatar_file_to_gridfs(avatar)
183
+ user_data["avatar"] = file_id # store the GridFS file id
184
  users_collection.insert_one(user_data)
185
  logger.info(f"New user registered: {email}")
186
  return {
 
196
  logger.warning(f"Failed login attempt for: {form_data.username}")
197
  raise HTTPException(status_code=401, detail="Incorrect username or password")
198
  logger.info(f"User logged in: {user['email']}")
199
+ avatar_url = None
200
+ if "avatar" in user and user["avatar"]:
201
+ avatar_url = f"/auth/avatar/{user['avatar']}"
202
  return {
203
  "access_token": create_access_token(user["email"]),
204
  "refresh_token": create_refresh_token(user["email"]),
205
  "token_type": "bearer",
206
  "name": user["name"],
207
+ "avatar": avatar_url
208
  }
209
 
210
  @router.get("/user/data")
211
  async def get_user_data(request: Request, current_user: dict = Depends(get_current_user)):
212
+ avatar_url = None
213
+ if "avatar" in current_user and current_user["avatar"]:
214
+ avatar_url = f"/auth/avatar/{current_user['avatar']}"
215
  return {
216
  "name": current_user["name"],
217
  "email": current_user["email"],
218
+ "avatar": avatar_url,
219
  "chat_histories": current_user.get("chat_histories", [])
220
  }
221
 
 
241
  raise HTTPException(status_code=400, detail=str(e))
242
  update_data["hashed_password"] = get_password_hash(password)
243
  if avatar:
244
+ file_id = await save_avatar_file_to_gridfs(avatar)
245
+ update_data["avatar"] = file_id
246
  if not update_data:
247
  logger.info("No update parameters provided")
248
  raise HTTPException(status_code=400, detail="No update parameters provided")
 
254
  async def logout(request: Request, current_user: dict = Depends(get_current_user)):
255
  logger.info(f"User logged out: {current_user['email']}")
256
  return {"message": "User logged out successfully"}
257
+
258
+ @router.get("/avatar/{file_id}")
259
+ async def get_avatar(file_id: str):
260
+ try:
261
+ file = fs.get(file_id)
262
+ return StreamingResponse(file, media_type=file.content_type)
263
+ except Exception as e:
264
+ logger.error(f"Avatar not found for file_id {file_id}: {e}")
265
+ raise HTTPException(status_code=404, detail="Avatar not found")