File size: 3,215 Bytes
58f09d1 bbf9c8b 627db6a a4fc148 58f09d1 627db6a a4fc148 58f09d1 627db6a a4fc148 627db6a a4fc148 627db6a a4fc148 627db6a a4fc148 627db6a a4fc148 bbf9c8b 627db6a a4fc148 627db6a a4fc148 bbf9c8b 627db6a bbf9c8b 627db6a a4fc148 bbf9c8b a4fc148 bbf9c8b a4fc148 627db6a a4fc148 627db6a a4fc148 627db6a a4fc148 bbf9c8b a4fc148 bbf9c8b 627db6a bbf9c8b 627db6a bbf9c8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
import threading
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
# GPU๊ฐ ์๋ CPU์์ ์คํํ๋๋ก ์ค์
device = torch.device("cpu")
# IMDb ๋ฐ์ดํฐ์
๋ก๋ฉ
dataset = load_dataset("imdb")
# ๋ฐ์ดํฐ์
์ ํ
์คํธ ์ปฌ๋ผ ์๋ ๊ฐ์ง
text_column = dataset["train"].column_names[0] # ๊ธฐ๋ณธ์ ์ผ๋ก "text"์ผ ๊ฐ๋ฅ์ฑ์ด ๋์
# ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋ฉ
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.to(device) # ๋ชจ๋ธ์ CPU๋ก ์ด๋
# ๋ฐ์ดํฐ์
์ ๋ชจ๋ธ์ ๋ง๊ฒ ์ ์ฒ๋ฆฌ
def tokenize_function(examples):
return tokenizer(examples[text_column], padding="max_length", truncation=True)
tokenized_train_datasets = dataset["train"].map(tokenize_function, batched=True, batch_size=None, remove_columns=[text_column])
tokenized_test_datasets = dataset["test"].map(tokenize_function, batched=True, batch_size=None, remove_columns=[text_column])
# ํ๋ จ ์ค์ (GPU ์ฌ์ฉ ์ ํจ)
training_args = TrainingArguments(
output_dir="./results", # ๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก
num_train_epochs=1, # ํ๋ จ ์ํญ ์ 1๋ก ์ค์ (๋น ๋ฅด๊ฒ ํ
์คํธ)
per_device_train_batch_size=4, # ๋ฐฐ์น ํฌ๊ธฐ ์ค์ด๊ธฐ (CPU์์๋ ์์ ๊ฐ ์ถ์ฒ)
per_device_eval_batch_size=4, # ๋ฐฐ์น ํฌ๊ธฐ ์ค์ด๊ธฐ
evaluation_strategy="epoch", # ์ํญ๋ง๋ค ๊ฒ์ฆ
save_strategy="epoch",
logging_dir="./logs", # ๋ก๊ทธ ์ ์ฅ ๊ฒฝ๋ก
logging_steps=100, # 100 ์คํ
๋ง๋ค ๋ก๊ทธ ์ถ๋ ฅ
report_to="none", # ํ๊น
ํ์ด์ค ์
๋ก๋ ์ ๋ก๊น
๋นํ์ฑํ
load_best_model_at_end=True, # ์ต์์ ๋ชจ๋ธ๋ก ์ข
๋ฃ
no_cuda=True # โ GPU ์ฌ์ฉํ์ง ์๋๋ก ์ค์
)
# ํ๋ จ ํจ์
def train_model():
trainer = Trainer(
model=model, # ํ๋ จํ ๋ชจ๋ธ
args=training_args, # ํ๋ จ ์ธ์
train_dataset=tokenized_train_datasets, # ํ๋ จ ๋ฐ์ดํฐ์
eval_dataset=tokenized_test_datasets, # ํ๊ฐ ๋ฐ์ดํฐ์
)
trainer.train()
# ํ๋ จ์ ๋ณ๋์ ์ค๋ ๋์์ ์คํ
def start_training():
train_thread = threading.Thread(target=train_model)
train_thread.start()
# ํ
์คํธ ๋ถ๋ฅ ํจ์ (CPU์์ ์คํ)
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad(): # ๋ถํ์ํ ์ฐ์ฐ ๋ฐฉ์ง
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
return str(predicted_class) # Gradio์์ ๋ฌธ์์ด ๋ฐํ์ด ๋ ์์ ์
# Gradio ์ธํฐํ์ด์ค ์ค์
demo = gr.Interface(fn=classify_text, inputs="text", outputs="text")
# ํ๋ จ ์์๊ณผ Gradio UI ์คํ
def launch_app():
start_training() # ํ๋ จ ์์
demo.launch() # Gradio UI ์คํ
# ํ๊น
ํ์ด์ค Spaces์ ์
๋ก๋ํ ๋ ์คํ
if __name__ == "__main__":
launch_app()
|