File size: 12,262 Bytes
b8d3576
 
 
 
95f56a4
b8d3576
 
95f56a4
 
 
 
 
 
b8d3576
95f56a4
 
 
 
 
 
 
 
 
 
 
 
 
b8d3576
95f56a4
 
 
 
 
 
 
 
 
b8d3576
95f56a4
 
 
 
 
 
 
b8d3576
95f56a4
 
 
 
 
 
 
 
 
 
 
 
 
b8d3576
95f56a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8d3576
95f56a4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import streamlit as st
import torch
import open_clip
from PIL import Image
import os
from classifier import few_shot_fault_classification

# Set page configuration
st.set_page_config(
    page_title="Industrial QC Image Classifier",
    page_icon="🔍",
    layout="wide"
)

# Initialize session state variables if they don't exist
if 'nominal_images' not in st.session_state:
    st.session_state.nominal_images = []
    st.session_state.nominal_descriptions = []
    st.session_state.nominal_filenames = []
    
if 'defective_images' not in st.session_state:
    st.session_state.defective_images = []
    st.session_state.defective_descriptions = []
    st.session_state.defective_filenames = []
    
if 'test_results' not in st.session_state:
    st.session_state.test_results = []

if 'model' not in st.session_state:
    st.session_state.device = "cuda" if torch.cuda.is_available() else "cpu"
    # Load model
    with st.spinner("Loading CLIP model..."):
        model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai')
        model = model.to(st.session_state.device)
        model.eval()
        st.session_state.model = model
        st.session_state.preprocess = preprocess

def process_uploaded_image(uploaded_file):
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        if image.mode != "RGB":
            image = image.convert("RGB")
        return image
    return None

def display_image_grid(images, captions, columns=5):
    if not images:
        return
    
    rows = (len(images) + columns - 1) // columns
    
    for i in range(rows):
        cols = st.columns(columns)
        for j in range(columns):
            idx = i * columns + j
            if idx < len(images):
                with cols[j]:
                    st.image(images[idx], caption=captions[idx], use_column_width=True)

