Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from model import AlexNet | |
from torchvision import transforms | |
model_path = './alexnet_model_v1.pth' | |
model = AlexNet() | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
model.eval() | |
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
def predict(inp): | |
inp = transforms.ToTensor()(inp).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0) | |
confidences = {labels[i]: float(prediction[i]) for i in range(10)} | |
return confidences | |
gr.Interface(fn=predict, | |
inputs=gr.components.Image(type="pil"), | |
outputs=gr.components.Label(num_top_classes=5), | |
examples=["frog.jpeg", "car.jpeg", "cat.jpeg", "ship.jpeg", "dog.jpeg"], | |
theme="default", | |
css=".footer{display:none !important}").launch() |