fmegahed commited on
Commit
95f56a4
·
verified ·
1 Parent(s): bf99d96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -37
app.py CHANGED
@@ -2,48 +2,256 @@ import streamlit as st
2
  import torch
3
  import open_clip
4
  from PIL import Image
 
5
  from classifier import few_shot_fault_classification
6
 
7
- # Load lightweight CLIP model
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai')
10
- model = model.to(device)
11
- model.eval()
 
12
 
13
- st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- st.markdown("Upload **10 Nominal Images**, **10 Defective Images**, and one or more **Test Images** to classify.")
 
 
 
 
 
 
 
 
16
 
17
- col1, col2 = st.columns(2)
18
- with col1:
19
- nominal_files = st.file_uploader("Upload Nominal Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
20
- with col2:
21
- defective_files = st.file_uploader("Upload Defective Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
 
 
22
 
23
- test_files = st.file_uploader("Upload Test Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- if st.button("Classify Test Images"):
26
- if len(nominal_files) < 1 or len(defective_files) < 1 or len(test_files) < 1:
27
- st.warning("Please upload at least 1 image in each category.")
28
- else:
29
- st.info("Running classification...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- nominal_imgs = [preprocess(Image.open(f).convert("RGB")).unsqueeze(0).to(device) for f in nominal_files]
32
- defective_imgs = [preprocess(Image.open(f).convert("RGB")).unsqueeze(0).to(device) for f in defective_files]
33
- test_imgs = [preprocess(Image.open(f).convert("RGB")).unsqueeze(0).to(device) for f in test_files]
34
-
35
- results = few_shot_fault_classification(
36
- model=model,
37
- test_images=[img.squeeze(0) for img in test_imgs],
38
- test_image_filenames=[f.name for f in test_files],
39
- nominal_images=[img.squeeze(0) for img in nominal_imgs],
40
- nominal_descriptions=[f.name for f in nominal_files],
41
- defective_images=[img.squeeze(0) for img in defective_imgs],
42
- defective_descriptions=[f.name for f in defective_files],
43
- num_few_shot_nominal_imgs=len(nominal_files),
44
- device=device
45
- )
46
-
47
- for res in results:
48
- st.write(f"**{res['image_path']}** ➜ {res['classification_result']} "
49
- f"(Nominal: {res['non_defect_prob']}, Defective: {res['defect_prob']})")
 
2
  import torch
3
  import open_clip
4
  from PIL import Image
5
+ import os
6
  from classifier import few_shot_fault_classification
7
 
8
+ # Set page configuration
9
+ st.set_page_config(
10
+ page_title="Industrial QC Image Classifier",
11
+ page_icon="🔍",
12
+ layout="wide"
13
+ )
14
 
15
+ # Initialize session state variables if they don't exist
16
+ if 'nominal_images' not in st.session_state:
17
+ st.session_state.nominal_images = []
18
+ st.session_state.nominal_descriptions = []
19
+ st.session_state.nominal_filenames = []
20
+
21
+ if 'defective_images' not in st.session_state:
22
+ st.session_state.defective_images = []
23
+ st.session_state.defective_descriptions = []
24
+ st.session_state.defective_filenames = []
25
+
26
+ if 'test_results' not in st.session_state:
27
+ st.session_state.test_results = []
28
 
29
+ if 'model' not in st.session_state:
30
+ st.session_state.device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ # Load model
32
+ with st.spinner("Loading CLIP model..."):
33
+ model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai')
34
+ model = model.to(st.session_state.device)
35
+ model.eval()
36
+ st.session_state.model = model
37
+ st.session_state.preprocess = preprocess
38
 
39
+ def process_uploaded_image(uploaded_file):
40
+ if uploaded_file is not None:
41
+ image = Image.open(uploaded_file)
42
+ if image.mode != "RGB":
43
+ image = image.convert("RGB")
44
+ return image
45
+ return None
46
 
47
+ def display_image_grid(images, captions, columns=5):
48
+ if not images:
49
+ return
50
+
51
+ rows = (len(images) + columns - 1) // columns
52
+
53
+ for i in range(rows):
54
+ cols = st.columns(columns)
55
+ for j in range(columns):
56
+ idx = i * columns + j
57
+ if idx < len(images):
58
+ with cols[j]:
59
+ st.image(images[idx], caption=captions[idx], use_column_width=True)
60
 
61
+ def main():
62
+ st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
63
+ st.markdown("Upload **Nominal Images** (good parts), **Defective Images** (bad parts), and **Test Images** to classify.")
64
+
65
+ # Create tabs for different functionality
66
+ tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
67
+
68
+ with tab1:
69
+ st.header("Upload Reference Images")
70
+
71
+ # Create two columns for nominal and defective images
72
+ col1, col2 = st.columns(2)
73
+
74
+ with col1:
75
+ st.subheader("Good Parts (Nominal)")
76
+
77
+ # Input for nominal description
78
+ nominal_desc = st.text_input("Nominal Description",
79
+ value="Good part without defects")
80
+
81
+ # File uploader for nominal images
82
+ nominal_files = st.file_uploader("Upload images of good parts (10 recommended)",
83
+ type=["jpg", "jpeg", "png"],
84
+ accept_multiple_files=True,
85
+ key="nominal_upload")
86
+
87
+ if nominal_files and st.button("Add Nominal Images"):
88
+ for file in nominal_files:
89
+ img = process_uploaded_image(file)
90
+ if img:
91
+ # Preprocess the image for the model
92
+ preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
93
+ st.session_state.nominal_images.append(preprocessed_img)
94
+ st.session_state.nominal_descriptions.append(nominal_desc)
95
+ st.session_state.nominal_filenames.append(file.name)
96
+
97
+ st.success(f"Added {len(nominal_files)} nominal images!")
98
+
99
+ # Display current nominal image count
100
+ st.write(f"Current nominal images: {len(st.session_state.nominal_images)}")
101
+
102
+ # Display nominal images in a grid if we have any
103
+ if st.session_state.nominal_images and st.session_state.nominal_filenames:
104
+ st.subheader("Current Nominal Images")
105
+ # We'll display the filenames instead of the actual tensor images
106
+ display_placeholders = [f"Image {i+1}" for i in range(len(st.session_state.nominal_filenames))]
107
+ st.write(", ".join(st.session_state.nominal_filenames))
108
+
109
+ if st.button("Clear Nominal Images"):
110
+ st.session_state.nominal_images = []
111
+ st.session_state.nominal_descriptions = []
112
+ st.session_state.nominal_filenames = []
113
+ st.success("Cleared all nominal images!")
114
+
115
+ with col2:
116
+ st.subheader("Defective Parts")
117
+
118
+ # Input for defective description
119
+ defective_desc = st.text_input("Defective Description",
120
+ value="Part with visible defects")
121
+
122
+ # File uploader for defective images
123
+ defective_files = st.file_uploader("Upload images of defective parts (10 recommended)",
124
+ type=["jpg", "jpeg", "png"],
125
+ accept_multiple_files=True,
126
+ key="defective_upload")
127
+
128
+ if defective_files and st.button("Add Defective Images"):
129
+ for file in defective_files:
130
+ img = process_uploaded_image(file)
131
+ if img:
132
+ # Preprocess the image for the model
133
+ preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
134
+ st.session_state.defective_images.append(preprocessed_img)
135
+ st.session_state.defective_descriptions.append(defective_desc)
136
+ st.session_state.defective_filenames.append(file.name)
137
+
138
+ st.success(f"Added {len(defective_files)} defective images!")
139
+
140
+ # Display current defective image count
141
+ st.write(f"Current defective images: {len(st.session_state.defective_images)}")
142
+
143
+ # Display defective images in a grid if we have any
144
+ if st.session_state.defective_images and st.session_state.defective_filenames:
145
+ st.subheader("Current Defective Images")
146
+ # We'll display the filenames instead of the actual tensor images
147
+ st.write(", ".join(st.session_state.defective_filenames))
148
+
149
+ if st.button("Clear Defective Images"):
150
+ st.session_state.defective_images = []
151
+ st.session_state.defective_descriptions = []
152
+ st.session_state.defective_filenames = []
153
+ st.success("Cleared all defective images!")
154
+
155
+ with tab2:
156
+ st.header("Test Image Classification")
157
+
158
+ # Check if we have enough reference images
159
+ if len(st.session_state.nominal_images) == 0 or len(st.session_state.defective_images) == 0:
160
+ st.warning("You need to upload at least one nominal image and one defective image before testing.")
161
+ else:
162
+ st.write("Upload a test image to classify it as nominal or defective.")
163
+
164
+ # File uploader for test image
165
+ test_files = st.file_uploader("Upload test images",
166
+ type=["jpg", "jpeg", "png"],
167
+ accept_multiple_files=True,
168
+ key="test_upload")
169
+
170
+ if test_files:
171
+ test_images = []
172
+ test_image_displays = []
173
+ test_filenames = []
174
+
175
+ for file in test_files:
176
+ img = process_uploaded_image(file)
177
+ if img:
178
+ test_image_displays.append(img)
179
+ # Preprocess the image for the model
180
+ preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
181
+ test_images.append(preprocessed_img.squeeze(0))
182
+ test_filenames.append(file.name)
183
+
184
+ # Display the test images
185
+ if test_image_displays:
186
+ display_image_grid(
187
+ test_image_displays,
188
+ [f"Test: {name}" for name in test_filenames]
189
+ )
190
+
191
+ if st.button("Classify Images"):
192
+ with st.spinner("Classifying images..."):
193
+ # Create output directory
194
+ os.makedirs("./results", exist_ok=True)
195
+
196
+ # Run classification
197
+ results = few_shot_fault_classification(
198
+ model=st.session_state.model,
199
+ test_images=test_images,
200
+ test_image_filenames=test_filenames,
201
+ nominal_images=[img.squeeze(0) for img in st.session_state.nominal_images],
202
+ nominal_descriptions=st.session_state.nominal_descriptions,
203
+ defective_images=[img.squeeze(0) for img in st.session_state.defective_images],
204
+ defective_descriptions=st.session_state.defective_descriptions,
205
+ num_few_shot_nominal_imgs=len(st.session_state.nominal_images),
206
+ device=st.session_state.device,
207
+ file_path="./results",
208
+ file_name="classification_results.csv"
209
+ )
210
+
211
+ for i, result in enumerate(results):
212
+ st.session_state.test_results.append({
213
+ "image": test_image_displays[i],
214
+ "filename": test_filenames[i],
215
+ **result
216
+ })
217
+
218
+ # Display classification results
219
+ for result in results:
220
+ classification = result["classification_result"]
221
+ color = "green" if classification == "Nominal" else "red"
222
+
223
+ st.markdown(f"### {result['image_name']}: <span style='color:{color}'>{classification}</span>", unsafe_allow_html=True)
224
+
225
+ # Display probabilities
226
+ col1, col2 = st.columns(2)
227
+ with col1:
228
+ st.metric("Good Part Probability", f"{result['non_defect_prob']:.3f}")
229
+ with col2:
230
+ st.metric("Defective Part Probability", f"{result['defect_prob']:.3f}")
231
+
232
+ with tab3:
233
+ st.header("Classification Results")
234
+
235
+ if not st.session_state.test_results:
236
+ st.info("No classification results yet. Test some images in the 'Test Classification' tab.")
237
+ else:
238
+ st.write(f"Total images classified: {len(st.session_state.test_results)}")
239
+
240
+ # Display results in a table
241
+ results_df = [{
242
+ "Filename": r["filename"],
243
+ "Classification": r["classification_result"],
244
+ "Good Prob": f"{r['non_defect_prob']:.3f}",
245
+ "Defect Prob": f"{r['defect_prob']:.3f}",
246
+ "Time": r["datetime_of_operation"].split('T')[0]
247
+ } for r in st.session_state.test_results]
248
+
249
+ st.table(results_df)
250
+
251
+ # Option to clear results
252
+ if st.button("Clear Results"):
253
+ st.session_state.test_results = []
254
+ st.success("Classification results cleared!")
255
 
256
+ if __name__ == "__main__":
257
+ main()