def main():
    st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
    st.markdown("Upload **Nominal Images** (good parts), **Defective Images** (bad parts), and **Test Images** to classify.")
    
    # Create tabs for different functionality
    tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
    
    with tab1:
        st.header("Upload Reference Images")
        
        # Create two columns for nominal and defective images
        col1, col2 = st.columns(2)
        
        with col1:
            st.subheader("Good Parts (Nominal)")
            
            # Input for nominal description
            nominal_desc = st.text_input("Nominal Description", 
                                       value="Good part without defects")
            
            # File uploader for nominal images
            nominal_files = st.file_uploader("Upload images of good parts (10 recommended)", 
                                          type=["jpg", "jpeg", "png"], 
                                          accept_multiple_files=True,
                                          key="nominal_upload")
            
            if nominal_files and st.button("Add Nominal Images"):
                for file in nominal_files:
                    img = process_uploaded_image(file)
                    if img:
                        # Preprocess the image for the model
                        preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
                        st.session_state.nominal_images.append(preprocessed_img)
                        st.session_state.nominal_descriptions.append(nominal_desc)
                        st.session_state.nominal_filenames.append(file.name)
                
                st.success(f"Added {len(nominal_files)} nominal images!")
                
            # Display current nominal image count
            st.write(f"Current nominal images: {len(st.session_state.nominal_images)}")
            
            # Display nominal images in a grid if we have any
            if st.session_state.nominal_images and st.session_state.nominal_filenames:
                st.subheader("Current Nominal Images")
                # We'll display the filenames instead of the actual tensor images
                display_placeholders = [f"Image {i+1}" for i in range(len(st.session_state.nominal_filenames))]
                st.write(", ".join(st.session_state.nominal_filenames))
                
                if st.button("Clear Nominal Images"):
                    st.session_state.nominal_images = []
                    st.session_state.nominal_descriptions = []
                    st.session_state.nominal_filenames = []
                    st.success("Cleared all nominal images!")
            
        with col2:
            st.subheader("Defective Parts")
            
            # Input for defective description
            defective_desc = st.text_input("Defective Description", 
                                        value="Part with visible defects")
            
            # File uploader for defective images
            defective_files = st.file_uploader("Upload images of defective parts (10 recommended)", 
                                           type=["jpg", "jpeg", "png"], 
                                           accept_multiple_files=True,
                                           key="defective_upload")
            
            if defective_files and st.button("Add Defective Images"):
                for file in defective_files:
                    img = process_uploaded_image(file)
                    if img:
                        # Preprocess the image for the model
                        preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
                        st.session_state.defective_images.append(preprocessed_img)
                        st.session_state.defective_descriptions.append(defective_desc)
                        st.session_state.defective_filenames.append(file.name)
                
                st.success(f"Added {len(defective_files)} defective images!")
                
            # Display current defective image count
            st.write(f"Current defective images: {len(st.session_state.defective_images)}")
            
            # Display defective images in a grid if we have any
            if st.session_state.defective_images and st.session_state.defective_filenames:
                st.subheader("Current Defective Images")
                # We'll display the filenames instead of the actual tensor images
                st.write(", ".join(st.session_state.defective_filenames))
                
                if st.button("Clear Defective Images"):
                    st.session_state.defective_images = []
                    st.session_state.defective_descriptions = []
                    st.session_state.defective_filenames = []
                    st.success("Cleared all defective images!")
    
    with tab2:
        st.header("Test Image Classification")
        
        # Check if we have enough reference images
        if len(st.session_state.nominal_images) == 0 or len(st.session_state.defective_images) == 0:
            st.warning("You need to upload at least one nominal image and one defective image before testing.")
        else:
            st.write("Upload a test image to classify it as nominal or defective.")
            
            # File uploader for test image
            test_files = st.file_uploader("Upload test images", 
                                       type=["jpg", "jpeg", "png"],
                                       accept_multiple_files=True,
                                       key="test_upload")
            
            if test_files:
                test_images = []
                test_image_displays = []
                test_filenames = []
                
                for file in test_files:
                    img = process_uploaded_image(file)
                    if img:
                        test_image_displays.append(img)
                        # Preprocess the image for the model
                        preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
                        test_images.append(preprocessed_img.squeeze(0))
                        test_filenames.append(file.name)
                
                # Display the test images
                if test_image_displays:
                    display_image_grid(
                        test_image_displays, 
                        [f"Test: {name}" for name in test_filenames]
                    )
                    
                    if st.button("Classify Images"):
                        with st.spinner("Classifying images..."):
                            # Create output directory
                            os.makedirs("./results", exist_ok=True)
                            
                            # Run classification
                            results = few_shot_fault_classification(
                                model=st.session_state.model,
                                test_images=test_images,
                                test_image_filenames=test_filenames,
                                nominal_images=[img.squeeze(0) for img in st.session_state.nominal_images],
                                nominal_descriptions=st.session_state.nominal_descriptions,
                                defective_images=[img.squeeze(0) for img in st.session_state.defective_images],
                                defective_descriptions=st.session_state.defective_descriptions,
                                num_few_shot_nominal_imgs=len(st.session_state.nominal_images),
                                device=st.session_state.device,
                                file_path="./results",
                                file_name="classification_results.csv"
                            )
                            
                            for i, result in enumerate(results):
                                st.session_state.test_results.append({
                                    "image": test_image_displays[i],
                                    "filename": test_filenames[i],
                                    **result
                                })
                            
                            # Display classification results
                            for result in results:
                                classification = result["classification_result"]
                                color = "green" if classification == "Nominal" else "red"
                                
                                st.markdown(f"### {result['image_name']}: <span style='color:{color}'>{classification}</span>", unsafe_allow_html=True)
                                
                                # Display probabilities
                                col1, col2 = st.columns(2)
                                with col1:
                                    st.metric("Good Part Probability", f"{result['non_defect_prob']:.3f}")
                                with col2:
                                    st.metric("Defective Part Probability", f"{result['defect_prob']:.3f}")
    
    with tab3:
        st.header("Classification Results")
        
        if not st.session_state.test_results:
            st.info("No classification results yet. Test some images in the 'Test Classification' tab.")
        else:
            st.write(f"Total images classified: {len(st.session_state.test_results)}")
            
            # Display results in a table
            results_df = [{
                "Filename": r["filename"],
                "Classification": r["classification_result"],
                "Good Prob": f"{r['non_defect_prob']:.3f}",
                "Defect Prob": f"{r['defect_prob']:.3f}",
                "Time": r["datetime_of_operation"].split('T')[0]
            } for r in st.session_state.test_results]
            
            st.table(results_df)
            
            # Option to clear results
            if st.button("Clear Results"):
                st.session_state.test_results = []
                st.success("Classification results cleared!")

if __name__ == "__main__":
    main()