Update app.py
Browse files
app.py
CHANGED
@@ -2,48 +2,256 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
import open_clip
|
4 |
from PIL import Image
|
|
|
5 |
from classifier import few_shot_fault_classification
|
6 |
|
7 |
-
#
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
test_imgs = [preprocess(Image.open(f).convert("RGB")).unsqueeze(0).to(device) for f in test_files]
|
34 |
-
|
35 |
-
results = few_shot_fault_classification(
|
36 |
-
model=model,
|
37 |
-
test_images=[img.squeeze(0) for img in test_imgs],
|
38 |
-
test_image_filenames=[f.name for f in test_files],
|
39 |
-
nominal_images=[img.squeeze(0) for img in nominal_imgs],
|
40 |
-
nominal_descriptions=[f.name for f in nominal_files],
|
41 |
-
defective_images=[img.squeeze(0) for img in defective_imgs],
|
42 |
-
defective_descriptions=[f.name for f in defective_files],
|
43 |
-
num_few_shot_nominal_imgs=len(nominal_files),
|
44 |
-
device=device
|
45 |
-
)
|
46 |
-
|
47 |
-
for res in results:
|
48 |
-
st.write(f"**{res['image_path']}** ➜ {res['classification_result']} "
|
49 |
-
f"(Nominal: {res['non_defect_prob']}, Defective: {res['defect_prob']})")
|
|
|
2 |
import torch
|
3 |
import open_clip
|
4 |
from PIL import Image
|
5 |
+
import os
|
6 |
from classifier import few_shot_fault_classification
|
7 |
|
8 |
+
# Set page configuration
|
9 |
+
st.set_page_config(
|
10 |
+
page_title="Industrial QC Image Classifier",
|
11 |
+
page_icon="🔍",
|
12 |
+
layout="wide"
|
13 |
+
)
|
14 |
|
15 |
+
# Initialize session state variables if they don't exist
|
16 |
+
if 'nominal_images' not in st.session_state:
|
17 |
+
st.session_state.nominal_images = []
|
18 |
+
st.session_state.nominal_descriptions = []
|
19 |
+
st.session_state.nominal_filenames = []
|
20 |
+
|
21 |
+
if 'defective_images' not in st.session_state:
|
22 |
+
st.session_state.defective_images = []
|
23 |
+
st.session_state.defective_descriptions = []
|
24 |
+
st.session_state.defective_filenames = []
|
25 |
+
|
26 |
+
if 'test_results' not in st.session_state:
|
27 |
+
st.session_state.test_results = []
|
28 |
|
29 |
+
if 'model' not in st.session_state:
|
30 |
+
st.session_state.device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
# Load model
|
32 |
+
with st.spinner("Loading CLIP model..."):
|
33 |
+
model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai')
|
34 |
+
model = model.to(st.session_state.device)
|
35 |
+
model.eval()
|
36 |
+
st.session_state.model = model
|
37 |
+
st.session_state.preprocess = preprocess
|
38 |
|
39 |
+
def process_uploaded_image(uploaded_file):
|
40 |
+
if uploaded_file is not None:
|
41 |
+
image = Image.open(uploaded_file)
|
42 |
+
if image.mode != "RGB":
|
43 |
+
image = image.convert("RGB")
|
44 |
+
return image
|
45 |
+
return None
|
46 |
|
47 |
+
def display_image_grid(images, captions, columns=5):
|
48 |
+
if not images:
|
49 |
+
return
|
50 |
+
|
51 |
+
rows = (len(images) + columns - 1) // columns
|
52 |
+
|
53 |
+
for i in range(rows):
|
54 |
+
cols = st.columns(columns)
|
55 |
+
for j in range(columns):
|
56 |
+
idx = i * columns + j
|
57 |
+
if idx < len(images):
|
58 |
+
with cols[j]:
|
59 |
+
st.image(images[idx], caption=captions[idx], use_column_width=True)
|
60 |
|
61 |
+
def main():
|
62 |
+
st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
|
63 |
+
st.markdown("Upload **Nominal Images** (good parts), **Defective Images** (bad parts), and **Test Images** to classify.")
|
64 |
+
|
65 |
+
# Create tabs for different functionality
|
66 |
+
tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
|
67 |
+
|
68 |
+
with tab1:
|
69 |
+
st.header("Upload Reference Images")
|
70 |
+
|
71 |
+
# Create two columns for nominal and defective images
|
72 |
+
col1, col2 = st.columns(2)
|
73 |
+
|
74 |
+
with col1:
|
75 |
+
st.subheader("Good Parts (Nominal)")
|
76 |
+
|
77 |
+
# Input for nominal description
|
78 |
+
nominal_desc = st.text_input("Nominal Description",
|
79 |
+
value="Good part without defects")
|
80 |
+
|
81 |
+
# File uploader for nominal images
|
82 |
+
nominal_files = st.file_uploader("Upload images of good parts (10 recommended)",
|
83 |
+
type=["jpg", "jpeg", "png"],
|
84 |
+
accept_multiple_files=True,
|
85 |
+
key="nominal_upload")
|
86 |
+
|
87 |
+
if nominal_files and st.button("Add Nominal Images"):
|
88 |
+
for file in nominal_files:
|
89 |
+
img = process_uploaded_image(file)
|
90 |
+
if img:
|
91 |
+
# Preprocess the image for the model
|
92 |
+
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
|
93 |
+
st.session_state.nominal_images.append(preprocessed_img)
|
94 |
+
st.session_state.nominal_descriptions.append(nominal_desc)
|
95 |
+
st.session_state.nominal_filenames.append(file.name)
|
96 |
+
|
97 |
+
st.success(f"Added {len(nominal_files)} nominal images!")
|
98 |
+
|
99 |
+
# Display current nominal image count
|
100 |
+
st.write(f"Current nominal images: {len(st.session_state.nominal_images)}")
|
101 |
+
|
102 |
+
# Display nominal images in a grid if we have any
|
103 |
+
if st.session_state.nominal_images and st.session_state.nominal_filenames:
|
104 |
+
st.subheader("Current Nominal Images")
|
105 |
+
# We'll display the filenames instead of the actual tensor images
|
106 |
+
display_placeholders = [f"Image {i+1}" for i in range(len(st.session_state.nominal_filenames))]
|
107 |
+
st.write(", ".join(st.session_state.nominal_filenames))
|
108 |
+
|
109 |
+
if st.button("Clear Nominal Images"):
|
110 |
+
st.session_state.nominal_images = []
|
111 |
+
st.session_state.nominal_descriptions = []
|
112 |
+
st.session_state.nominal_filenames = []
|
113 |
+
st.success("Cleared all nominal images!")
|
114 |
+
|
115 |
+
with col2:
|
116 |
+
st.subheader("Defective Parts")
|
117 |
+
|
118 |
+
# Input for defective description
|
119 |
+
defective_desc = st.text_input("Defective Description",
|
120 |
+
value="Part with visible defects")
|
121 |
+
|
122 |
+
# File uploader for defective images
|
123 |
+
defective_files = st.file_uploader("Upload images of defective parts (10 recommended)",
|
124 |
+
type=["jpg", "jpeg", "png"],
|
125 |
+
accept_multiple_files=True,
|
126 |
+
key="defective_upload")
|
127 |
+
|
128 |
+
if defective_files and st.button("Add Defective Images"):
|
129 |
+
for file in defective_files:
|
130 |
+
img = process_uploaded_image(file)
|
131 |
+
if img:
|
132 |
+
# Preprocess the image for the model
|
133 |
+
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
|
134 |
+
st.session_state.defective_images.append(preprocessed_img)
|
135 |
+
st.session_state.defective_descriptions.append(defective_desc)
|
136 |
+
st.session_state.defective_filenames.append(file.name)
|
137 |
+
|
138 |
+
st.success(f"Added {len(defective_files)} defective images!")
|
139 |
+
|
140 |
+
# Display current defective image count
|
141 |
+
st.write(f"Current defective images: {len(st.session_state.defective_images)}")
|
142 |
+
|
143 |
+
# Display defective images in a grid if we have any
|
144 |
+
if st.session_state.defective_images and st.session_state.defective_filenames:
|
145 |
+
st.subheader("Current Defective Images")
|
146 |
+
# We'll display the filenames instead of the actual tensor images
|
147 |
+
st.write(", ".join(st.session_state.defective_filenames))
|
148 |
+
|
149 |
+
if st.button("Clear Defective Images"):
|
150 |
+
st.session_state.defective_images = []
|
151 |
+
st.session_state.defective_descriptions = []
|
152 |
+
st.session_state.defective_filenames = []
|
153 |
+
st.success("Cleared all defective images!")
|
154 |
+
|
155 |
+
with tab2:
|
156 |
+
st.header("Test Image Classification")
|
157 |
+
|
158 |
+
# Check if we have enough reference images
|
159 |
+
if len(st.session_state.nominal_images) == 0 or len(st.session_state.defective_images) == 0:
|
160 |
+
st.warning("You need to upload at least one nominal image and one defective image before testing.")
|
161 |
+
else:
|
162 |
+
st.write("Upload a test image to classify it as nominal or defective.")
|
163 |
+
|
164 |
+
# File uploader for test image
|
165 |
+
test_files = st.file_uploader("Upload test images",
|
166 |
+
type=["jpg", "jpeg", "png"],
|
167 |
+
accept_multiple_files=True,
|
168 |
+
key="test_upload")
|
169 |
+
|
170 |
+
if test_files:
|
171 |
+
test_images = []
|
172 |
+
test_image_displays = []
|
173 |
+
test_filenames = []
|
174 |
+
|
175 |
+
for file in test_files:
|
176 |
+
img = process_uploaded_image(file)
|
177 |
+
if img:
|
178 |
+
test_image_displays.append(img)
|
179 |
+
# Preprocess the image for the model
|
180 |
+
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
|
181 |
+
test_images.append(preprocessed_img.squeeze(0))
|
182 |
+
test_filenames.append(file.name)
|
183 |
+
|
184 |
+
# Display the test images
|
185 |
+
if test_image_displays:
|
186 |
+
display_image_grid(
|
187 |
+
test_image_displays,
|
188 |
+
[f"Test: {name}" for name in test_filenames]
|
189 |
+
)
|
190 |
+
|
191 |
+
if st.button("Classify Images"):
|
192 |
+
with st.spinner("Classifying images..."):
|
193 |
+
# Create output directory
|
194 |
+
os.makedirs("./results", exist_ok=True)
|
195 |
+
|
196 |
+
# Run classification
|
197 |
+
results = few_shot_fault_classification(
|
198 |
+
model=st.session_state.model,
|
199 |
+
test_images=test_images,
|
200 |
+
test_image_filenames=test_filenames,
|
201 |
+
nominal_images=[img.squeeze(0) for img in st.session_state.nominal_images],
|
202 |
+
nominal_descriptions=st.session_state.nominal_descriptions,
|
203 |
+
defective_images=[img.squeeze(0) for img in st.session_state.defective_images],
|
204 |
+
defective_descriptions=st.session_state.defective_descriptions,
|
205 |
+
num_few_shot_nominal_imgs=len(st.session_state.nominal_images),
|
206 |
+
device=st.session_state.device,
|
207 |
+
file_path="./results",
|
208 |
+
file_name="classification_results.csv"
|
209 |
+
)
|
210 |
+
|
211 |
+
for i, result in enumerate(results):
|
212 |
+
st.session_state.test_results.append({
|
213 |
+
"image": test_image_displays[i],
|
214 |
+
"filename": test_filenames[i],
|
215 |
+
**result
|
216 |
+
})
|
217 |
+
|
218 |
+
# Display classification results
|
219 |
+
for result in results:
|
220 |
+
classification = result["classification_result"]
|
221 |
+
color = "green" if classification == "Nominal" else "red"
|
222 |
+
|
223 |
+
st.markdown(f"### {result['image_name']}: <span style='color:{color}'>{classification}</span>", unsafe_allow_html=True)
|
224 |
+
|
225 |
+
# Display probabilities
|
226 |
+
col1, col2 = st.columns(2)
|
227 |
+
with col1:
|
228 |
+
st.metric("Good Part Probability", f"{result['non_defect_prob']:.3f}")
|
229 |
+
with col2:
|
230 |
+
st.metric("Defective Part Probability", f"{result['defect_prob']:.3f}")
|
231 |
+
|
232 |
+
with tab3:
|
233 |
+
st.header("Classification Results")
|
234 |
+
|
235 |
+
if not st.session_state.test_results:
|
236 |
+
st.info("No classification results yet. Test some images in the 'Test Classification' tab.")
|
237 |
+
else:
|
238 |
+
st.write(f"Total images classified: {len(st.session_state.test_results)}")
|
239 |
+
|
240 |
+
# Display results in a table
|
241 |
+
results_df = [{
|
242 |
+
"Filename": r["filename"],
|
243 |
+
"Classification": r["classification_result"],
|
244 |
+
"Good Prob": f"{r['non_defect_prob']:.3f}",
|
245 |
+
"Defect Prob": f"{r['defect_prob']:.3f}",
|
246 |
+
"Time": r["datetime_of_operation"].split('T')[0]
|
247 |
+
} for r in st.session_state.test_results]
|
248 |
+
|
249 |
+
st.table(results_df)
|
250 |
+
|
251 |
+
# Option to clear results
|
252 |
+
if st.button("Clear Results"):
|
253 |
+
st.session_state.test_results = []
|
254 |
+
st.success("Classification results cleared!")
|
255 |
|
256 |
+
if __name__ == "__main__":
|
257 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|