fmegahed commited on
Commit
4f8dfb9
·
verified ·
1 Parent(s): a1e1c29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -245
app.py CHANGED
@@ -2,256 +2,48 @@ import streamlit as st
2
  import torch
3
  import open_clip
4
  from PIL import Image
5
- import os
6
  from classifier import few_shot_fault_classification
7
 
8
- # Set page configuration
9
- st.set_page_config(
10
- page_title="Industrial QC Image Classifier",
11
- page_icon="🔍",
12
- layout="wide"
13
- )
14
 
15
- # Initialize session state variables if they don't exist
16
- if 'nominal_images' not in st.session_state:
17
- st.session_state.nominal_images = []
18
- st.session_state.nominal_descriptions = []
19
- st.session_state.nominal_filenames = []
20
-
21
- if 'defective_images' not in st.session_state:
22
- st.session_state.defective_images = []
23
- st.session_state.defective_descriptions = []
24
- st.session_state.defective_filenames = []
25
-
26
- if 'test_results' not in st.session_state:
27
- st.session_state.test_results = []
28
 
29
- if 'model' not in st.session_state:
30
- st.session_state.device = "cuda" if torch.cuda.is_available() else "cpu"
31
- # Load model
32
- with st.spinner("Loading CLIP model..."):
33
- model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai')
34
- model = model.to(st.session_state.device)
35
- model.eval()
36
- st.session_state.model = model
37
- st.session_state.preprocess = preprocess
38
 
39
- def process_uploaded_image(uploaded_file):
40
- if uploaded_file is not None:
41
- image = Image.open(uploaded_file)
42
- if image.mode != "RGB":
43
- image = image.convert("RGB")
44
- return image
45
- return None
46
 
47
- def display_image_grid(images, captions, columns=5):
48
- if not images:
49
- return
50
-
51
- rows = (len(images) + columns - 1) // columns
52
-
53
- for i in range(rows):
54
- cols = st.columns(columns)
55
- for j in range(columns):
56
- idx = i * columns + j
57
- if idx < len(images):
58
- with cols[j]:
59
- st.image(images[idx], caption=captions[idx], use_column_width=True)
60
 
