Spaces:
Running
Running
import sys | |
import os | |
import time | |
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import FileResponse | |
import uvicorn | |
import traceback | |
import pickle | |
import shutil | |
from pathlib import Path | |
from contextlib import asynccontextmanager | |
import pandas as pd | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(os.path.join(current_dir, "meisai-check-ai")) | |
from sentence_transformer_lib.sentence_transformer_helper import ( | |
SentenceTransformerHelper, | |
) | |
from data_lib.input_name_data import InputNameData | |
from data_lib.subject_data import SubjectData | |
from data_lib.sample_name_data import SampleNameData | |
from clustering_lib.sentence_clustering_lib import SentenceClusteringLib | |
from data_lib.base_data import ( | |
COL_STANDARD_NAME, | |
COL_STANDARD_NAME_KEY, | |
COL_STANDARD_SUBJECT, | |
) | |
from mapping_lib.name_mapping_helper import NameMappingHelper | |
# Initialize global variables for model and data | |
sentenceTransformerHelper = None | |
dic_standard_subject = None | |
sample_name_sentence_embeddings = None | |
sample_name_sentence_similarities = None | |
sampleData = None | |
sentence_clustering_lib = None | |
name_groups = None | |
# Create data directory if it doesn't exist | |
os.makedirs(os.path.join(current_dir, "data"), exist_ok=True) | |
os.makedirs(os.path.join(current_dir, "uploads"), exist_ok=True) | |
os.makedirs(os.path.join(current_dir, "outputs"), exist_ok=True) | |
async def lifespan(app: FastAPI): | |
"""Lifespan context manager for startup and shutdown events""" | |
global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings | |
global sample_name_sentence_similarities, sampleData, sentence_clustering_lib, name_groups | |
try: | |
# Load sentence transformer model | |
sentenceTransformerHelper = SentenceTransformerHelper( | |
convert_to_zenkaku_flag=True, replace_words=None, keywords=None | |
) | |
sentenceTransformerHelper.load_model_by_name( | |
"Detomo/cl-nagoya-sup-simcse-ja-for-standard-name-v1_0" | |
) | |
# Load standard subject dictionary | |
dic_standard_subject = SubjectData.create_standard_subject_dic_from_file( | |
"data/subjectData.csv" | |
) | |
# Load pre-computed embeddings and similarities | |
with open( | |
f"data/sample_name_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl", | |
"rb", | |
) as f: | |
sample_name_sentence_embeddings = pickle.load(f) | |
with open( | |
f"data/sample_name_sentence_similarities(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl", | |
"rb", | |
) as f: | |
sample_name_sentence_similarities = pickle.load(f) | |
# Load and process sample data | |
sampleData = SampleNameData() | |
file_path = os.path.join(current_dir, "data", "sampleData.csv") | |
sampleData.load_data_from_csv(file_path) | |
sampleData.process_data() | |
# Create sentence clusters | |
sentence_clustering_lib = SentenceClusteringLib(sample_name_sentence_embeddings) | |
best_name_eps = 0.07 | |
name_groups, _ = sentence_clustering_lib.create_sentence_cluster(best_name_eps) | |
sampleData._create_key_column( | |
COL_STANDARD_NAME_KEY, COL_STANDARD_SUBJECT, COL_STANDARD_NAME | |
) | |
sampleData.set_name_sentence_labels(name_groups) | |
sampleData.build_search_tree() | |
print("Models and data loaded successfully") | |
except Exception as e: | |
print(f"Error during startup: {e}") | |
traceback.print_exc() | |
yield # This is where the app runs | |
# Cleanup code (if needed) goes here | |
print("Shutting down application") | |
app = FastAPI(lifespan=lifespan) | |
async def root(): | |
return {"message": "Hello World"} | |
async def health_check(): | |
return {"status": "ok", "timestamp": time.time()} | |
async def predict(file: UploadFile = File(...)): | |
""" | |
Process an input CSV file and return standardized names | |
""" | |
global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings | |
global sample_name_sentence_similarities, sampleData, name_groups | |
if not file.filename.endswith(".csv"): | |
raise HTTPException(status_code=400, detail="Only CSV files are supported") | |
# Save uploaded file | |
timestamp = int(time.time()) | |
input_file_path = os.path.join(current_dir, "uploads", f"input_{timestamp}.csv") | |
# Use CSV format with correct extension | |
output_file_path = os.path.join(current_dir, "outputs", f"output_{timestamp}.csv") | |
try: | |
with open(input_file_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
finally: | |
file.file.close() | |
try: | |
# Process input data | |
inputData = InputNameData(dic_standard_subject) | |
inputData.load_data_from_csv(input_file_path) | |
inputData.process_data() | |
# Map standard names | |
nameMappingHelper = NameMappingHelper( | |
sentenceTransformerHelper, | |
inputData, | |
sampleData, | |
sample_name_sentence_embeddings, | |
sample_name_sentence_similarities, | |
) | |
df_predicted = nameMappingHelper.map_standard_names() | |
# Create output dataframe and save to CSV - Fix SettingWithCopyWarning by creating a copy | |
# columns_to_keep = ["ファイル名", "シート名", "行", "科目", "名称"] | |
# output_df = inputData.dataframe[columns_to_keep].copy() | |
output_df = inputData.dataframe.copy() | |
print(df_predicted.columns) | |
# Use .loc to avoid SettingWithCopyWarning | |
output_df.loc[:, COL_STANDARD_SUBJECT] = df_predicted[COL_STANDARD_SUBJECT] | |
output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"] | |
output_df.loc[:, "参考_名称"] = df_predicted["参考_名称"] | |
output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"] | |
# Save with utf_8_sig encoding for Japanese Excel compatibility | |
output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig") | |
# Return the file as a download with correct content type and headers | |
return FileResponse( | |
path=output_file_path, | |
filename=f"output_{Path(file.filename).stem}.csv", | |
media_type="text/csv", | |
headers={ | |
"Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"', | |
"Content-Type": "application/x-www-form-urlencoded", | |
}, | |
) | |
except Exception as e: | |
print(f"Error processing file: {e}") | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |