Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
# Use a smaller, more efficient model | |
model_name = "microsoft/resnet-18" # Smaller model that should work with Hugging Face constraints | |
# Load model and feature extractor | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
model = AutoModelForImageClassification.from_pretrained(model_name) | |
# Function to classify image | |
def classify_image(image): | |
if image is None: | |
return "No image provided", None | |
try: | |
# Process image | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Get predicted class | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_class = model.config.id2label[predicted_class_idx] | |
# Get top 5 predictions | |
probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
top5_prob, top5_indices = torch.topk(probs, 5) | |
# Create plot for visualization | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
# Get class names and probabilities | |
classes = [model.config.id2label[idx.item()] for idx in top5_indices] | |
probabilities = [prob.item() * 100 for prob in top5_prob] | |
# Create horizontal bar chart | |
bars = ax.barh(classes, probabilities, color='#4C72B0') | |
ax.set_xlabel('Probability (%)') | |
ax.set_title('Top 5 Predictions') | |
# Add percentage labels | |
for i, bar in enumerate(bars): | |
width = bar.get_width() | |
ax.text(width + 1, bar.get_y() + bar.get_height()/2, | |
f'{probabilities[i]:.1f}%', | |
va='center', fontsize=10) | |
# Improve layout | |
plt.tight_layout() | |
return predicted_class, fig | |
except Exception as e: | |
return f"Error: {str(e)}", None | |
# Create Gradio interface with simpler structure | |
demo = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Textbox(label="Prediction"), | |
gr.Plot(label="Confidence Levels") | |
], | |
title="🖼️ Image Classification Tool", | |
description="Upload an image to see what the AI recognizes in it!", | |
allow_flagging="never", | |
examples=[], # No examples to avoid dependencies | |
theme=gr.themes.Soft() | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |