Viihtorugo
Add files
a2e27b6
import torch
from transformers import AutoModelForImageClassification, AutoProcessor, Trainer, TrainingArguments
from datasets import load_dataset
# 1. Carregar o modelo e o processador pré-treinado
model_name = "google/vit-base-patch16-224-in21k"
model = AutoModelForImageClassification.from_pretrained(model_name, num_labels=3)
processor = AutoProcessor.from_pretrained(model_name)
# 2. Carregar o dataset
dataset = load_dataset("beans")
# 3. Função de pré-processamento das imagens
def preprocess_data(example):
# Processar apenas a coluna "image"
image = example['image']
# O processor transforma a imagem em um tensor que o modelo pode entender
inputs = processor(images=image, return_tensors="pt")
# O Trainer espera tensores puros, então convertemos
pixel_values = inputs["pixel_values"].squeeze() # Remove dimensões extras
labels = torch.tensor(example["labels"], dtype=torch.long) # Converte labels para tensor
return {"pixel_values": pixel_values, "labels": labels}
# 4. Aplicar o pré-processamento às imagens do dataset
train_dataset = dataset["train"].map(preprocess_data, remove_columns=["image"])
eval_dataset = dataset["test"].map(preprocess_data, remove_columns=["image"])
# **Corrigir o formato do dataset** - Definir os formatos corretamente
train_dataset.set_format(type="torch", columns=["pixel_values", "labels"])
eval_dataset.set_format(type="torch", columns=["pixel_values", "labels"])
# 5. Configurar os parâmetros de treinamento
training_args = TrainingArguments(
output_dir="./vit-finetuned", # Diretório para salvar o modelo treinado
num_train_epochs=3, # Número de épocas para treinamento
per_device_train_batch_size=8, # Tamanho do batch de treinamento
evaluation_strategy="epoch", # Avaliar o modelo a cada época
save_strategy="epoch", # Salvar o modelo a cada época
save_total_limit=2 # Limitar o número de checkpoints salvos
)
# 6. Configurar o Trainer
trainer = Trainer(
model=model, # O modelo treinado
args=training_args, # Argumentos de treinamento
train_dataset=train_dataset, # Dataset de treinamento
eval_dataset=eval_dataset # Dataset de avaliação
)
# 7. Iniciar o treinamento
trainer.train()
# 8. Salvar o modelo finetunado
trainer.save_model("./vit-finetuned")