Spaces:
Running
Running
import os | |
import time | |
import shutil | |
from pathlib import Path | |
from fastapi import APIRouter, UploadFile, File, HTTPException, Depends | |
from fastapi.responses import FileResponse | |
from auth import get_current_user | |
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service | |
from data_lib.input_name_data import InputNameData | |
from data_lib.base_data import COL_NAME_SENTENCE | |
from mapping_lib.name_mapping_helper import NameMappingHelper | |
from config import UPLOAD_DIR, OUTPUT_DIR | |
router = APIRouter() | |
async def predict( | |
current_user=Depends(get_current_user), | |
file: UploadFile = File(...), | |
sentence_service: SentenceTransformerService = Depends(lambda: sentence_transformer_service) | |
): | |
""" | |
Process an input CSV file and return standardized names (requires authentication) | |
""" | |
if not file.filename.endswith(".csv"): | |
raise HTTPException(status_code=400, detail="Only CSV files are supported") | |
# Save uploaded file | |
timestamp = int(time.time()) | |
input_file_path = os.path.join(UPLOAD_DIR, f"input_{timestamp}_{current_user.username}.csv") | |
output_file_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}_{current_user.username}.csv") | |
try: | |
with open(input_file_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
finally: | |
file.file.close() | |
try: | |
# Process input data | |
inputData = InputNameData(sentence_service.dic_standard_subject) | |
inputData.load_data_from_csv(input_file_path) | |
inputData.process_data() | |
input_name_sentences = inputData.dataframe[COL_NAME_SENTENCE] | |
input_name_sentence_embeddings = sentence_service.sentenceTransformerHelper.create_embeddings(input_name_sentences) | |
# Create similarity matrix | |
similarity_matrix = sentence_service.sentenceTransformerHelper.create_similarity_matrix_from_embeddings( | |
sentence_service.sample_name_sentence_embeddings, | |
input_name_sentence_embeddings | |
) | |
# Map standard names | |
nameMappingHelper = NameMappingHelper( | |
sentence_service.sentenceTransformerHelper, | |
inputData, | |
sentence_service.sampleData, | |
input_name_sentence_embeddings, | |
sentence_service.sample_name_sentence_embeddings, | |
similarity_matrix, | |
) | |
df_predicted = nameMappingHelper.map_standard_names() | |
# Create output dataframe and save to CSV | |
column_to_keep = ['シート名', '行', '科目', '分類', '名称', '摘要', '備考'] | |
output_df = inputData.dataframe[column_to_keep].copy() | |
output_df.reset_index(drop=False, inplace=True) | |
output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"] | |
output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"] | |
output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"] | |
# Save with utf_8_sig encoding for Japanese Excel compatibility | |
output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig") | |
return FileResponse( | |
path=output_file_path, | |
filename=f"output_{Path(file.filename).stem}.csv", | |
media_type="text/csv", | |
headers={ | |
"Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"', | |
"Content-Type": "application/x-www-form-urlencoded", | |
}, | |
) | |
except Exception as e: | |
print(f"Error processing file: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |