Spaces:
Configuration error
Configuration error
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from ResNet_for_CC import CC_model # Import the model | |
# Set device (CPU/GPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load the trained CC_model | |
model_path = "CC_net.pt" | |
model = CC_model(num_classes1=14) | |
# Load model weights | |
state_dict = torch.load(model_path, map_location=device) | |
model.load_state_dict(state_dict, strict=False) | |
model.to(device) | |
model.eval() | |
# Clothing1M Class Labels | |
class_labels = [ | |
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", | |
"Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", | |
"Vest", "Underwear" | |
] | |
# β **Updated Image Preprocessing Function** | |
def preprocess_image(image): | |
"""Applies necessary transformations to the input image.""" | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
return transform(image).unsqueeze(0).to(device) | |
# β **Classification Function** | |
def classify_image(image): | |
"""Processes the input image and returns the predicted clothing category.""" | |
print("\n[INFO] Received image for classification.") | |
try: | |
image = Image.fromarray(image) # Ensure conversion to PIL format | |
image = preprocess_image(image) # Apply transformations | |
print("[INFO] Image transformed and moved to device.") | |
with torch.no_grad(): | |
output = model(image) | |
# β Ensure output is a tensor (handle tuple case) | |
if isinstance(output, tuple): | |
output = output[1] # Extract the actual output tensor | |
print(f"[DEBUG] Model output shape: {output.shape}") | |
print(f"[DEBUG] Model output values: {output}") | |
if output.shape[1] != 14: | |
return f"[ERROR] Model output mismatch! Expected 14 but got {output.shape[1]}." | |
# Convert logits to probabilities | |
probabilities = F.softmax(output, dim=1) | |
print(f"[DEBUG] Softmax probabilities: {probabilities}") | |
# Get predicted class index | |
predicted_class = torch.argmax(probabilities, dim=1).item() | |
print(f"[INFO] Predicted class index: {predicted_class} (Class: {class_labels[predicted_class]})") | |
# Validate and return the prediction | |
if 0 <= predicted_class < len(class_labels): | |
predicted_label = class_labels[predicted_class] | |
confidence = probabilities[0][predicted_class].item() * 100 | |
return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)" | |
else: | |
return "[ERROR] Model returned an invalid class index." | |
except Exception as e: | |
print(f"[ERROR] Exception during classification: {e}") | |
return "Error in classification. Check console for details." | |
# β **Gradio Interface** | |
interface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(type="numpy"), | |
outputs="text", | |
title="Clothing1M Image Classifier", | |
description="Upload a clothing image, and the model will classify it into one of the 14 categories." | |
) | |
# β **Run the Interface** | |
if __name__ == "__main__": | |
print("[INFO] Launching Gradio interface...") | |
interface.launch() | |