ugaray96's picture
Refactor and improve model, app, and training components
0f734ea unverified
import logging
import os
import numpy as np
from evaluate import load
from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
from tqdm import tqdm
from transformers import BatchFeature, Trainer, TrainingArguments
from dataset import RetailDataset
metric = load("accuracy")
f1_score = load("f1")
np.random.seed(42)
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
logger = logging.getLogger(__name__)
def prepare_dataset(
images,
labels,
model,
test_size=0.2,
train_transform=None,
val_transform=None,
batch_size=512,
):
logger.info("Preparing dataset")
# Split the dataset in train and test
try:
images_train, images_test, labels_train, labels_test = train_test_split(
images, labels, test_size=test_size
)
except ValueError:
logger.warning(
"Could not split dataset. Using all data for training and testing"
)
images_train = images
labels_train = labels
images_test = images
labels_test = labels
# Preprocess images using model feature extractor
images_train_prep = []
images_test_prep = []
for bs in tqdm(
range(0, len(images_train), batch_size), desc="Preprocessing training images"
):
images_train_batch = [
Image.fromarray(np.array(image))
for image in images_train[bs : bs + batch_size]
]
images_train_batch = model.preprocess_image(images_train_batch)
images_train_prep.extend(images_train_batch["pixel_values"])
for bs in tqdm(
range(0, len(images_test), batch_size), desc="Preprocessing test images"
):
images_test_batch = [
Image.fromarray(np.array(image))
for image in images_test[bs : bs + batch_size]
]
images_test_batch = model.preprocess_image(images_test_batch)
images_test_prep.extend(images_test_batch["pixel_values"])
# Create BatchFeatures
images_train_prep = {"pixel_values": images_train_prep}
train_batch_features = BatchFeature(data=images_train_prep)
images_test_prep = {"pixel_values": images_test_prep}
test_batch_features = BatchFeature(data=images_test_prep)
# Create the datasets with proper device
train_dataset = RetailDataset(
train_batch_features, labels_train, train_transform, device=model.device
)
test_dataset = RetailDataset(
test_batch_features, labels_test, val_transform, device=model.device
)
logger.info("Train dataset: %d images", len(labels_train))
logger.info("Test dataset: %d images", len(labels_test))
return train_dataset, test_dataset
def re_training(images, labels, _model, save_model_path="new_model", num_epochs=10):
global model
model = _model
labels = model.label_encoder.transform(labels)
normalize = Normalize(
mean=model.feature_extractor.image_mean, std=model.feature_extractor.image_std
)
def train_transforms(batch):
return Compose(
[
RandomResizedCrop(model.feature_extractor.size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)(batch)
def val_transforms(batch):
return Compose(
[
Resize(model.feature_extractor.size),
CenterCrop(model.feature_extractor.size),
ToTensor(),
normalize,
]
)(batch)
train_dataset, test_dataset = prepare_dataset(
images, labels, model, 0.2, train_transforms, val_transforms
)
trainer = Trainer(
model=model,
args=TrainingArguments(
output_dir="output",
overwrite_output_dir=True,
num_train_epochs=num_epochs,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
learning_rate=0.000001,
weight_decay=0.01,
eval_strategy="steps",
eval_steps=1000,
save_steps=3000,
use_cpu=model.device.type == "cpu", # Only force CPU if that's our device
),
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
trainer.train()
model.save(save_model_path)