61
- def main():
62
- st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
63
- st.markdown("Upload **Nominal Images** (good parts), **Defective Images** (bad parts), and **Test Images** to classify.")
64
-
65
- # Create tabs for different functionality
66
- tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
67
-
68
- with tab1:
69
- st.header("Upload Reference Images")
70
-
71
- # Create two columns for nominal and defective images
72
- col1, col2 = st.columns(2)
73
-
74
- with col1:
75
- st.subheader("Good Parts (Nominal)")
76
-
77
- # Input for nominal description
78
- nominal_desc = st.text_input("Nominal Description",
79
- value="Good part without defects")
80
-
81
- # File uploader for nominal images
82
- nominal_files = st.file_uploader("Upload images of good parts (10 recommended)",
83
- type=["jpg", "jpeg", "png"],
84
- accept_multiple_files=True,
85
- key="nominal_upload")
86
-
87
- if nominal_files and st.button("Add Nominal Images"):
88
- for file in nominal_files:
89
- img = process_uploaded_image(file)
90
- if img:
91
- # Preprocess the image for the model
92
- preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
93
- st.session_state.nominal_images.append(preprocessed_img)
94
- st.session_state.nominal_descriptions.append(nominal_desc)
95
- st.session_state.nominal_filenames.append(file.name)
96
-
97
- st.success(f"Added {len(nominal_files)} nominal images!")
98
-
99
- # Display current nominal image count
100
- st.write(f"Current nominal images: {len(st.session_state.nominal_images)}")
101
-
102
- # Display nominal images in a grid if we have any
103
- if st.session_state.nominal_images and st.session_state.nominal_filenames:
104
- st.subheader("Current Nominal Images")
105
- # We'll display the filenames instead of the actual tensor images
106
- display_placeholders = [f"Image {i+1}" for i in range(len(st.session_state.nominal_filenames))]
107
- st.write(", ".join(st.session_state.nominal_filenames))
108
-
109
- if st.button("Clear Nominal Images"):
110
- st.session_state.nominal_images = []
111
- st.session_state.nominal_descriptions = []
112
- st.session_state.nominal_filenames = []
113
- st.success("Cleared all nominal images!")
114
-
115
- with col2:
116
- st.subheader("Defective Parts")
117
-
118
- # Input for defective description
119
- defective_desc = st.text_input("Defective Description",
120
- value="Part with visible defects")
121
-
122
- # File uploader for defective images
123
- defective_files = st.file_uploader("Upload images of defective parts (10 recommended)",
124
- type=["jpg", "jpeg", "png"],
125
- accept_multiple_files=True,
126
- key="defective_upload")
127
-
128
- if defective_files and st.button("Add Defective Images"):
129
- for file in defective_files:
130
- img = process_uploaded_image(file)
131
- if img:
132
- # Preprocess the image for the model
133
- preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
134
- st.session_state.defective_images.append(preprocessed_img)
135
- st.session_state.defective_descriptions.append(defective_desc)
136
- st.session_state.defective_filenames.append(file.name)
137
-
138
- st.success(f"Added {len(defective_files)} defective images!")
139
-
140
- # Display current defective image count
141
- st.write(f"Current defective images: {len(st.session_state.defective_images)}")
142
-
143
- # Display defective images in a grid if we have any
144
- if st.session_state.defective_images and st.session_state.defective_filenames:
145
- st.subheader("Current Defective Images")
146
- # We'll display the filenames instead of the actual tensor images
147
- st.write(", ".join(st.session_state.defective_filenames))
148
-
149
- if st.button("Clear Defective Images"):
150
- st.session_state.defective_images = []
151
- st.session_state.defective_descriptions = []
152
- st.session_state.defective_filenames = []
153
- st.success("Cleared all defective images!")
154
-
155
- with tab2:
156
- st.header("Test Image Classification")
157
-
158
- # Check if we have enough reference images
159
- if len(st.session_state.nominal_images) == 0 or len(st.session_state.defective_images) == 0:
160
- st.warning("You need to upload at least one nominal image and one defective image before testing.")
161
- else:
162
- st.write("Upload a test image to classify it as nominal or defective.")
163
-
164
- # File uploader for test image
165
- test_files = st.file_uploader("Upload test images",
166
- type=["jpg", "jpeg", "png"],
167
- accept_multiple_files=True,
168
- key="test_upload")
169
-
170
- if test_files:
171
- test_images = []
172
- test_image_displays = []
173
- test_filenames = []
174
-
175
- for file in test_files:
176
- img = process_uploaded_image(file)
177
- if img:
178
- test_image_displays.append(img)
179
- # Preprocess the image for the model
180
- preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
181
- test_images.append(preprocessed_img.squeeze(0))
182
- test_filenames.append(file.name)
183
-
184
- # Display the test images
185
- if test_image_displays:
186
- display_image_grid(
187
- test_image_displays,
188
- [f"Test: {name}" for name in test_filenames]
189
- )
190
-
191
- if st.button("Classify Images"):
192
- with st.spinner("Classifying images..."):
193
- # Create output directory
194
- os.makedirs("./results", exist_ok=True)
195
-
196
- # Run classification
197
- results = few_shot_fault_classification(
198
- model=st.session_state.model,
199
- test_images=test_images,
200
- test_image_filenames=test_filenames,
201
- nominal_images=[img.squeeze(0) for img in st.session_state.nominal_images],
202
- nominal_descriptions=st.session_state.nominal_descriptions,
203
- defective_images=[img.squeeze(0) for img in st.session_state.defective_images],
204
- defective_descriptions=st.session_state.defective_descriptions,
205
- num_few_shot_nominal_imgs=len(st.session_state.nominal_images),
206
- device=st.session_state.device,
207
- file_path="./results",
208
- file_name="classification_results.csv"
209
- )
210
-
211
- for i, result in enumerate(results):
212
- st.session_state.test_results.append({
213
- "image": test_image_displays[i],
214
- "filename": test_filenames[i],
215
- **result
216
- })
217
-
218
- # Display classification results
219
- for result in results:
220
- classification = result["classification_result"]
221
- color = "green" if classification == "Nominal" else "red"
222
-
223
- st.markdown(f"### {result['image_name']}: <span style='color:{color}'>{classification}</span>", unsafe_allow_html=True)
224
-
225
- # Display probabilities
226
- col1, col2 = st.columns(2)
227
- with col1:
228
- st.metric("Good Part Probability", f"{result['non_defect_prob']:.3f}")
229
- with col2:
230
- st.metric("Defective Part Probability", f"{result['defect_prob']:.3f}")
231
-
232
- with tab3:
233
- st.header("Classification Results")
234
-
235
- if not st.session_state.test_results:
236
- st.info("No classification results yet. Test some images in the 'Test Classification' tab.")
237
- else:
238
- st.write(f"Total images classified: {len(st.session_state.test_results)}")
239
-
240
- # Display results in a table
241
- results_df = [{
242
- "Filename": r["filename"],
243
- "Classification": r["classification_result"],
244
- "Good Prob": f"{r['non_defect_prob']:.3f}",
245
- "Defect Prob": f"{r['defect_prob']:.3f}",
246
- "Time": r["datetime_of_operation"].split('T')[0]
247
- } for r in st.session_state.test_results]
248
-
249
- st.table(results_df)
250
-
251
- # Option to clear results
252
- if st.button("Clear Results"):
253
- st.session_state.test_results = []
254
- st.success("Classification results cleared!")
255
 
