File size: 3,334 Bytes
f322570
 
 
 
ca77fd5
 
f322570
 
 
 
ca77fd5
f322570
ca77fd5
 
 
 
f322570
ca77fd5
5bde481
ca77fd5
 
 
 
 
 
 
 
 
eca7b02
 
591d257
4e539c0
9cfa170
 
 
 
 
 
 
4e539c0
 
 
 
 
d844a2a
eca7b02
d844a2a
 
eca7b02
 
d844a2a
 
eca7b02
 
 
 
 
 
 
 
 
f322570
ca77fd5
f322570
d844a2a
 
 
 
f322570
 
d844a2a
 
 
 
 
ca77fd5
d844a2a
ca77fd5
 
 
d844a2a
ca77fd5
eca7b02
ca77fd5
 
f322570
d844a2a
 
eca7b02
 
ca77fd5
 
eca7b02
d844a2a
ca77fd5
 
d844a2a
 
 
 
 
 
 
 
 
 
ca77fd5
d844a2a
 
f322570
 
 
13685f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import gradio as gr
import torch.nn as nn
import torch.nn.functional as F
import os
from pathlib import Path
from torch.nn import init
import torchvision.transforms as transforms
from PIL import Image

# ... [Keep all your existing model definitions and initialization code] ...

# Precompute example image paths
example_dir = "examples"
example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir) 
                 if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]

# Custom CSS for styling
css = """
.centered-examples {
    margin: 0 auto !important;
    justify-content: center !important;
    gap: 8px !important;
}
.centered-examples .thumb {
    height: 100px !important;
    width: 100px !important;
    object-fit: cover !important;
}
"""

# PREPROCESSING PIPELINE (ADD THIS BACK)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Precompute example image paths
example_dir = "examples"
example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir) 
                 if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]

def predict(img_path):
    """Process single image and return prediction"""
    if not img_path:
        return "Please select or upload an image first"
    
    try:
        image = Image.open(img_path).convert('RGB')
        tensor = preprocess(image).unsqueeze(0)
        
        with torch.inference_mode():
            outputs = model(tensor)
            _, pred = torch.max(outputs, 1)
        
        return classes[pred.item()]
    
    except Exception as e:
        return f"Error: {str(e)}"

with gr.Blocks(title="Animal Classifier", css=css) as demo:
    gr.Markdown("## 🐾 Animal Classifier")
    gr.Markdown("Select an image below or upload your own, then click Classify")
    
    # Store current image path
    current_image = gr.State()
    
    with gr.Row():
        with gr.Column():
            image_preview = gr.Image(label="Selected Image", type="filepath")
            upload_btn = gr.UploadButton("Upload Custom Image", file_types=["image"])
            classify_btn = gr.Button("Classify 🚀", variant="primary")
        result = gr.Textbox(label="Prediction", lines=3)
    
    # Example gallery at bottom
    with gr.Row(variant="panel"):
        examples_gallery = gr.Gallery(
            value=example_images,
            label="Example Images (Click to Select)",
            columns=7,
            height=120,
            allow_preview=False,
            elem_classes=["centered-examples"]
        )

    # Handle image selection from examples
    def select_example(evt: gr.SelectData):
        return example_images[evt.index]
    
    examples_gallery.select(
        fn=select_example,
        outputs=[image_preview, current_image],
        show_progress=False
    )

    # Handle custom uploads
    upload_btn.upload(
        fn=lambda file: (file.name, file.name),
        inputs=upload_btn,
        outputs=[image_preview, current_image]
    )

    # Handle classification
    classify_btn.click(
        fn=predict,
        inputs=current_image,
        outputs=result
    )

if __name__ == "__main__":
    demo.launch(show_error=True)