import torch import torchvision.transforms as transforms from torchvision.models import ResNet50_Weights # 사전 학습된 ResNet 불러오기 from torchvision import models model = models.resnet50(weights=ResNet50_Weights.DEFAULT) model.eval() # 라벨 불러오기 (ImageNet 클래스 라벨) with open("imagenet_classes.txt") as f: labels = [line.strip() for line in f.readlines()] # 변환 파이프라인 (크기 조정, 텐서 변환, 정규화 등) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def classify_image(image): # image: PIL Image img_t = transform(image) batch_t = torch.unsqueeze(img_t, 0) out = model(batch_t) _, index = torch.max(out, 1) return labels[index.item()]