ResNet-18 / app.py
Your Name
title added
71ca213
from pathlib import Path
from PIL import Image
import numpy as np
import torch
import requests
from io import BytesIO
from torchvision.models import resnet18, ResNet18_Weights
def predict(img_path = None) -> str:
# Initialize the model and transform
resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT)
resnet_transform = ResNet18_Weights.DEFAULT.transforms()
# Load the image
if img_path is None:
image = Image.open("examples/steak.jpeg").convert("RGB")
if isinstance(img_path, np.ndarray):
img = Image.fromarray(img_path.astype("uint8"), "RGB")
# img = effnet_b2_transform(img).unsqueeze(0)
# Convert to tensor
# img = torch.from_numpy(np.array(image)).permute(2, 0, 1)
img = resnet_transform(img)
# Inference
resnet_model.eval()
with torch.inference_mode():
logits = resnet_model(img.unsqueeze(0))
pred_class = torch.softmax(logits, dim=1).argmax(dim=1).item()
predicted_label = ResNet18_Weights.DEFAULT.meta["categories"][pred_class]
print(f"Predicted class: {predicted_label}")
return predicted_label
import numpy as np
import gradio as gr
demo = gr.Interface(predict,
gr.Image(),
"label",
title="ResNet-18_1K πŸš—",
description="Upload an image to see classification probabilities based on ResNet-18 with 1K classes",)
if __name__ == "__main__":
demo.launch()