Aarookie commited on
Commit
dae28cf
·
verified ·
1 Parent(s): f14bb22

Upload enhanced-waste-classification-webapp.py

Browse files
enhanced-waste-classification-webapp.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ import numpy as np
6
+ import json
7
+ import os
8
+ from datetime import datetime
9
+
10
+ # Define classification categories
11
+ waste_categories = ["General Waste", "Recyclable Waste"]
12
+
13
+ # More specific waste item types for tracking
14
+ waste_items = [
15
+ "Water Bottle", "Plastic Bag", "Food Container", "Paper", "Cardboard",
16
+ "Glass Bottle", "Aluminum Can", "Food Waste", "Coffee Cup", "Other"
17
+ ]
18
+
19
+ # Load CLIP model and processor (will be loaded from Hugging Face Hub)
20
+ model_name = "openai/clip-vit-base-patch16"
21
+ model = CLIPModel.from_pretrained(model_name)
22
+ processor = CLIPProcessor.from_pretrained(model_name)
23
+
24
+ # File to store historical data
25
+ HISTORY_FILE = "waste_history.json"
26
+
27
+ def load_history():
28
+ """Load historical waste classification data from file."""
29
+ if os.path.exists(HISTORY_FILE):
30
+ try:
31
+ with open(HISTORY_FILE, "r") as f:
32
+ return json.load(f)
33
+ except:
34
+ return {"classifications": []}
35
+ return {"classifications": []}
36
+
37
+ def save_history(history):
38
+ """Save historical waste classification data to file."""
39
+ with open(HISTORY_FILE, "w") as f:
40
+ json.dump(history, f)
41
+
42
+ def classify_waste_category(image):
43
+ """Classify waste as General or Recyclable using CLIP model."""
44
+ if image is None:
45
+ return None, None
46
+
47
+ # Convert to PIL Image if needed
48
+ if isinstance(image, np.ndarray):
49
+ image = Image.fromarray(image)
50
+
51
+ # Process the image with CLIP for waste category
52
+ inputs = processor(text=waste_categories, images=image, return_tensors="pt", padding=True)
53
+
54
+ with torch.no_grad():
55
+ outputs = model(**inputs)
56
+
57
+ logits_per_image = outputs.logits_per_image
58
+ probs = logits_per_image.softmax(dim=1).numpy()[0]
59
+
60
+ # Get the prediction and confidence
61
+ predicted_class_idx = np.argmax(probs)
62
+ predicted_category = waste_categories[predicted_class_idx]
63
+ confidence = probs[predicted_class_idx]
64
+
65
+ return predicted_category, confidence
66
+
67
+ def classify_waste_item(image):
68
+ """Identify the specific waste item type using CLIP model."""
69
+ if image is None:
70
+ return None, None
71
+
72
+ # Convert to PIL Image if needed
73
+ if isinstance(image, np.ndarray):
74
+ image = Image.fromarray(image)
75
+
76
+ # Process the image with CLIP for specific waste item
77
+ inputs = processor(text=waste_items, images=image, return_tensors="pt", padding=True)
78
+
79
+ with torch.no_grad():
80
+ outputs = model(**inputs)
81
+
82
+ logits_per_image = outputs.logits_per_image
83
+ probs = logits_per_image.softmax(dim=1).numpy()[0]
84
+
85
+ # Get the prediction and confidence
86
+ predicted_item_idx = np.argmax(probs)
87
+ predicted_item = waste_items[predicted_item_idx]
88
+ confidence = probs[predicted_item_idx]
89
+
90
+ return predicted_item, confidence
91
+
92
+ def process_image(image, waste_item_selected=None):
93
+ """Main function to process images from webcam or upload."""
94
+ if image is None:
95
+ return "No image provided", None, None, None
96
+
97
+ # Classify waste category (General or Recyclable)
98
+ category, category_confidence = classify_waste_category(image)
99
+
100
+ # Get specific waste item type
101
+ if waste_item_selected and waste_item_selected != "Auto-detect":
102
+ item = waste_item_selected
103
+ item_confidence = 1.0 # User manually selected, so confidence is 100%
104
+ else:
105
+ item, item_confidence = classify_waste_item(image)
106
+
107
+ # Format the result
108
+ result = f"This appears to be: {item}\nClassified as: {category}\n"
109
+ result += f"Category confidence: {category_confidence:.2f}"
110
+
111
+ # Create dictionaries for visualizations
112
+ category_probs = {
113
+ waste_categories[0]: 1-category_confidence if category == waste_categories[1] else category_confidence,
114
+ waste_categories[1]: category_confidence if category == waste_categories[1] else 1-category_confidence
115
+ }
116
+
117
+ # Record this classification in history
118
+ history = load_history()
119
+ history["classifications"].append({
120
+ "timestamp": datetime.now().isoformat(),
121
+ "waste_item": item,
122
+ "waste_category": category,
123
+ "confidence": float(category_confidence)
124
+ })
125
+ save_history(history)
126
+
127
+ # Generate statistics for the history tab
128
+ item_counts = {}
129
+ category_counts = {"General Waste": 0, "Recyclable Waste": 0}
130
+
131
+ for entry in history["classifications"]:
132
+ item = entry["waste_item"]
133
+ category = entry["waste_category"]
134
+
135
+ if item not in item_counts:
136
+ item_counts[item] = {"General Waste": 0, "Recyclable Waste": 0}
137
+
138
+ item_counts[item][category] += 1
139
+ category_counts[category] += 1
140
+
141
+ # Format history data for visualization
142
+ history_data = []
143
+ for item, counts in item_counts.items():
144
+ for category, count in counts.items():
145
+ if count > 0:
146
+ history_data.append({"item": item, "category": category, "count": count})
147
+
148
+ return result, category_probs, history_data, category_counts
149
+
150
+ # Create Gradio interface
151
+ with gr.Blocks(title="EcoCan - Waste Classification") as demo:
152
+ gr.Markdown("""
153
+ # EcoCan - AI Waste Classification
154
+
155
+ Use your webcam or phone camera to take a photo of waste and classify it as General Waste or Recyclable Waste.
156
+ """)
157
+
158
+ with gr.Tabs():
159
+ with gr.Tab("Classify Waste"):
160
+ with gr.Row():
161
+ with gr.Column():
162
+ # Input options: webcam, upload, or file upload
163
+ input_image = gr.Image(
164
+ sources=["webcam", "upload"],
165
+ type="numpy",
166
+ label="Take a photo or upload an image"
167
+ )
168
+
169
+ # Optional manual selection
170
+ waste_item_dropdown = gr.Dropdown(
171
+ choices=["Auto-detect"] + waste_items,
172
+ value="Auto-detect",
173
+ label="Or manually select waste item type (optional)"
174
+ )
175
+
176
+ # Classify button
177
+ classify_button = gr.Button("Classify Waste", variant="primary")
178
+
179
+ with gr.Column():
180
+ # Output components
181
+ result_text = gr.Textbox(label="Classification Result")
182
+ confidence_chart = gr.BarPlot(
183
+ x="category",
184
+ y="confidence",
185
+ title="Classification Confidence",
186
+ y_lim=[0, 1],
187
+ tooltip=["category", "confidence"],
188
+ color="category"
189
+ )
190
+
191
+ with gr.Tab("History"):
192
+ gr.Markdown("### Waste Classification History")
193
+
194
+ with gr.Row():
195
+ with gr.Column():
196
+ history_chart = gr.BarPlot(
197
+ x="item",
198
+ y="count",
199
+ color="category",
200
+ group="category",
201
+ title="Items Classified by Category",
202
+ tooltip=["item", "category", "count"],
203
+ height=400
204
+ )
205
+
206
+ with gr.Column():
207
+ category_pie = gr.PieChart(
208
+ label="Waste Categories Distribution",
209
+ type="value",
210
+ height=400
211
+ )
212
+
213
+ # Set up event handlers
214
+ classify_button.click(
215
+ process_image,
216
+ inputs=[input_image, waste_item_dropdown],
217
+ outputs=[result_text, confidence_chart, history_chart, category_pie]
218
+ )
219
+
220
+ # Also process when image is captured without clicking button
221
+ input_image.change(
222
+ process_image,
223
+ inputs=[input_image, waste_item_dropdown],
224
+ outputs=[result_text, confidence_chart, history_chart, category_pie]
225
+ )
226
+
227
+ # Launch locally for testing
228
+ if __name__ == "__main__":
229
+ demo.launch()