ResNet-18 / app.py
Your Name
first commit
beb105b
raw
history blame
1.32 kB
from pathlib import Path
from PIL import Image
import numpy as np
import torch
import requests
from io import BytesIO
from torchvision.models import resnet18, ResNet18_Weights
def predict(img_path = None) -> str:
# Initialize the model and transform
resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT)
resnet_transform = ResNet18_Weights.DEFAULT.transforms()
# Load the image
if img_path is None:
image = Image.open("examples/steak.jpeg").convert("RGB")
if isinstance(img_path, np.ndarray):
img = Image.fromarray(img_path.astype("uint8"), "RGB")
# img = effnet_b2_transform(img).unsqueeze(0)
# Convert to tensor
# img = torch.from_numpy(np.array(image)).permute(2, 0, 1)
img = resnet_transform(img)
# Inference
resnet_model.eval()
with torch.inference_mode():
logits = resnet_model(img.unsqueeze(0))
pred_class = torch.softmax(logits, dim=1).argmax(dim=1).item()
predicted_label = ResNet18_Weights.DEFAULT.meta["categories"][pred_class]
print(f"Predicted class: {predicted_label}")
return predicted_label
import numpy as np
import gradio as gr
demo = gr.Interface(predict, gr.Image(), "label")
if __name__ == "__main__":
demo.launch()