# -*- coding: utf-8 -*- # --- 필요한 모듈 임포트 --- import gradio as gr from transformers import DonutProcessor, VisionEncoderDecoderModel from PIL import Image import torch import re import json import os import warnings # --- 경고 메시지 무시 --- # UserWarning: TypedStorage is deprecated 는 PyTorch 관련 경고로 무시해도 괜찮습니다. warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") # Future or other warnings if needed warnings.filterwarnings("ignore", category=FutureWarning) # --- 모델 및 프로세서 경로 정의 --- # Hugging Face Spaces 저장소 내부에 모델 파일을 복사했다고 가정합니다. # 저장소 루트에 donut_sroie_finetuned 폴더가 있고 그 안에 final_model 이 있는 구조 model_path_finetuned = "greene6517/finetuned_donut_sroie" model_name_base = "naver-clova-ix/donut-base" # Base 모델은 Hub에서 직접 로드 # --- Fine-tuned Processor 및 모델 로딩 --- print(f"Loading Fine-tuned processor from Hub: {model_path_finetuned}") # 로그 메시지도 확인 try: # local_files_only=True 가 없어야 함! model_path_finetuned 변수 사용 확인! processor = DonutProcessor.from_pretrained(model_path_finetuned) print("Successfully loaded fine-tuned processor from Hub.") except Exception as e: print(f"FATAL: Could not load fine-tuned processor from Hub: {e}") exit() print(f"Loading Fine-tuned model from Hub: {model_path_finetuned}") # 로그 메시지도 확인 try: # local_files_only=True 가 없어야 함! model_path_finetuned 변수 사용 확인! model_finetuned = VisionEncoderDecoderModel.from_pretrained(model_path_finetuned) print("Successfully loaded fine-tuned model from Hub.") except Exception as e: print(f"FATAL: Could not load fine-tuned model from Hub: {e}") exit() print(f"Loading Fine-tuned model from: {model_path_finetuned}") try: # local_files_only=True 를 사용하여 Spaces 저장소 내 파일만 사용하도록 강제 model_finetuned = VisionEncoderDecoderModel.from_pretrained(model_path_finetuned, local_files_only=True) print("Successfully loaded fine-tuned model locally from Space repo.") except Exception as e: print(f"Error loading fine-tuned model locally: {e}. Check if model files exist at the path.") # 필요시 Hub에서 로드 시도하는 로직 추가 가능 (단, 모델이 Hub에 업로드 되어 있어야 함) # try: # model_finetuned = VisionEncoderDecoderModel.from_pretrained("your-hf-username/your-model-repo-name") # Hub 경로 예시 # print("Loaded fine-tuned model from Hub as fallback.") # except Exception as e2: # print(f"FATAL: Could not load fine-tuned model locally or from Hub: {e2}") # exit() # 여기서는 로컬 로딩 실패 시 일단 종료하도록 함 (수정 필요시 주석 해제) print(f"FATAL: Could not load fine-tuned model locally: {e}") exit() # --- Base Processor 및 모델 로딩 (Hub에서 직접) --- print(f"Loading Base processor from: {model_name_base}") try: processor_base = DonutProcessor.from_pretrained(model_name_base) print("Successfully loaded base processor.") except Exception as e: print(f"FATAL: Could not load base processor: {e}") exit() print(f"Loading Base model from: {model_name_base}") try: model_base = VisionEncoderDecoderModel.from_pretrained(model_name_base) print("Successfully loaded base model.") except Exception as e: print(f"FATAL: Could not load base model: {e}") exit() # --- 장치 설정 및 모델 이동 --- # Spaces 환경에서는 CPU 또는 할당된 GPU를 사용합니다. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\nUsing device: {device}") # 모델을 해당 장치로 이동 try: model_finetuned.to(device) model_base.to(device) print("Models moved to device.") # 평가 모드 설정 (필수) model_finetuned.eval() model_base.eval() print("Models set to evaluation mode.") except Exception as e: print(f"Error moving models to device or setting eval mode: {e}") exit() # --- Helper function to clean generated sequence (주로 Fine-tuned용) --- def clean_sequence(sequence, processor_to_use, prompt_token_str=None): """Removes prompt, EOS, PAD tokens from a generated sequence.""" cleaned = sequence try: # Standard tokens first eos_token = processor_to_use.tokenizer.eos_token if processor_to_use.tokenizer.eos_token else "" # Default EOS pad_token = processor_to_use.tokenizer.pad_token if processor_to_use.tokenizer.pad_token else "" # Default PAD cleaned = cleaned.replace(eos_token, "").replace(pad_token, "").strip() # Add BOS token removal if it exists and appears if hasattr(processor_to_use.tokenizer, 'bos_token') and processor_to_use.tokenizer.bos_token: cleaned = cleaned.replace(processor_to_use.tokenizer.bos_token, "").strip() # Specific prompt removal (case-insensitive start check can be robust) if prompt_token_str: # Simple startswith check might be enough if prompt is always at the beginning if cleaned.startswith(prompt_token_str): cleaned = cleaned[len(prompt_token_str):].strip() # Regex version (more robust but slightly slower) # cleaned = re.sub(f"^{re.escape(prompt_token_str)}", "", cleaned, flags=re.IGNORECASE).strip() except Exception as e: print(f"Warning: Error during sequence cleaning: {e}") return sequence # Return original if cleaning fails return cleaned # --- Helper function to parse SROIE format --- def token2json_simple(text): """Parses value format into a dictionary.""" output = {} # Regex to find ... patterns, handling potential spaces and newlines in value # It captures the key name (e.g., "company") and the value between the tags. parts = re.findall(r"([\s\S]*?)", text) for key, value in parts: # Strip leading/trailing whitespace from key and value output[key.strip()] = value.strip() # Add info if parsing failed but text was present if not output and text and not text.isspace(): output["parsing_info"] = "Could not parse SROIE key-value pairs from the cleaned sequence." output["cleaned_sequence_preview"] = text[:200] + "..." # Show preview elif not text or text.isspace(): output["parsing_info"] = "Empty sequence after cleaning, nothing to parse." return output # --- 통합 이미지 처리 및 추론 함수 --- # 데코레이터 추가: 그래디언트 계산 비활성화 (추론 시 메모리 절약 및 속도 향상) @torch.no_grad() def process_image_comparison(image_input): if image_input is None: no_image_msg = {"error": "이미지를 업로드해주세요."} # Ensure JSON output for Gradio component return json.dumps(no_image_msg, indent=2, ensure_ascii=False), json.dumps(no_image_msg, indent=2, ensure_ascii=False) try: # Gradio's numpy input needs conversion image = Image.fromarray(image_input).convert("RGB") except Exception as e: error_msg = {"error": f"이미지 변환 오류: {e}"} error_json_str = json.dumps(error_msg, indent=2, ensure_ascii=False) return error_json_str, error_json_str results_ft_json_str = "{}" results_base_json_str = "{}" sequence_ft_raw = "N/A" sequence_base_raw = "N/A" # === Fine-tuned 모델 추론 === try: pixel_values_ft = processor(image, return_tensors="pt").pixel_values.to(device) task_prompt_ft = "" # Fine-tuned 모델의 시작 프롬프트 decoder_input_ids_ft = processor.tokenizer( task_prompt_ft, add_special_tokens=False, return_tensors="pt" ).input_ids.to(device) # 생성 시 필요한 파라미터 설정 generation_config_ft = { "max_length": model_finetuned.config.decoder.max_position_embeddings, "pad_token_id": processor.tokenizer.pad_token_id, "eos_token_id": processor.tokenizer.eos_token_id, "use_cache": True, "bad_words_ids": [[processor.tokenizer.unk_token_id]] if processor.tokenizer.unk_token_id else None, "return_dict_in_generate": True, "decoder_input_ids": decoder_input_ids_ft # 시작 프롬프트 제공 } outputs_ft = model_finetuned.generate(pixel_values_ft, **generation_config_ft) sequence_ft_raw = processor.batch_decode(outputs_ft.sequences)[0] # print(f"\nFine-tuned Raw Output: {sequence_ft_raw}") # 서버 로그에 출력 (디버깅용) # Fine-tuned 모델 결과 클리닝 sequence_ft_cleaned = clean_sequence(sequence_ft_raw, processor, prompt_token_str=task_prompt_ft) # print(f"Fine-tuned Cleaned Output: {sequence_ft_cleaned}") # 서버 로그에 출력 (디버깅용) # 클리닝된 결과 파싱 result_json_ft = token2json_simple(sequence_ft_cleaned) result_json_ft["raw_decoded_sequence_preview"] = sequence_ft_raw[:200] + "..." # 원본 결과 프리뷰 추가 # 최종 JSON 문자열 변환 results_ft_json_str = json.dumps(result_json_ft, indent=2, ensure_ascii=False, sort_keys=False) except Exception as e: print(f"Error during fine-tuned model inference: {e}") import traceback traceback.print_exc() # detailed error log on server results_ft_json_str = json.dumps({ "error": f"Fine-tuned 모델 추론 오류: {e}", "raw_decoded_sequence_before_error": sequence_ft_raw }, indent=2, ensure_ascii=False) # === Base 모델 추론 === try: pixel_values_base = processor_base(image, return_tensors="pt").pixel_values.to(device) # Base 모델용 프롬프트 (예: 또는 다른 일반 문서 프롬프트) # 여기서는 이전 코드와 동일하게 사용 task_prompt_base = "" # Base 모델은 해당 프롬프트 토큰이 없을 수 있으므로 확인 또는 다른 프롬프트 사용 필요 # 여기서는 일단 진행 try: decoder_input_ids_base = processor_base.tokenizer( task_prompt_base, add_special_tokens=False, return_tensors="pt", ).input_ids.to(device) except Exception as tokenizer_e: print(f"Warning: Base processor cannot tokenize prompt '{task_prompt_base}'. Using default generation. Error: {tokenizer_e}") decoder_input_ids_base = None # 프롬프트 없이 생성 # 생성 파라미터 설정 generation_config_base = { "max_length": model_base.config.decoder.max_position_embeddings, "early_stopping": True, "pad_token_id": processor_base.tokenizer.pad_token_id, "eos_token_id": processor_base.tokenizer.eos_token_id, "use_cache": True, "num_beams": 1, # Greedy decoding "bad_words_ids": [[processor_base.tokenizer.unk_token_id]] if processor_base.tokenizer.unk_token_id else None, "return_dict_in_generate": True, } # 프롬프트가 성공적으로 인코딩 되었으면 추가 if decoder_input_ids_base is not None: generation_config_base["decoder_input_ids"] = decoder_input_ids_base outputs_base = model_base.generate(pixel_values_base, **generation_config_base) sequence_base_raw = processor_base.batch_decode(outputs_base.sequences)[0] # print(f"\nBase Raw Output: {sequence_base_raw}") # 서버 로그에 출력 (디버깅용) # Base 모델 결과 클리닝 (skip_special_tokens 사용) sequence_base_cleaned = processor_base.batch_decode(outputs_base.sequences, skip_special_tokens=True)[0] # print(f"Base Cleaned Output (skip_special_tokens): {sequence_base_cleaned}") # 서버 로그에 출력 (디버깅용) # 결과 딕셔너리 생성 result_json_base = { "raw_decoded_sequence_preview": sequence_base_raw[:200] + "...", # 원본 결과 프리뷰 "output_skip_special_tokens": sequence_base_cleaned # 클리닝된 결과 } # 최종 JSON 문자열 변환 results_base_json_str = json.dumps(result_json_base, indent=2, ensure_ascii=False, sort_keys=False) except Exception as e: print(f"Error during base model inference: {e}") import traceback traceback.print_exc() # detailed error log on server results_base_json_str = json.dumps({ "error": f"Base 모델 추론 오류: {e}", "raw_decoded_sequence_before_error": sequence_base_raw # Include raw if available }, indent=2, ensure_ascii=False) # 두 모델의 결과를 JSON 문자열 형태로 반환 return results_ft_json_str, results_base_json_str # --- Gradio 인터페이스 정의 --- # CSS 스타일 정의 custom_css = """ body { background-color: #f0f4f8; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } #main_title { text-align: center; color: #1a5276; font-size: 2.3em; font-weight: 600; margin-top: 20px; margin-bottom: 5px; } #sub_description { text-align: center; color: #566573; font-size: 1.0em; margin-bottom: 25px; } .gradio-container { border-radius: 10px !important; box-shadow: 0 3px 10px rgba(0,0,0,0.08); padding: 25px !important; } footer { display: none !important; } /* Hide Gradio footer */ #output-title-ft, #output-title-base { color: #1a5276; font-weight: 600; margin-bottom: 8px; font-size: 1.2em; border-bottom: 2px solid #aed6f1; padding-bottom: 4px; } #output_row > div.gradio-column { border: 1px solid #d5dbdb; padding: 15px !important; border-radius: 8px; background-color: #ffffff; margin: 0 8px !important; box-shadow: 0 1px 3px rgba(0,0,0,0.04); } #json_output_ft > div:nth-child(2), #json_output_base > div:nth-child(2) { max-height: 600px; overflow-y: auto !important; } /* JSON output scroll */ """ # Gradio Blocks 인터페이스 구성 with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky")) as demo: gr.Markdown("# Donut 모델 비교: Fine-tuned vs Base", elem_id="main_title") gr.Markdown("영수증 이미지를 업로드하면 Fine-tuned 모델(SROIE 파싱)과 Base 모델의 추출 결과를 비교합니다.", elem_id="sub_description") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="numpy", label="🧾 영수증 이미지 업로드") submit_btn = gr.Button("🚀 결과 비교 시작", variant="primary", scale=0) # --- 예제 이미지 부분은 Spaces 환경에서 경로 문제가 있을 수 있어 일단 주석 처리 --- # 만약 예제 이미지를 Space 저장소에 함께 업로드하고 경로를 맞출 수 있다면 주석 해제 가능 example_img_dir = "example" # Space 저장소 루트에 있는 'example' 폴더 지정 # list comprehension 사용하여 존재하는 파일만 목록으로 만듦 example_paths = [os.path.join(example_img_dir, f) for f in ["1.jpg", "2.jpg"] if os.path.exists(os.path.join(example_img_dir, f))] if example_paths: gr.Examples(examples=example_paths, inputs=image_input, label="예제 이미지 클릭 (클릭 후 '결과 비교 시작' 버튼 누르세요)") else: gr.Markdown("_(예제 이미지를 찾을 수 없습니다. 'example' 폴더 확인 필요)_") with gr.Column(scale=2): with gr.Row(elem_id="output_row"): with gr.Column(scale=1): gr.Markdown("### ✨ Fine-tuned Model (SROIE 파싱)", elem_id="output-title-ft") json_output_ft = gr.JSON(label="Fine-tuned 결과 (JSON)", elem_id="json_output_ft") with gr.Column(scale=1): gr.Markdown("### 💡 Base Model (Raw + Cleaned)", elem_id="output-title-base") json_output_base = gr.JSON(label="Base 모델 결과 (JSON)", elem_id="json_output_base") # 버튼 클릭 시 실행할 함수 및 입출력 정의 submit_btn.click( fn=process_image_comparison, inputs=image_input, outputs=[json_output_ft, json_output_base] # 함수가 반환하는 순서대로 컴포넌트 지정 ) # --- Gradio 앱 실행 --- # Hugging Face Spaces 에서 실행될 때는 이 부분이 호출됩니다. if __name__ == "__main__": # share=True 는 Spaces 환경에서는 필요 없습니다. demo.launch()