pmkhanh7890's picture
run pre-commit
38fd181
raw
history blame
5.11 kB
import logging
import pytorch_lightning as pl
import timm
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from sklearn.metrics import roc_auc_score
from torchmetrics import (
Accuracy,
Recall,
)
from torchvision.transforms import v2
logging.basicConfig(
filename="training.log",
filemode="w",
level=logging.INFO,
force=True,
)
CHECKPOINT = (
"models/image_classifier/image-classifier-step=8008-val_loss=0.11.ckpt"
)
class ImageClassifier(pl.LightningModule):
def __init__(self, lmd=0):
super().__init__()
self.model = timm.create_model(
"resnet50",
pretrained=True,
num_classes=1,
)
self.accuracy = Accuracy(task="binary", threshold=0.5)
self.recall = Recall(task="binary", threshold=0.5)
self.validation_outputs = []
self.lmd = lmd
def forward(self, x):
return self.model(x)
def training_step(self, batch):
images, labels, _ = batch
outputs = self.forward(images).squeeze()
print(f"Shape of outputs (training): {outputs.shape}")
print(f"Shape of labels (training): {labels.shape}")
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
logging.info(f"Training Step - ERM loss: {loss.item()}")
loss += self.lmd * (outputs**2).mean() # SD loss penalty
logging.info(f"Training Step - SD loss: {loss.item()}")
return loss
def validation_step(self, batch):
images, labels, _ = batch
outputs = self.forward(images).squeeze()
if outputs.shape == torch.Size([]):
return
print(f"Shape of outputs (validation): {outputs.shape}")
print(f"Shape of labels (validation): {labels.shape}")
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
preds = torch.sigmoid(outputs)
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
self.log(
"val_acc",
self.accuracy(preds, labels.int()),
prog_bar=True,
sync_dist=True,
)
self.log(
"val_recall",
self.recall(preds, labels.int()),
prog_bar=True,
sync_dist=True,
)
output = {"val_loss": loss, "preds": preds, "labels": labels}
self.validation_outputs.append(output)
logging.info(f"Validation Step - Batch loss: {loss.item()}")
return output
def predict_step(self, batch):
images, label, domain = batch
outputs = self.forward(images).squeeze()
preds = torch.sigmoid(outputs)
return preds, label, domain
def on_validation_epoch_end(self):
if not self.validation_outputs:
logging.warning("No outputs in validation step to process")
return
preds = torch.cat([x["preds"] for x in self.validation_outputs])
labels = torch.cat([x["labels"] for x in self.validation_outputs])
if labels.unique().size(0) == 1:
logging.warning("Only one class in validation step")
return
auc_score = roc_auc_score(labels.cpu(), preds.cpu())
self.log("val_auc", auc_score, prog_bar=True, sync_dist=True)
logging.info(f"Validation Epoch End - AUC score: {auc_score}")
self.validation_outputs = []
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005)
return optimizer
def load_image(image_path, transform=None):
image = Image.open(image_path).convert("RGB")
if transform:
image = transform(image)
return image
def predict_single_image(image_path, model, transform=None):
image = load_image(image_path, transform)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
image = image.to(device)
model.eval()
with torch.no_grad():
image = image.unsqueeze(0)
output = model(image).squeeze()
prediction = torch.sigmoid(output).item()
return prediction
def image_generation_detection(image_path):
model = ImageClassifier.load_from_checkpoint(CHECKPOINT)
transform = v2.Compose(
[
transforms.ToTensor(),
v2.CenterCrop((256, 256)),
],
)
prediction = predict_single_image(image_path, model, transform)
result = ""
if prediction <= 0.2:
result += "Most likely human"
image_prediction_label = "HUMAN"
else:
result += "Most likely machine"
image_prediction_label = "MACHINE"
image_confidence = min(1, 0.5 + abs(prediction - 0.2))
result += f" with confidence = {round(image_confidence * 100, 2)}%"
# image_confidence = round(image_confidence * 100, 2)
return image_prediction_label, image_confidence
if __name__ == "__main__":
image_path = "path_to_your_image.jpg" # Replace with your image path
image_prediction_label, image_confidence = image_generation_detection(
image_path,
)