256
- if __name__ == "__main__":
257
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import open_clip
4
  from PIL import Image
 
5
  from classifier import few_shot_fault_classification
6
 
7
+ # Load lightweight CLIP model
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai')
10
+ model = model.to(device)
11
+ model.eval()
 
12
 
13
+ st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ st.markdown("Upload **10 Nominal Images**, **10 Defective Images**, and one or more **Test Images** to classify.")
 
 
 
 
 
 
 
 
16
 
17
+ col1, col2 = st.columns(2)
18
+ with col1:
19
+ nominal_files = st.file_uploader("Upload Nominal Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
20
+ with col2:
21
+ defective_files = st.file_uploader("Upload Defective Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
 
 
22
 
23
+ test_files = st.file_uploader("Upload Test Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ if st.button("Classify Test Images"):
26
+ if len(nominal_files) < 1 or len(defective_files) < 1 or len(test_files) < 1:
27
+ st.warning("Please upload at least 1 image in each category.")
28
+ else:
29
+ st.info("Running classification...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ nominal_imgs = [preprocess(Image.open(f).convert("RGB")).to(device) for f in nominal_files]
32
+ defective_imgs = [preprocess(Image.open(f).convert("RGB")).to(device) for f in defective_files]
33
+ test_imgs = [preprocess(Image.open(f).convert("RGB")).to(device) for f in test_files]
34
+
35
+ results = few_shot_fault_classification(
36
+ model=model,
37
+ test_images=test_imgs,
38
+ test_image_filenames=[f.name for f in test_files],
39
+ nominal_images=nominal_imgs,
40
+ nominal_descriptions=[f.name for f in nominal_files],
41
+ defective_images=defective_imgs,
42
+ defective_descriptions=[f.name for f in defective_files],
43
+ num_few_shot_nominal_imgs=len(nominal_files),
44
+ device=device
45
+ )
46
+
47
+ for res in results:
48
+ st.write(f"**{res['image_path']}** ➜ {res['classification_result']} "
49
+ f"(Nominal: {res['non_defect_prob']}, Defective: {res['defect_prob']})")