mohamdlog commited on
Commit
8ce0469
·
1 Parent(s): 3ef806a

Add app.py and examples

Browse files
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from pathlib import Path
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from huggingface_hub import hf_hub_download
8
+ from ResNet_for_CC import CC_model
9
+
10
+ # Define the Clothing1M class labels
11
+ CLOTHING1M_CLASSES = [
12
+ "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater",
13
+ "Hoodie", "Windbreaker", "Jacket", "Downcoat",
14
+ "Suit", "Shawl", "Dress", "Vest", "Underwear"
15
+ ]
16
+
17
+ # Initialize the model
18
+ model = CC_model()
19
+ model_path = hf_hub_download(repo_id="mohamdlog/CC", filename="CC_net.pt")
20
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
21
+ model.eval()
22
+
23
+ # Define preprocessing pipeline
24
+ def preprocess_image(image):
25
+ if isinstance(image, np.ndarray):
26
+ image = Image.fromarray(image)
27
+ transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ ])
31
+ return transform(image).unsqueeze(0)
32
+
33
+ # Define classification function
34
+ def classify_image(image):
35
+ input_tensor = preprocess_image(image)
36
+ with torch.no_grad():
37
+ output = model(input_tensor)
38
+
39
+ # Get predicted class and confidence
40
+ probabilities = torch.nn.functional.softmax(output, dim=1)
41
+ predicted_class_idx = output.argmax(dim=1).item()
42
+ predicted_class = CLOTHING1M_CLASSES[predicted_class_idx]
43
+ confidence = probabilities[0][predicted_class_idx].item()
44
+
45
+ return f"Category: {predicted_class}\nConfidence: {confidence:.2f}"
46
+
47
+ # Create Gradio interface
48
+ interface = gr.Interface(
49
+ fn=classify_image,
50
+ inputs=gr.Image(label="Uploaded Image"),
51
+ outputs=gr.Text(label="Predicted Clothing"),
52
+ title="Clothing Category Classifier",
53
+ description="Upload an image of clothing, and the model will predict its category.",
54
+ examples = [[str(file)] for file in Path("examples").glob("*")],
55
+ flagging_mode="never",
56
+ theme="soft"
57
+ )
58
+
59
+ # Launch the interface
60
+ if __name__ == "__main__":
61
+ interface.launch()
examples/example1.jpg ADDED
examples/example2.jpg ADDED
examples/example3.jpg ADDED
examples/example4.jpg ADDED
examples/example5.jpg ADDED