Spaces:
Running
Running
import torch | |
import numpy as np | |
import gradio as gr | |
from pathlib import Path | |
from PIL import Image | |
from torchvision import transforms | |
from huggingface_hub import hf_hub_download | |
from ResNet_for_CC import CC_model | |
# Define the Clothing1M class labels | |
CLOTHING1M_CLASSES = [ | |
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", | |
"Hoodie", "Windbreaker", "Jacket", "Downcoat", | |
"Suit", "Shawl", "Dress", "Vest", "Underwear" | |
] | |
# Initialize the model | |
model = CC_model() | |
model_path = hf_hub_download(repo_id="mohamdlog/CC", filename="CC_net.pt") | |
model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
model.eval() | |
# Define preprocessing pipeline | |
def preprocess_image(image): | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
return transform(image).unsqueeze(0) | |
# Define classification function | |
def classify_image(image): | |
input_tensor = preprocess_image(image) | |
with torch.no_grad(): | |
output = model(input_tensor) | |
# Get predicted class and confidence | |
probabilities = torch.nn.functional.softmax(output, dim=1) | |
predicted_class_idx = output.argmax(dim=1).item() | |
predicted_class = CLOTHING1M_CLASSES[predicted_class_idx] | |
confidence = probabilities[0][predicted_class_idx].item() | |
return f"Category: {predicted_class}\nConfidence: {confidence:.2f}" | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(label="Uploaded Image"), | |
outputs=gr.Text(label="Predicted Clothing"), | |
title="Clothing Category Classifier", | |
description = """ | |
**Upload an image of clothing, and the model will predict its category.** | |
Try using an image that doesn't belong to any of the available categories, and see how the result differs! | |
**Categories:** | |
| T-Shirt | Shirt | Knitwear | Chiffon | Sweater | Hoodie | Windbreaker | | |
| Jacket | Downcoat | Suit | Shawl | Dress | Vest | Underwear | | |
""", | |
examples=[[str(file)] for file in Path("examples").glob("*")], | |
flagging_mode="never", | |
theme="soft" | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
interface.launch() |