viniciusgribas commited on
Commit
9495c6e
·
1 Parent(s): 4f68470

novos files

Browse files
Files changed (2) hide show
  1. app.py +134 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint disable=import-error
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ from transformers import ViTForImageClassification, ViTImageProcessor
9
+
10
+ # Load pre-trained Vision Transformer model
11
+ model_name = "google/vit-base-patch16-224"
12
+ model = ViTForImageClassification.from_pretrained(model_name)
13
+ processor = ViTImageProcessor.from_pretrained(model_name)
14
+
15
+ # Function to predict image class
16
+ def classify_image(image):
17
+ if image is None:
18
+ return None, None
19
+
20
+ # Process image
21
+ inputs = processor(images=image, return_tensors="pt")
22
+
23
+ # Make prediction
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+ logits = outputs.logits
27
+
28
+ # Get predicted class and probabilities
29
+ predicted_class_idx = logits.argmax(-1).item()
30
+ predicted_class = model.config.id2label[predicted_class_idx]
31
+
32
+ # Get top 5 predictions
33
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0]
34
+ top5_prob, top5_indices = torch.topk(probs, 5)
35
+
36
+ # Create plot for visualization
37
+ fig, ax = plt.subplots(figsize=(10, 5))
38
+
39
+ # Get class names and probabilities
40
+ classes = [model.config.id2label[idx.item()] for idx in top5_indices]
41
+ probabilities = [prob.item() * 100 for prob in top5_prob]
42
+
43
+ # Create horizontal bar chart
44
+ bars = ax.barh(classes, probabilities, color='#4C72B0')
45
+ ax.set_xlabel('Probability (%)')
46
+ ax.set_title('Top 5 Predictions')
47
+
48
+ # Add percentage labels
49
+ for i, bar in enumerate(bars):
50
+ width = bar.get_width()
51
+ ax.text(width + 1, bar.get_y() + bar.get_height()/2,
52
+ f'{probabilities[i]:.1f}%',
53
+ va='center', fontsize=10)
54
+
55
+ # Improve layout
56
+ plt.tight_layout()
57
+
58
+ return predicted_class, fig
59
+
60
+ # Create Gradio interface
61
+ with gr.Blocks(title="Image Classifier", theme=gr.themes.Soft()) as demo:
62
+ gr.Markdown(
63
+ """
64
+ # 🖼️ Image Classification Tool
65
+
66
+ This application uses a Vision Transformer (ViT) model to classify images into 1,000 different categories.
67
+
68
+ Upload an image or take a photo to see what the AI recognizes in it!
69
+ """
70
+ )
71
+
72
+ with gr.Row():
73
+ with gr.Column():
74
+ image_input = gr.Image(
75
+ label="Upload or capture an image",
76
+ type="pil",
77
+ height=400
78
+ )
79
+ classify_btn = gr.Button("Classify Image", variant="primary")
80
+
81
+ with gr.Column():
82
+ prediction = gr.Textbox(label="Prediction")
83
+ confidence_plot = gr.Plot(label="Confidence Levels")
84
+
85
+ # Add examples
86
+ example_images = [
87
+ "examples/dog.jpg",
88
+ "examples/cat.jpg",
89
+ "examples/coffee.jpg",
90
+ "examples/laptop.jpg",
91
+ "examples/beach.jpg"
92
+ ]
93
+
94
+ gr.Examples(
95
+ examples=example_images,
96
+ inputs=image_input,
97
+ outputs=[prediction, confidence_plot],
98
+ fn=classify_image,
99
+ cache_examples=True
100
+ )
101
+
102
+ # Set up the click event
103
+ classify_btn.click(
104
+ fn=classify_image,
105
+ inputs=image_input,
106
+ outputs=[prediction, confidence_plot]
107
+ )
108
+
109
+ # Set up the input change event
110
+ image_input.change(
111
+ fn=classify_image,
112
+ inputs=image_input,
113
+ outputs=[prediction, confidence_plot]
114
+ )
115
+
116
+ gr.Markdown("""
117
+ ### How it works
118
+
119
+ This tool uses a Vision Transformer (ViT) model pre-trained on ImageNet, enabling it to recognize 1,000
120
+ different object categories ranging from animals and plants to vehicles, household items, and more.
121
+
122
+ ### Applications
123
+
124
+ - **Content Categorization**: Automatically organize image libraries
125
+ - **Accessibility**: Help describe images for visually impaired users
126
+ - **Education**: Learn about objects in the world around you
127
+ - **Data Analysis**: Process and categorize large image datasets
128
+
129
+ Created by [Vinicius Guerra e Ribas](https://viniciusgribas.netlify.app/)
130
+ """)
131
+
132
+ # Launch the app
133
+ if __name__ == "__main__":
134
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=3.32.0
2
+ torch>=1.7.0
3
+ torchvision>=0.8.0
4
+ transformers>=4.26.0
5
+ matplotlib>=3.5.0
6
+ numpy>=1.20.0
7
+ Pillow>=8.0.0