|
import torch
|
|
import torchvision.transforms as transforms
|
|
|
|
from torchvision.models import ResNet50_Weights
|
|
|
|
|
|
from torchvision import models
|
|
|
|
model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
|
|
model.eval()
|
|
|
|
|
|
with open("imagenet_classes.txt") as f:
|
|
labels = [line.strip() for line in f.readlines()]
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225]
|
|
)
|
|
])
|
|
|
|
def classify_image(image):
|
|
|
|
img_t = transform(image)
|
|
batch_t = torch.unsqueeze(img_t, 0)
|
|
out = model(batch_t)
|
|
_, index = torch.max(out, 1)
|
|
return labels[index.item()]
|
|
|