from transformers import Trainer, TrainingArguments from datasets import load_dataset from transformers import ViTForImageClassification, ViTFeatureExtractor # Carregar o dataset (exemplo com o dataset CIFAR-10) dataset = load_dataset("cifar10") # Carregar o modelo prĂ©-treinado e o feature extractor model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k") feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") # Preprocessamento def preprocess_function(examples): return feature_extractor(examples["img"], return_tensors="pt") # Aplicando o preprocessamento ao dataset dataset = dataset.map(preprocess_function, batched=True) # Definir os parĂ¢metros de treinamento training_args = TrainingArguments( output_dir="./results", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=64, num_train_epochs=3, weight_decay=0.01, ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], ) # Treinar o modelo trainer.train() model.save_pretrained("./computer-vision-cifar-10") feature_extractor.save_pretrained("./computer-vision-cifar-10")