Update app.py
Browse files
app.py
CHANGED
@@ -2,256 +2,48 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
import open_clip
|
4 |
from PIL import Image
|
5 |
-
import os
|
6 |
from classifier import few_shot_fault_classification
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
)
|
14 |
|
15 |
-
|
16 |
-
if 'nominal_images' not in st.session_state:
|
17 |
-
st.session_state.nominal_images = []
|
18 |
-
st.session_state.nominal_descriptions = []
|
19 |
-
st.session_state.nominal_filenames = []
|
20 |
-
|
21 |
-
if 'defective_images' not in st.session_state:
|
22 |
-
st.session_state.defective_images = []
|
23 |
-
st.session_state.defective_descriptions = []
|
24 |
-
st.session_state.defective_filenames = []
|
25 |
-
|
26 |
-
if 'test_results' not in st.session_state:
|
27 |
-
st.session_state.test_results = []
|
28 |
|
29 |
-
|
30 |
-
st.session_state.device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
-
# Load model
|
32 |
-
with st.spinner("Loading CLIP model..."):
|
33 |
-
model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai')
|
34 |
-
model = model.to(st.session_state.device)
|
35 |
-
model.eval()
|
36 |
-
st.session_state.model = model
|
37 |
-
st.session_state.preprocess = preprocess
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
return image
|
45 |
-
return None
|
46 |
|
47 |
-
|
48 |
-
if not images:
|
49 |
-
return
|
50 |
-
|
51 |
-
rows = (len(images) + columns - 1) // columns
|
52 |
-
|
53 |
-
for i in range(rows):
|
54 |
-
cols = st.columns(columns)
|
55 |
-
for j in range(columns):
|
56 |
-
idx = i * columns + j
|
57 |
-
if idx < len(images):
|
58 |
-
with cols[j]:
|
59 |
-
st.image(images[idx], caption=captions[idx], use_column_width=True)
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
|
67 |
-
|
68 |
-
with tab1:
|
69 |
-
st.header("Upload Reference Images")
|
70 |
-
|
71 |
-
# Create two columns for nominal and defective images
|
72 |
-
col1, col2 = st.columns(2)
|
73 |
-
|
74 |
-
with col1:
|
75 |
-
st.subheader("Good Parts (Nominal)")
|
76 |
-
|
77 |
-
# Input for nominal description
|
78 |
-
nominal_desc = st.text_input("Nominal Description",
|
79 |
-
value="Good part without defects")
|
80 |
-
|
81 |
-
# File uploader for nominal images
|
82 |
-
nominal_files = st.file_uploader("Upload images of good parts (10 recommended)",
|
83 |
-
type=["jpg", "jpeg", "png"],
|
84 |
-
accept_multiple_files=True,
|
85 |
-
key="nominal_upload")
|
86 |
-
|
87 |
-
if nominal_files and st.button("Add Nominal Images"):
|
88 |
-
for file in nominal_files:
|
89 |
-
img = process_uploaded_image(file)
|
90 |
-
if img:
|
91 |
-
# Preprocess the image for the model
|
92 |
-
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
|
93 |
-
st.session_state.nominal_images.append(preprocessed_img)
|
94 |
-
st.session_state.nominal_descriptions.append(nominal_desc)
|
95 |
-
st.session_state.nominal_filenames.append(file.name)
|
96 |
-
|
97 |
-
st.success(f"Added {len(nominal_files)} nominal images!")
|
98 |
-
|
99 |
-
# Display current nominal image count
|
100 |
-
st.write(f"Current nominal images: {len(st.session_state.nominal_images)}")
|
101 |
-
|
102 |
-
# Display nominal images in a grid if we have any
|
103 |
-
if st.session_state.nominal_images and st.session_state.nominal_filenames:
|
104 |
-
st.subheader("Current Nominal Images")
|
105 |
-
# We'll display the filenames instead of the actual tensor images
|
106 |
-
display_placeholders = [f"Image {i+1}" for i in range(len(st.session_state.nominal_filenames))]
|
107 |
-
st.write(", ".join(st.session_state.nominal_filenames))
|
108 |
-
|
109 |
-
if st.button("Clear Nominal Images"):
|
110 |
-
st.session_state.nominal_images = []
|
111 |
-
st.session_state.nominal_descriptions = []
|
112 |
-
st.session_state.nominal_filenames = []
|
113 |
-
st.success("Cleared all nominal images!")
|
114 |
-
|
115 |
-
with col2:
|
116 |
-
st.subheader("Defective Parts")
|
117 |
-
|
118 |
-
# Input for defective description
|
119 |
-
defective_desc = st.text_input("Defective Description",
|
120 |
-
value="Part with visible defects")
|
121 |
-
|
122 |
-
# File uploader for defective images
|
123 |
-
defective_files = st.file_uploader("Upload images of defective parts (10 recommended)",
|
124 |
-
type=["jpg", "jpeg", "png"],
|
125 |
-
accept_multiple_files=True,
|
126 |
-
key="defective_upload")
|
127 |
-
|
128 |
-
if defective_files and st.button("Add Defective Images"):
|
129 |
-
for file in defective_files:
|
130 |
-
img = process_uploaded_image(file)
|
131 |
-
if img:
|
132 |
-
# Preprocess the image for the model
|
133 |
-
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
|
134 |
-
st.session_state.defective_images.append(preprocessed_img)
|
135 |
-
st.session_state.defective_descriptions.append(defective_desc)
|
136 |
-
st.session_state.defective_filenames.append(file.name)
|
137 |
-
|
138 |
-
st.success(f"Added {len(defective_files)} defective images!")
|
139 |
-
|
140 |
-
# Display current defective image count
|
141 |
-
st.write(f"Current defective images: {len(st.session_state.defective_images)}")
|
142 |
-
|
143 |
-
# Display defective images in a grid if we have any
|
144 |
-
if st.session_state.defective_images and st.session_state.defective_filenames:
|
145 |
-
st.subheader("Current Defective Images")
|
146 |
-
# We'll display the filenames instead of the actual tensor images
|
147 |
-
st.write(", ".join(st.session_state.defective_filenames))
|
148 |
-
|
149 |
-
if st.button("Clear Defective Images"):
|
150 |
-
st.session_state.defective_images = []
|
151 |
-
st.session_state.defective_descriptions = []
|
152 |
-
st.session_state.defective_filenames = []
|
153 |
-
st.success("Cleared all defective images!")
|
154 |
-
|
155 |
-
with tab2:
|
156 |
-
st.header("Test Image Classification")
|
157 |
-
|
158 |
-
# Check if we have enough reference images
|
159 |
-
if len(st.session_state.nominal_images) == 0 or len(st.session_state.defective_images) == 0:
|
160 |
-
st.warning("You need to upload at least one nominal image and one defective image before testing.")
|
161 |
-
else:
|
162 |
-
st.write("Upload a test image to classify it as nominal or defective.")
|
163 |
-
|
164 |
-
# File uploader for test image
|
165 |
-
test_files = st.file_uploader("Upload test images",
|
166 |
-
type=["jpg", "jpeg", "png"],
|
167 |
-
accept_multiple_files=True,
|
168 |
-
key="test_upload")
|
169 |
-
|
170 |
-
if test_files:
|
171 |
-
test_images = []
|
172 |
-
test_image_displays = []
|
173 |
-
test_filenames = []
|
174 |
-
|
175 |
-
for file in test_files:
|
176 |
-
img = process_uploaded_image(file)
|
177 |
-
if img:
|
178 |
-
test_image_displays.append(img)
|
179 |
-
# Preprocess the image for the model
|
180 |
-
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
|
181 |
-
test_images.append(preprocessed_img.squeeze(0))
|
182 |
-
test_filenames.append(file.name)
|
183 |
-
|
184 |
-
# Display the test images
|
185 |
-
if test_image_displays:
|
186 |
-
display_image_grid(
|
187 |
-
test_image_displays,
|
188 |
-
[f"Test: {name}" for name in test_filenames]
|
189 |
-
)
|
190 |
-
|
191 |
-
if st.button("Classify Images"):
|
192 |
-
with st.spinner("Classifying images..."):
|
193 |
-
# Create output directory
|
194 |
-
os.makedirs("./results", exist_ok=True)
|
195 |
-
|
196 |
-
# Run classification
|
197 |
-
results = few_shot_fault_classification(
|
198 |
-
model=st.session_state.model,
|
199 |
-
test_images=test_images,
|
200 |
-
test_image_filenames=test_filenames,
|
201 |
-
nominal_images=[img.squeeze(0) for img in st.session_state.nominal_images],
|
202 |
-
nominal_descriptions=st.session_state.nominal_descriptions,
|
203 |
-
defective_images=[img.squeeze(0) for img in st.session_state.defective_images],
|
204 |
-
defective_descriptions=st.session_state.defective_descriptions,
|
205 |
-
num_few_shot_nominal_imgs=len(st.session_state.nominal_images),
|
206 |
-
device=st.session_state.device,
|
207 |
-
file_path="./results",
|
208 |
-
file_name="classification_results.csv"
|
209 |
-
)
|
210 |
-
|
211 |
-
for i, result in enumerate(results):
|
212 |
-
st.session_state.test_results.append({
|
213 |
-
"image": test_image_displays[i],
|
214 |
-
"filename": test_filenames[i],
|
215 |
-
**result
|
216 |
-
})
|
217 |
-
|
218 |
-
# Display classification results
|
219 |
-
for result in results:
|
220 |
-
classification = result["classification_result"]
|
221 |
-
color = "green" if classification == "Nominal" else "red"
|
222 |
-
|
223 |
-
st.markdown(f"### {result['image_name']}: <span style='color:{color}'>{classification}</span>", unsafe_allow_html=True)
|
224 |
-
|
225 |
-
# Display probabilities
|
226 |
-
col1, col2 = st.columns(2)
|
227 |
-
with col1:
|
228 |
-
st.metric("Good Part Probability", f"{result['non_defect_prob']:.3f}")
|
229 |
-
with col2:
|
230 |
-
st.metric("Defective Part Probability", f"{result['defect_prob']:.3f}")
|
231 |
-
|
232 |
-
with tab3:
|
233 |
-
st.header("Classification Results")
|
234 |
-
|
235 |
-
if not st.session_state.test_results:
|
236 |
-
st.info("No classification results yet. Test some images in the 'Test Classification' tab.")
|
237 |
-
else:
|
238 |
-
st.write(f"Total images classified: {len(st.session_state.test_results)}")
|
239 |
-
|
240 |
-
# Display results in a table
|
241 |
-
results_df = [{
|
242 |
-
"Filename": r["filename"],
|
243 |
-
"Classification": r["classification_result"],
|
244 |
-
"Good Prob": f"{r['non_defect_prob']:.3f}",
|
245 |
-
"Defect Prob": f"{r['defect_prob']:.3f}",
|
246 |
-
"Time": r["datetime_of_operation"].split('T')[0]
|
247 |
-
} for r in st.session_state.test_results]
|
248 |
-
|
249 |
-
st.table(results_df)
|
250 |
-
|
251 |
-
# Option to clear results
|
252 |
-
if st.button("Clear Results"):
|
253 |
-
st.session_state.test_results = []
|
254 |
-
st.success("Classification results cleared!")
|
255 |
|
256 |
-
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import open_clip
|
4 |
from PIL import Image
|
|
|
5 |
from classifier import few_shot_fault_classification
|
6 |
|
7 |
+
# Load lightweight CLIP model
|
8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai')
|
10 |
+
model = model.to(device)
|
11 |
+
model.eval()
|
|
|
12 |
|
13 |
+
st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
st.markdown("Upload **10 Nominal Images**, **10 Defective Images**, and one or more **Test Images** to classify.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
col1, col2 = st.columns(2)
|
18 |
+
with col1:
|
19 |
+
nominal_files = st.file_uploader("Upload Nominal Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
|
20 |
+
with col2:
|
21 |
+
defective_files = st.file_uploader("Upload Defective Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
|
|
|
|
|
22 |
|
23 |
+
test_files = st.file_uploader("Upload Test Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
if st.button("Classify Test Images"):
|
26 |
+
if len(nominal_files) < 1 or len(defective_files) < 1 or len(test_files) < 1:
|
27 |
+
st.warning("Please upload at least 1 image in each category.")
|
28 |
+
else:
|
29 |
+
st.info("Running classification...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
nominal_imgs = [preprocess(Image.open(f).convert("RGB")).to(device) for f in nominal_files]
|
32 |
+
defective_imgs = [preprocess(Image.open(f).convert("RGB")).to(device) for f in defective_files]
|
33 |
+
test_imgs = [preprocess(Image.open(f).convert("RGB")).to(device) for f in test_files]
|
34 |
+
|
35 |
+
results = few_shot_fault_classification(
|
36 |
+
model=model,
|
37 |
+
test_images=test_imgs,
|
38 |
+
test_image_filenames=[f.name for f in test_files],
|
39 |
+
nominal_images=nominal_imgs,
|
40 |
+
nominal_descriptions=[f.name for f in nominal_files],
|
41 |
+
defective_images=defective_imgs,
|
42 |
+
defective_descriptions=[f.name for f in defective_files],
|
43 |
+
num_few_shot_nominal_imgs=len(nominal_files),
|
44 |
+
device=device
|
45 |
+
)
|
46 |
+
|
47 |
+
for res in results:
|
48 |
+
st.write(f"**{res['image_path']}** ➜ {res['classification_result']} "
|
49 |
+
f"(Nominal: {res['non_defect_prob']}, Defective: {res['defect_prob']})")
|