File size: 908 Bytes
b083548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torchvision.transforms as transforms

from torchvision.models import ResNet50_Weights

# ์‚ฌ์ „ ํ•™์Šต๋œ ResNet ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
from torchvision import models

model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
model.eval()

# ๋ผ๋ฒจ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ (ImageNet ํด๋ž˜์Šค ๋ผ๋ฒจ)
with open("imagenet_classes.txt") as f:
    labels = [line.strip() for line in f.readlines()]

# ๋ณ€ํ™˜ ํŒŒ์ดํ”„๋ผ์ธ (ํฌ๊ธฐ ์กฐ์ •, ํ…์„œ ๋ณ€ํ™˜, ์ •๊ทœํ™” ๋“ฑ)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

def classify_image(image):
    # image: PIL Image
    img_t = transform(image)
    batch_t = torch.unsqueeze(img_t, 0)
    out = model(batch_t)
    _, index = torch.max(out, 1)
    return labels[index.item()]