import torch from transformers import AutoModelForImageClassification, AutoProcessor, Trainer, TrainingArguments from datasets import load_dataset # 1. Carregar o modelo e o processador pré-treinado model_name = "google/vit-base-patch16-224-in21k" model = AutoModelForImageClassification.from_pretrained(model_name, num_labels=3) processor = AutoProcessor.from_pretrained(model_name) # 2. Carregar o dataset dataset = load_dataset("beans") # 3. Função de pré-processamento das imagens def preprocess_data(example): # Processar apenas a coluna "image" image = example['image'] # O processor transforma a imagem em um tensor que o modelo pode entender inputs = processor(images=image, return_tensors="pt") # O Trainer espera tensores puros, então convertemos pixel_values = inputs["pixel_values"].squeeze() # Remove dimensões extras labels = torch.tensor(example["labels"], dtype=torch.long) # Converte labels para tensor return {"pixel_values": pixel_values, "labels": labels} # 4. Aplicar o pré-processamento às imagens do dataset train_dataset = dataset["train"].map(preprocess_data, remove_columns=["image"]) eval_dataset = dataset["test"].map(preprocess_data, remove_columns=["image"]) # **Corrigir o formato do dataset** - Definir os formatos corretamente train_dataset.set_format(type="torch", columns=["pixel_values", "labels"]) eval_dataset.set_format(type="torch", columns=["pixel_values", "labels"]) # 5. Configurar os parâmetros de treinamento training_args = TrainingArguments( output_dir="./vit-finetuned", # Diretório para salvar o modelo treinado num_train_epochs=3, # Número de épocas para treinamento per_device_train_batch_size=8, # Tamanho do batch de treinamento evaluation_strategy="epoch", # Avaliar o modelo a cada época save_strategy="epoch", # Salvar o modelo a cada época save_total_limit=2 # Limitar o número de checkpoints salvos ) # 6. Configurar o Trainer trainer = Trainer( model=model, # O modelo treinado args=training_args, # Argumentos de treinamento train_dataset=train_dataset, # Dataset de treinamento eval_dataset=eval_dataset # Dataset de avaliação ) # 7. Iniciar o treinamento trainer.train() # 8. Salvar o modelo finetunado trainer.save_model("./vit-finetuned")