import logging import os import numpy as np from evaluate import load from PIL import Image from sklearn.model_selection import train_test_split from torchvision.transforms import ( CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor, ) from tqdm import tqdm from transformers import BatchFeature, Trainer, TrainingArguments from dataset import RetailDataset metric = load("accuracy") f1_score = load("f1") np.random.seed(42) logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING)) logger = logging.getLogger(__name__) def prepare_dataset( images, labels, model, test_size=0.2, train_transform=None, val_transform=None, batch_size=512, ): logger.info("Preparing dataset") # Split the dataset in train and test try: images_train, images_test, labels_train, labels_test = train_test_split( images, labels, test_size=test_size ) except ValueError: logger.warning( "Could not split dataset. Using all data for training and testing" ) images_train = images labels_train = labels images_test = images labels_test = labels # Preprocess images using model feature extractor images_train_prep = [] images_test_prep = [] for bs in tqdm( range(0, len(images_train), batch_size), desc="Preprocessing training images" ): images_train_batch = [ Image.fromarray(np.array(image)) for image in images_train[bs : bs + batch_size] ] images_train_batch = model.preprocess_image(images_train_batch) images_train_prep.extend(images_train_batch["pixel_values"]) for bs in tqdm( range(0, len(images_test), batch_size), desc="Preprocessing test images" ): images_test_batch = [ Image.fromarray(np.array(image)) for image in images_test[bs : bs + batch_size] ] images_test_batch = model.preprocess_image(images_test_batch) images_test_prep.extend(images_test_batch["pixel_values"]) # Create BatchFeatures images_train_prep = {"pixel_values": images_train_prep} train_batch_features = BatchFeature(data=images_train_prep) images_test_prep = {"pixel_values": images_test_prep} test_batch_features = BatchFeature(data=images_test_prep) # Create the datasets with proper device train_dataset = RetailDataset( train_batch_features, labels_train, train_transform, device=model.device ) test_dataset = RetailDataset( test_batch_features, labels_test, val_transform, device=model.device ) logger.info("Train dataset: %d images", len(labels_train)) logger.info("Test dataset: %d images", len(labels_test)) return train_dataset, test_dataset def re_training(images, labels, _model, save_model_path="new_model", num_epochs=10): global model model = _model labels = model.label_encoder.transform(labels) normalize = Normalize( mean=model.feature_extractor.image_mean, std=model.feature_extractor.image_std ) def train_transforms(batch): return Compose( [ RandomResizedCrop(model.feature_extractor.size), RandomHorizontalFlip(), ToTensor(), normalize, ] )(batch) def val_transforms(batch): return Compose( [ Resize(model.feature_extractor.size), CenterCrop(model.feature_extractor.size), ToTensor(), normalize, ] )(batch) train_dataset, test_dataset = prepare_dataset( images, labels, model, 0.2, train_transforms, val_transforms ) trainer = Trainer( model=model, args=TrainingArguments( output_dir="output", overwrite_output_dir=True, num_train_epochs=num_epochs, per_device_train_batch_size=32, gradient_accumulation_steps=1, learning_rate=0.000001, weight_decay=0.01, eval_strategy="steps", eval_steps=1000, save_steps=3000, use_cpu=model.device.type == "cpu", # Only force CPU if that's our device ), train_dataset=train_dataset, eval_dataset=test_dataset, ) trainer.train() model.save(save_model_path)