Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import shutil | |
import time | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from sklearn.metrics import classification_report | |
from sklearn.preprocessing import LabelEncoder | |
from tqdm import tqdm | |
from transformers import ViTFeatureExtractor, ViTModel | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
from train import ( | |
f1_score, | |
metric, | |
re_training, | |
) | |
data_path = os.environ.get("DATA_PATH", "./data") | |
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING)) | |
logger = logging.getLogger(__name__) | |
class ViTForImageClassification(nn.Module): | |
def get_device(): | |
if torch.cuda.is_available(): | |
return torch.device("cuda") | |
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
return torch.device("mps") | |
return torch.device("cpu") | |
def __init__(self, model_name, num_labels=24, dropout=0.25, image_size=224): | |
logger.info("Loading model") | |
super(ViTForImageClassification, self).__init__() | |
self.vit = ViTModel.from_pretrained(model_name) | |
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
self.feature_extractor.do_resize = True | |
self.feature_extractor.size = image_size | |
self.dropout = nn.Dropout(dropout) | |
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels) | |
self.num_labels = num_labels | |
self.label_encoder = LabelEncoder() | |
self.device = self.get_device() | |
logger.info(f"Using device: {self.device}") | |
self.model_name = model_name | |
# To device | |
self.vit.to(self.device) | |
self.to(self.device) | |
self.classifier.to(self.device) | |
logger.info("Model loaded") | |
def forward(self, pixel_values, labels): | |
logger.info("Forwarding") | |
pixel_values = pixel_values.to(self.device) | |
outputs = self.vit(pixel_values=pixel_values) | |
output = self.dropout(outputs.last_hidden_state[:, 0]) | |
logits = self.classifier(output) | |
loss = None | |
if labels is not None: | |
loss_fct = nn.CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def preprocess_image(self, images): | |
logger.info("Preprocessing images") | |
return self.feature_extractor(images, return_tensors="pt") | |
def predict( | |
self, images, batch_size=32, classes_names=True, return_probabilities=False | |
): | |
logger.info("Predicting") | |
if not isinstance(images, list): | |
images = [images] | |
classes_list = [] | |
confidence_list = [] | |
for bs in tqdm( | |
range(0, len(images), batch_size), desc="Preprocessing training images" | |
): | |
images_batch = [image for image in images[bs : bs + batch_size]] | |
images_batch = self.preprocess_image(images_batch)["pixel_values"] | |
sequence_classifier_output = self.forward(images_batch, None) | |
# Get max prob | |
probs = sequence_classifier_output.logits.softmax(dim=-1).tolist() | |
classes = np.argmax(probs, axis=1) | |
confidences = np.max(probs, axis=1) | |
classes_list.extend(classes) | |
confidence_list.extend(confidences) | |
if classes_names: | |
classes_list = self.label_encoder.inverse_transform(classes_list) | |
if return_probabilities: | |
return classes_list, confidence_list, probs | |
return classes_list, confidence_list | |
def save(self, path): | |
logger.info("Saving model") | |
os.makedirs(path, exist_ok=True) | |
torch.save(self.state_dict(), path + "/model.pt") | |
# Save label encoder | |
np.save(path + "/label_encoder.npy", self.label_encoder.classes_) | |
def load(self, path): | |
logger.info("Loading model") | |
# Load label encoder | |
# Check if label encoder and model exists | |
if not os.path.exists(path + "/label_encoder.npy") or not os.path.exists( | |
path + "/model.pt" | |
): | |
logger.warning("Label encoder or model not found") | |
return | |
self.label_encoder.classes_ = np.load(path + "/label_encoder.npy") | |
# Reload classifier layer | |
self.classifier = nn.Linear( | |
self.vit.config.hidden_size, len(self.label_encoder.classes_) | |
) | |
self.load_state_dict(torch.load(path + "/model.pt", map_location=self.device)) | |
self.vit.to(self.device) | |
self.vit.eval() | |
self.to(self.device) | |
self.eval() | |
def evaluate(self, images, labels): | |
logger.info("Evaluating") | |
labels = self.label_encoder.transform(labels) | |
# Predict | |
y_pred, _ = self.predict(images, classes_names=False) | |
# Evaluate | |
metrics = metric.compute(predictions=y_pred, references=labels) | |
f1 = f1_score.compute(predictions=y_pred, references=labels, average="macro") | |
print( | |
classification_report( | |
labels, | |
y_pred, | |
labels=[i for i in range(len(self.label_encoder.classes_))], | |
target_names=self.label_encoder.classes_, | |
) | |
) | |
print(f"Accuracy: {metrics['accuracy']}") | |
print(f"F1: {f1}") | |
def partial_fit(self, images, labels, save_model_path="new_model", num_epochs=10): | |
logger.info("Partial fitting") | |
# Freeze ViT model but last layer | |
# params = [param for param in self.vit.parameters()] | |
# for param in params[:-1]: | |
# param.requires_grad = False | |
# Model in training mode | |
self.vit.train() | |
self.train() | |
re_training(images, labels, self, save_model_path, num_epochs) | |
self.load(save_model_path) | |
self.vit.eval() | |
self.eval() | |
self.evaluate(images, labels) | |
def __load_from_path(self, path, num_per_label=None): | |
images = [] | |
labels = [] | |
for label in os.listdir(path): | |
count = 0 | |
label_folder_path = os.path.join(path, label) | |
for image_file in tqdm( | |
os.listdir(label_folder_path), | |
desc="Resizing images for label {}".format(label), | |
): | |
file_path = os.path.join(label_folder_path, image_file) | |
try: | |
image = Image.open(file_path) | |
image_shape = ( | |
self.feature_extractor.size, | |
self.feature_extractor.size, | |
) | |
if image.size != image_shape: | |
image = image.resize(image_shape) | |
images.append(image.convert("RGB")) | |
labels.append(label) | |
count += 1 | |
except Exception as e: | |
print(f"ERROR - Could not resize image {file_path} - {e}") | |
if num_per_label is not None and count >= num_per_label: | |
break | |
return images, labels | |
def retrain_from_path( | |
self, | |
path="./data/feedback", | |
num_per_label=None, | |
save_model_path="new_model", | |
remove_path=False, | |
num_epochs=10, | |
save_new_data=data_path + "/new_data", | |
): | |
logger.info("Retraining from path") | |
# Load path | |
images, labels = self.__load_from_path(path, num_per_label) | |
# Retrain | |
self.partial_fit(images, labels, save_model_path, num_epochs) | |
# Save new data | |
if save_new_data is not None: | |
logger.info("Saving new data") | |
for i, (image, label) in enumerate(zip(images, labels)): | |
label_path = os.path.join(save_new_data, label) | |
os.makedirs(label_path, exist_ok=True) | |
image.save( | |
os.path.join(label_path, str(int(time.time())) + f"_{i}.jpg") | |
) | |
# Remove path folder | |
if remove_path: | |
logger.info("Removing feedback path") | |
shutil.rmtree(path) | |
def evaluate_from_path(self, path, num_per_label=None): | |
logger.info("Evaluating from path") | |
# Load images | |
images, labels = self.__load_from_path(path, num_per_label) | |
# Evaluate | |
self.evaluate(images, labels) | |