Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForImageClassification, AutoProcessor, Trainer, TrainingArguments | |
from datasets import load_dataset | |
# 1. Carregar o modelo e o processador pré-treinado | |
model_name = "google/vit-base-patch16-224-in21k" | |
model = AutoModelForImageClassification.from_pretrained(model_name, num_labels=3) | |
processor = AutoProcessor.from_pretrained(model_name) | |
# 2. Carregar o dataset | |
dataset = load_dataset("beans") | |
# 3. Função de pré-processamento das imagens | |
def preprocess_data(example): | |
# Processar apenas a coluna "image" | |
image = example['image'] | |
# O processor transforma a imagem em um tensor que o modelo pode entender | |
inputs = processor(images=image, return_tensors="pt") | |
# O Trainer espera tensores puros, então convertemos | |
pixel_values = inputs["pixel_values"].squeeze() # Remove dimensões extras | |
labels = torch.tensor(example["labels"], dtype=torch.long) # Converte labels para tensor | |
return {"pixel_values": pixel_values, "labels": labels} | |
# 4. Aplicar o pré-processamento às imagens do dataset | |
train_dataset = dataset["train"].map(preprocess_data, remove_columns=["image"]) | |
eval_dataset = dataset["test"].map(preprocess_data, remove_columns=["image"]) | |
# **Corrigir o formato do dataset** - Definir os formatos corretamente | |
train_dataset.set_format(type="torch", columns=["pixel_values", "labels"]) | |
eval_dataset.set_format(type="torch", columns=["pixel_values", "labels"]) | |
# 5. Configurar os parâmetros de treinamento | |
training_args = TrainingArguments( | |
output_dir="./vit-finetuned", # Diretório para salvar o modelo treinado | |
num_train_epochs=3, # Número de épocas para treinamento | |
per_device_train_batch_size=8, # Tamanho do batch de treinamento | |
evaluation_strategy="epoch", # Avaliar o modelo a cada época | |
save_strategy="epoch", # Salvar o modelo a cada época | |
save_total_limit=2 # Limitar o número de checkpoints salvos | |
) | |
# 6. Configurar o Trainer | |
trainer = Trainer( | |
model=model, # O modelo treinado | |
args=training_args, # Argumentos de treinamento | |
train_dataset=train_dataset, # Dataset de treinamento | |
eval_dataset=eval_dataset # Dataset de avaliação | |
) | |
# 7. Iniciar o treinamento | |
trainer.train() | |
# 8. Salvar o modelo finetunado | |
trainer.save_model("./vit-finetuned") | |