import gradio as gr import threading import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments from datasets import load_dataset # GPU가 아닌 CPU에서 실행하도록 설정 device = torch.device("cpu") # IMDb 데이터셋 로딩 dataset = load_dataset("imdb") # 데이터셋의 텍스트 컬럼 자동 감지 text_column = dataset["train"].column_names[0] # 기본적으로 "text"일 가능성이 높음 # 모델과 토크나이저 로딩 model_name = "distilbert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.to(device) # 모델을 CPU로 이동 # 데이터셋을 모델에 맞게 전처리 def tokenize_function(examples): return tokenizer(examples[text_column], padding="max_length", truncation=True) tokenized_train_datasets = dataset["train"].map(tokenize_function, batched=True, batch_size=None, remove_columns=[text_column]) tokenized_test_datasets = dataset["test"].map(tokenize_function, batched=True, batch_size=None, remove_columns=[text_column]) # 훈련 설정 (GPU 사용 안 함) training_args = TrainingArguments( output_dir="./results", # 결과 저장 경로 num_train_epochs=1, # 훈련 에폭 수 1로 설정 (빠르게 테스트) per_device_train_batch_size=4, # 배치 크기 줄이기 (CPU에서는 작은 값 추천) per_device_eval_batch_size=4, # 배치 크기 줄이기 evaluation_strategy="epoch", # 에폭마다 검증 save_strategy="epoch", logging_dir="./logs", # 로그 저장 경로 logging_steps=100, # 100 스텝마다 로그 출력 report_to="none", # 허깅페이스 업로드 시 로깅 비활성화 load_best_model_at_end=True, # 최상의 모델로 종료 no_cuda=True # ❌ GPU 사용하지 않도록 설정 ) # 훈련 함수 def train_model(): trainer = Trainer( model=model, # 훈련할 모델 args=training_args, # 훈련 인자 train_dataset=tokenized_train_datasets, # 훈련 데이터셋 eval_dataset=tokenized_test_datasets, # 평가 데이터셋 ) trainer.train() # 훈련을 별도의 스레드에서 실행 def start_training(): train_thread = threading.Thread(target=train_model) train_thread.start() # 텍스트 분류 함수 (CPU에서 실행) def classify_text(text): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) with torch.no_grad(): # 불필요한 연산 방지 outputs = model(**inputs) logits = outputs.logits predicted_class = logits.argmax(-1).item() return str(predicted_class) # Gradio에서 문자열 반환이 더 안정적 # Gradio 인터페이스 설정 demo = gr.Interface(fn=classify_text, inputs="text", outputs="text") # 훈련 시작과 Gradio UI 실행 def launch_app(): start_training() # 훈련 시작 demo.launch() # Gradio UI 실행 # 허깅페이스 Spaces에 업로드할 때 실행 if __name__ == "__main__": launch_app()