meisaicheck-api / routes /predict.py
vumichien's picture
change logic from sentence name to representative name
01ae535
raw
history blame
3.64 kB
import os
import time
import shutil
from pathlib import Path
from fastapi import APIRouter, UploadFile, File, HTTPException, Depends
from fastapi.responses import FileResponse
from auth import get_current_user
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
from data_lib.input_name_data import InputNameData
from data_lib.base_data import COL_NAME_SENTENCE
from mapping_lib.name_mapping_helper import NameMappingHelper
from config import UPLOAD_DIR, OUTPUT_DIR
router = APIRouter()
@router.post("/predict")
async def predict(
current_user=Depends(get_current_user),
file: UploadFile = File(...),
sentence_service: SentenceTransformerService = Depends(lambda: sentence_transformer_service)
):
"""
Process an input CSV file and return standardized names (requires authentication)
"""
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="Only CSV files are supported")
# Save uploaded file
timestamp = int(time.time())
input_file_path = os.path.join(UPLOAD_DIR, f"input_{timestamp}_{current_user.username}.csv")
output_file_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}_{current_user.username}.csv")
try:
with open(input_file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
finally:
file.file.close()
try:
# Process input data
inputData = InputNameData(sentence_service.dic_standard_subject)
inputData.load_data_from_csv(input_file_path)
inputData.process_data()
input_name_sentences = inputData.dataframe[COL_NAME_SENTENCE]
input_name_sentence_embeddings = sentence_service.sentenceTransformerHelper.create_embeddings(input_name_sentences)
# Create similarity matrix
similarity_matrix = sentence_service.sentenceTransformerHelper.create_similarity_matrix_from_embeddings(
sentence_service.sample_name_sentence_embeddings,
input_name_sentence_embeddings
)
# Map standard names
nameMappingHelper = NameMappingHelper(
sentence_service.sentenceTransformerHelper,
inputData,
sentence_service.sampleData,
input_name_sentence_embeddings,
sentence_service.sample_name_sentence_embeddings,
similarity_matrix,
)
df_predicted = nameMappingHelper.map_standard_names()
# Create output dataframe and save to CSV
column_to_keep = ['シート名', '行', '科目', '分類', '名称', '摘要', '備考']
output_df = inputData.dataframe[column_to_keep].copy()
output_df.reset_index(drop=False, inplace=True)
output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"]
output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"]
output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"]
# Save with utf_8_sig encoding for Japanese Excel compatibility
output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig")
return FileResponse(
path=output_file_path,
filename=f"output_{Path(file.filename).stem}.csv",
media_type="text/csv",
headers={
"Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"',
"Content-Type": "application/x-www-form-urlencoded",
},
)
except Exception as e:
print(f"Error processing file: {e}")
raise HTTPException(status_code=500, detail=str(e))