Gallai's picture
Create app.py
3d640d4 verified
raw
history blame contribute delete
2.24 kB
import torch
import numpy as np
import gradio as gr
from pathlib import Path
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
from ResNet_for_CC import CC_model
# Define the Clothing1M class labels
CLOTHING1M_CLASSES = [
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater",
"Hoodie", "Windbreaker", "Jacket", "Downcoat",
"Suit", "Shawl", "Dress", "Vest", "Underwear"
]
# Initialize the model
model = CC_model()
model_path = hf_hub_download(repo_id="mohamdlog/CC", filename="CC_net.pt")
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
# Define preprocessing pipeline
def preprocess_image(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
return transform(image).unsqueeze(0)
# Define classification function
def classify_image(image):
input_tensor = preprocess_image(image)
with torch.no_grad():
output = model(input_tensor)
# Get predicted class and confidence
probabilities = torch.nn.functional.softmax(output, dim=1)
predicted_class_idx = output.argmax(dim=1).item()
predicted_class = CLOTHING1M_CLASSES[predicted_class_idx]
confidence = probabilities[0][predicted_class_idx].item()
return f"Category: {predicted_class}\nConfidence: {confidence:.2f}"
# Create Gradio interface
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(label="Uploaded Image"),
outputs=gr.Text(label="Predicted Clothing"),
title="Clothing Category Classifier",
description = """
**Upload an image of clothing, and the model will predict its category.**
Try using an image that doesn't belong to any of the available categories, and see how the result differs!
**Categories:**
| T-Shirt | Shirt | Knitwear | Chiffon | Sweater | Hoodie | Windbreaker |
| Jacket | Downcoat | Suit | Shawl | Dress | Vest | Underwear |
""",
examples=[[str(file)] for file in Path("examples").glob("*")],
flagging_mode="never",
theme="soft"
)
# Launch the interface
if __name__ == "__main__":
interface.launch()