import os import time import shutil from pathlib import Path from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Body from fastapi.responses import FileResponse from auth import get_current_user from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service from data_lib.input_name_data import InputNameData from data_lib.base_name_data import COL_NAME_SENTENCE from mapping_lib.subject_mapper import SubjectMapper from mapping_lib.name_mapper import NameMapper from config import UPLOAD_DIR, OUTPUT_DIR from models import ( EmbeddingRequest, PredictRawRequest, PredictRawResponse, PredictRecord, PredictResult, ) import pandas as pd import traceback router = APIRouter() @router.post("/predict") async def predict( current_user=Depends(get_current_user), file: UploadFile = File(...), sentence_service: SentenceTransformerService = Depends(lambda: sentence_transformer_service) ): """ Process an input CSV file and return standardized names (requires authentication) """ if not file.filename.endswith(".csv"): raise HTTPException(status_code=400, detail="Only CSV files are supported") # Save uploaded file timestamp = int(time.time()) input_file_path = os.path.join(UPLOAD_DIR, f"input_{timestamp}_{current_user.username}.csv") output_file_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}_{current_user.username}.csv") try: with open(input_file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) finally: file.file.close() try: # Process input data start_time = time.time() try: inputData = InputNameData() inputData.load_data_from_csv(input_file_path) except Exception as e: print(f"Error processing load data: {e}") raise HTTPException(status_code=500, detail=str(e)) try: subject_mapper = SubjectMapper( sentence_transformer_helper=sentence_service.sentenceTransformerHelper, dic_subject_map=sentence_service.dic_standard_subject, similarity_threshold=0.9, ) dic_subject_map = subject_mapper.map_standard_subjects(inputData.dataframe) except Exception as e: print(f"Error processing SubjectMapper: {e}") raise HTTPException(status_code=500, detail=str(e)) try: inputData.dic_standard_subject = dic_subject_map inputData.process_data() except Exception as e: print(f"Error processing inputData process_data: {e}") raise HTTPException(status_code=500, detail=str(e)) # Map standard names try: nameMapper = NameMapper( sentence_service.sentenceTransformerHelper, sentence_service.standardNameMapData, top_count=3 ) df_predicted = nameMapper.predict(inputData) except Exception as e: print(f"Error mapping standard names: {e}") traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) # Create output dataframe and save to CSV # column_to_keep = ['ファイル名', 'シート名', '行', '科目', '中科目', '分類', '名称', '摘要', '備考'] column_to_keep = ['シート名', '行', '科目', '中科目', '分類', '名称', '摘要', '備考', '確定'] output_df = inputData.dataframe[column_to_keep].copy() output_df.reset_index(drop=False, inplace=True) output_df.loc[:, "出力_科目"] = df_predicted["標準科目"] output_df.loc[:, "出力_項目名"] = df_predicted["標準項目名"] output_df.loc[:, "出力_確率度"] = df_predicted["基準名称類似度"] # Save with utf_8_sig encoding for Japanese Excel compatibility output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig") end_time = time.time() execution_time = end_time - start_time print(f"Execution time: {execution_time} seconds") return FileResponse( path=output_file_path, filename=f"output_{Path(file.filename).stem}.csv", media_type="text/csv", headers={ "Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"', "Content-Type": "application/x-www-form-urlencoded", }, ) except Exception as e: print(f"Error processing file: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/embeddings") async def create_embeddings( request: EmbeddingRequest, current_user=Depends(get_current_user), sentence_service: SentenceTransformerService = Depends( lambda: sentence_transformer_service ), ): """ Create embeddings for a list of input sentences (requires authentication) """ try: start_time = time.time() embeddings = sentence_service.sentenceTransformerHelper.create_embeddings( request.sentences ) end_time = time.time() execution_time = end_time - start_time print(f"Execution time: {execution_time} seconds") # Convert numpy array to list for JSON serialization embeddings_list = embeddings.tolist() return {"embeddings": embeddings_list} except Exception as e: print(f"Error creating embeddings: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/predict-raw", response_model=PredictRawResponse) async def predict_raw( request: PredictRawRequest, current_user=Depends(get_current_user), sentence_service: SentenceTransformerService = Depends( lambda: sentence_transformer_service ), ): """ Process raw input records and return standardized names (requires authentication) """ try: # Convert input records to DataFrame records_dict = { "科目": [], "中科目": [], "分類": [], "名称": [], "摘要": [], "備考": [], "シート名": [], # Required by BaseNameData but not used "行": [], # Required by BaseNameData but not used } for record in request.records: records_dict["科目"].append(record.subject) records_dict["中科目"].append(record.sub_subject) records_dict["分類"].append(record.name_category) records_dict["名称"].append(record.name) records_dict["摘要"].append(record.abstract or "") records_dict["備考"].append(record.memo or "") records_dict["シート名"].append("") # Placeholder records_dict["行"].append("") # Placeholder df = pd.DataFrame(records_dict) # Process input data try: inputData = InputNameData(sentence_service.dic_standard_subject) # Use _add_raw_data instead of direct assignment inputData._add_raw_data(df) except Exception as e: print(f"Error processing input data: {e}") raise HTTPException(status_code=500, detail=str(e)) try: subject_mapper = SubjectMapper( sentence_transformer_helper=sentence_service.sentenceTransformerHelper, dic_subject_map=sentence_service.dic_standard_subject, similarity_threshold=0.9, ) dic_subject_map = subject_mapper.map_standard_subjects(inputData.dataframe) except Exception as e: print(f"Error processing SubjectMapper: {e}") raise HTTPException(status_code=500, detail=str(e)) try: inputData.dic_standard_subject = dic_subject_map inputData.process_data() except Exception as e: print(f"Error processing inputData process_data: {e}") raise HTTPException(status_code=500, detail=str(e)) # Map standard names try: nameMapper = NameMapper( sentence_service.sentenceTransformerHelper, sentence_service.standardNameMapData, top_count=3 ) df_predicted = nameMapper.predict(inputData) except Exception as e: print(f"Error mapping standard names: {e}") traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) important_columns = ['確定', '標準科目', '標準項目名', '基準名称類似度'] for column in important_columns: if column not in df_predicted.columns: if column != '基準名称類似度': df_predicted[column] = "" inputData.dataframe[column] = "" else: df_predicted[column] = 0 inputData.dataframe[column] = 0 column_to_keep = ['シート名', '行', '科目', '中科目', '分類', '名称', '摘要', '備考', '確定'] output_df = inputData.dataframe[column_to_keep].copy() output_df.reset_index(drop=False, inplace=True) output_df.loc[:, "出力_科目"] = df_predicted["標準科目"] output_df.loc[:, "出力_項目名"] = df_predicted["標準項目名"] output_df.loc[:, "出力_確率度"] = df_predicted["基準名称類似度"] # Convert results to response format results = [] for _, row in output_df.iterrows(): result = PredictResult( subject=row["科目"], sub_subject=row["中科目"], name_category=row["分類"], name=row["名称"], abstract=row["摘要"], memo=row["備考"], confirmed=row["確定"], standard_subject=row["出力_科目"], standard_name=row["出力_項目名"], similarity_score=float(row["出力_確率度"]), ) results.append(result) return PredictRawResponse(results=results) except Exception as e: print(f"Error processing records: {e}") raise HTTPException(status_code=500, detail=str(e))