learn / app.py
yeonsoo
dif
a4fc148
raw
history blame
1.79 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋”ฉ
dataset = load_dataset("imdb")
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# ๋ฐ์ดํ„ฐ์…‹์„ ๋ชจ๋ธ์— ๋งž๊ฒŒ ์ „์ฒ˜๋ฆฌ
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# ํ›ˆ๋ จ ์„ค์ •
training_args = TrainingArguments(
output_dir="./results", # ๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ
num_train_epochs=3, # ํ›ˆ๋ จ ์—ํญ ์ˆ˜
per_device_train_batch_size=8, # ๋ฐฐ์น˜ ํฌ๊ธฐ
per_device_eval_batch_size=8, # ๊ฒ€์ฆ ๋ฐฐ์น˜ ํฌ๊ธฐ
evaluation_strategy="epoch", # ์—ํญ๋งˆ๋‹ค ๊ฒ€์ฆ
logging_dir="./logs", # ๋กœ๊ทธ ์ €์žฅ ๊ฒฝ๋กœ
)
trainer = Trainer(
model=model, # ํ›ˆ๋ จํ•  ๋ชจ๋ธ
args=training_args, # ํ›ˆ๋ จ ์ธ์ž
train_dataset=tokenized_datasets["train"], # ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์…‹
eval_dataset=tokenized_datasets["test"], # ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹
)
# ํ›ˆ๋ จ ์‹œ์ž‘
trainer.train()
# ๊ทธ๋ผ๋””์˜ค ์ธํ„ฐํŽ˜์ด์Šค๋กœ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ UI์— ์—ฐ๊ฒฐ
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
return predicted_class
demo = gr.Interface(fn=classify_text, inputs="text", outputs="text")
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰ (ํ›ˆ๋ จ ํ›„)
demo.launch()