Spaces:
Running
Running
import sys | |
import os | |
import time | |
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, status | |
from fastapi.responses import FileResponse | |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
import uvicorn | |
import traceback | |
import pickle | |
import shutil | |
from pathlib import Path | |
from contextlib import asynccontextmanager | |
import pandas as pd | |
from typing import Annotated, Optional, Union | |
from datetime import datetime, timedelta, timezone | |
import jwt | |
from jwt.exceptions import InvalidTokenError | |
from passlib.context import CryptContext | |
from pydantic import BaseModel | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(os.path.join(current_dir, "meisai-check-ai")) | |
from sentence_transformer_lib.sentence_transformer_helper import ( | |
SentenceTransformerHelper, | |
) | |
from data_lib.input_name_data import InputNameData | |
from data_lib.subject_data import SubjectData | |
from data_lib.sample_name_data import SampleNameData | |
from clustering_lib.sentence_clustering_lib import SentenceClusteringLib | |
from data_lib.base_data import ( | |
COL_STANDARD_NAME, | |
COL_STANDARD_NAME_KEY, | |
COL_STANDARD_SUBJECT, | |
) | |
from mapping_lib.name_mapping_helper import NameMappingHelper | |
# Initialize global variables for model and data | |
sentenceTransformerHelper = None | |
dic_standard_subject = None | |
sample_name_sentence_embeddings = None | |
sample_name_sentence_similarities = None | |
sampleData = None | |
sentence_clustering_lib = None | |
name_groups = None | |
# Create data directory if it doesn't exist | |
os.makedirs(os.path.join(current_dir, "data"), exist_ok=True) | |
os.makedirs(os.path.join(current_dir, "uploads"), exist_ok=True) | |
os.makedirs(os.path.join(current_dir, "outputs"), exist_ok=True) | |
# Authentication related settings | |
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_HOURS = 24 # Token expiration set to 24 hours | |
# Password hashing context | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
# OAuth2 scheme for token | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
# User database models | |
class Token(BaseModel): | |
access_token: str | |
token_type: str | |
class TokenData(BaseModel): | |
username: Optional[str] = None | |
class User(BaseModel): | |
username: str | |
email: Optional[str] = None | |
full_name: Optional[str] = None | |
disabled: Optional[bool] = None | |
class UserInDB(User): | |
hashed_password: str | |
# Fake users database with hashed passwords | |
users_db = { | |
"chien_vm": { | |
"username": "chien_vm", | |
"full_name": "Chien VM", | |
"email": "[email protected]", | |
"hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi", | |
"disabled": False, | |
}, | |
"hoi_nv": { | |
"username": "hoi_nv", | |
"full_name": "Hoi NV", | |
"email": "[email protected]", | |
"hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi", | |
"disabled": False, | |
} | |
} | |
# Authentication helper functions | |
def verify_password(plain_password, hashed_password): | |
return pwd_context.verify(plain_password, hashed_password) | |
def get_user(db, username: str): | |
if username in db: | |
user_dict = db[username] | |
return UserInDB(**user_dict) | |
return None | |
def authenticate_user(fake_db, username: str, password: str): | |
user = get_user(fake_db, username) | |
if not user: | |
return False | |
if not verify_password(password, user.hashed_password): | |
return False | |
return user | |
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): | |
to_encode = data.copy() | |
if expires_delta: | |
expire = datetime.now(timezone.utc) + expires_delta | |
else: | |
expire = datetime.now(timezone.utc) + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS) | |
to_encode.update({"exp": expire}) | |
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
return encoded_jwt | |
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): | |
credentials_exception = HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username = payload.get("sub") | |
if username is None: | |
raise credentials_exception | |
token_data = TokenData(username=username) | |
except InvalidTokenError: | |
raise credentials_exception | |
user = get_user(users_db, username=token_data.username) | |
if user is None: | |
raise credentials_exception | |
return user | |
async def get_current_active_user( | |
current_user: Annotated[User, Depends(get_current_user)], | |
): | |
if current_user.disabled: | |
raise HTTPException(status_code=400, detail="Inactive user") | |
return current_user | |
async def lifespan(app: FastAPI): | |
"""Lifespan context manager for startup and shutdown events""" | |
global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings | |
global sample_name_sentence_similarities, sampleData, sentence_clustering_lib, name_groups | |
try: | |
# Load sentence transformer model | |
sentenceTransformerHelper = SentenceTransformerHelper( | |
convert_to_zenkaku_flag=True, replace_words=None, keywords=None | |
) | |
sentenceTransformerHelper.load_model_by_name( | |
"Detomo/cl-nagoya-sup-simcse-ja-for-standard-name-v1_0" | |
) | |
# Load standard subject dictionary | |
dic_standard_subject = SubjectData.create_standard_subject_dic_from_file( | |
"data/subjectData.csv" | |
) | |
# Load pre-computed embeddings and similarities | |
with open( | |
f"data/sample_name_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl", | |
"rb", | |
) as f: | |
sample_name_sentence_embeddings = pickle.load(f) | |
with open( | |
f"data/sample_name_sentence_similarities(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl", | |
"rb", | |
) as f: | |
sample_name_sentence_similarities = pickle.load(f) | |
# Load and process sample data | |
sampleData = SampleNameData() | |
file_path = os.path.join(current_dir, "data", "sampleData.csv") | |
sampleData.load_data_from_csv(file_path) | |
sampleData.process_data() | |
# Create sentence clusters | |
sentence_clustering_lib = SentenceClusteringLib(sample_name_sentence_embeddings) | |
best_name_eps = 0.07 | |
name_groups, _ = sentence_clustering_lib.create_sentence_cluster(best_name_eps) | |
sampleData._create_key_column( | |
COL_STANDARD_NAME_KEY, COL_STANDARD_SUBJECT, COL_STANDARD_NAME | |
) | |
sampleData.set_name_sentence_labels(name_groups) | |
sampleData.build_search_tree() | |
print("Models and data loaded successfully") | |
except Exception as e: | |
print(f"Error during startup: {e}") | |
traceback.print_exc() | |
yield # This is where the app runs | |
# Cleanup code (if needed) goes here | |
print("Shutting down application") | |
app = FastAPI(lifespan=lifespan) | |
async def root(): | |
return {"message": "Hello World"} | |
async def health_check(): | |
return {"status": "ok", "timestamp": time.time()} | |
async def login_for_access_token( | |
form_data: Annotated[OAuth2PasswordRequestForm, Depends()] | |
) -> Token: | |
""" | |
Login endpoint to get an access token | |
""" | |
user = authenticate_user(users_db, form_data.username, form_data.password) | |
if not user: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Incorrect username or password", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
access_token_expires = timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS) | |
access_token = create_access_token( | |
data={"sub": user.username}, expires_delta=access_token_expires | |
) | |
return Token(access_token=access_token, token_type="bearer") | |
async def predict( | |
current_user: Annotated[User, Depends(get_current_active_user)], | |
file: UploadFile = File(...) | |
): | |
""" | |
Process an input CSV file and return standardized names (requires authentication) | |
""" | |
global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings | |
global sample_name_sentence_similarities, sampleData, name_groups | |
if not file.filename.endswith(".csv"): | |
raise HTTPException(status_code=400, detail="Only CSV files are supported") | |
# Save uploaded file | |
timestamp = int(time.time()) | |
input_file_path = os.path.join(current_dir, "uploads", f"input_{timestamp}_{current_user.username}.csv") | |
# Use CSV format with correct extension | |
output_file_path = os.path.join(current_dir, "outputs", f"output_{timestamp}_{current_user.username}.csv") | |
try: | |
with open(input_file_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
finally: | |
file.file.close() | |
try: | |
# Process input data | |
inputData = InputNameData(dic_standard_subject) | |
inputData.load_data_from_csv(input_file_path) | |
inputData.process_data() | |
# Map standard names | |
nameMappingHelper = NameMappingHelper( | |
sentenceTransformerHelper, | |
inputData, | |
sampleData, | |
sample_name_sentence_embeddings, | |
sample_name_sentence_similarities, | |
) | |
df_predicted = nameMappingHelper.map_standard_names() | |
# Create output dataframe and save to CSV | |
print("Columns of inputData.dataframe", inputData.dataframe.columns) | |
column_to_keep = ['シート名', '行', '科目', '分類', '名称', '摘要', '備考'] | |
output_df = inputData.dataframe[column_to_keep].copy() | |
output_df.reset_index(drop=False, inplace=True) | |
output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"] | |
output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"] | |
output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"] | |
# Save with utf_8_sig encoding for Japanese Excel compatibility | |
output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig") | |
# Return the file as a download with correct content type and headers | |
return FileResponse( | |
path=output_file_path, | |
filename=f"output_{Path(file.filename).stem}.csv", | |
media_type="text/csv", | |
headers={ | |
"Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"', | |
"Content-Type": "application/x-www-form-urlencoded", | |
}, | |
) | |
except Exception as e: | |
print(f"Error processing file: {e}") | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |