Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import json | |
import torchvision.models as models | |
from transformers import AutoImageProcessor | |
# Load ImageNet class labels | |
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" | |
response = requests.get(LABELS_URL) | |
labels = json.loads(response.text) | |
def load_model(): | |
""" | |
Load model and processor from Hugging Face Hub | |
""" | |
model_id = "jatingocodeo/ImageNet" | |
# Initialize ResNet50 model | |
model = models.resnet50(weights=None) | |
model.fc = torch.nn.Linear(model.fc.in_features, 1000) # 1000 ImageNet classes | |
# Load model weights | |
checkpoint = torch.hub.load_state_dict_from_url( | |
f"https://huggingface.co/{model_id}/resolve/main/pytorch_model.bin", | |
map_location="cpu" | |
) | |
model.load_state_dict(checkpoint) | |
model.eval() | |
# Create processor | |
processor = AutoImageProcessor.from_pretrained(model_id) | |
return model, processor | |
def predict(image): | |
""" | |
Make prediction on input image | |
""" | |
if image is None: | |
return None | |
try: | |
# Load model and processor (with caching) | |
model, processor = load_model() | |
# Process image | |
inputs = processor(image, return_tensors="pt") | |
# Get predictions | |
with torch.no_grad(): | |
outputs = model(inputs.pixel_values) | |
# Get probabilities and classes | |
probs = torch.nn.functional.softmax(outputs, dim=1)[0] | |
top_probs, top_indices = torch.topk(probs, k=5) | |
# Format results | |
results = {} | |
for prob, idx in zip(top_probs, top_indices): | |
label = labels[idx.item()] | |
confidence = prob.item() * 100 | |
results[label] = confidence | |
return results | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
# Create Gradio interface | |
title = "ImageNet Classifier" | |
description = """ | |
## ResNet50 ImageNet Classifier | |
This model classifies images into 1000 ImageNet categories. Upload an image or use one of the example images to get predictions. | |
### Instructions: | |
1. Upload an image using the input box below | |
2. The model will predict the top 5 classes for the image | |
3. Results show class names and confidence scores | |
### Model Details: | |
- Architecture: ResNet50 | |
- Dataset: ImageNet | |
- Input Size: 224x224 | |
- Number of Classes: 1000 | |
""" | |
# Example images | |
examples = [ | |
"examples/dog.jpg", | |
"examples/cat.jpg", | |
"examples/bird.jpg", | |
"examples/car.jpg", | |
"examples/flower.jpg" | |
] | |
# Create the interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil", label="Upload Image"), | |
outputs=gr.Label(num_top_classes=5, label="Predictions"), | |
title=title, | |
description=description, | |
examples=examples, | |
theme=gr.themes.Soft(), | |
allow_flagging="never" | |
) | |
# Launch the app | |
iface.launch() |