|
import streamlit as st |
|
import torch |
|
import open_clip |
|
from PIL import Image |
|
import os |
|
from classifier import few_shot_fault_classification |
|
|
|
|
|
st.set_page_config( |
|
page_title="Industrial QC Image Classifier", |
|
page_icon="🔍", |
|
layout="wide" |
|
) |
|
|
|
|
|
if 'nominal_images' not in st.session_state: |
|
st.session_state.nominal_images = [] |
|
st.session_state.nominal_descriptions = [] |
|
st.session_state.nominal_filenames = [] |
|
|
|
if 'defective_images' not in st.session_state: |
|
st.session_state.defective_images = [] |
|
st.session_state.defective_descriptions = [] |
|
st.session_state.defective_filenames = [] |
|
|
|
if 'test_results' not in st.session_state: |
|
st.session_state.test_results = [] |
|
|
|
if 'model' not in st.session_state: |
|
st.session_state.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
with st.spinner("Loading CLIP model..."): |
|
model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai') |
|
model = model.to(st.session_state.device) |
|
model.eval() |
|
st.session_state.model = model |
|
st.session_state.preprocess = preprocess |
|
|
|
def process_uploaded_image(uploaded_file): |
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
return image |
|
return None |
|
|
|
def display_image_grid(images, captions, columns=5): |
|
if not images: |
|
return |
|
|
|
rows = (len(images) + columns - 1) // columns |
|
|
|
for i in range(rows): |
|
cols = st.columns(columns) |
|
for j in range(columns): |
|
idx = i * columns + j |
|
if idx < len(images): |
|
with cols[j]: |
|
st.image(images[idx], caption=captions[idx], use_column_width=True) |
|
|
|
def main(): |
|
st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)") |
|
st.markdown("Upload **Nominal Images** (good parts), **Defective Images** (bad parts), and **Test Images** to classify.") |
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"]) |
|
|
|
with tab1: |
|
st.header("Upload Reference Images") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("Good Parts (Nominal)") |
|
|
|
|
|
nominal_desc = st.text_input("Nominal Description", |
|
value="Good part without defects") |
|
|
|
|
|
nominal_files = st.file_uploader("Upload images of good parts (10 recommended)", |
|
type=["jpg", "jpeg", "png"], |
|
accept_multiple_files=True, |
|
key="nominal_upload") |
|
|
|
if nominal_files and st.button("Add Nominal Images"): |
|
for file in nominal_files: |
|
img = process_uploaded_image(file) |
|
if img: |
|
|
|
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0) |
|
st.session_state.nominal_images.append(preprocessed_img) |
|
st.session_state.nominal_descriptions.append(nominal_desc) |
|
st.session_state.nominal_filenames.append(file.name) |
|
|
|
st.success(f"Added {len(nominal_files)} nominal images!") |
|
|
|
|
|
st.write(f"Current nominal images: {len(st.session_state.nominal_images)}") |
|
|
|
|
|
if st.session_state.nominal_images and st.session_state.nominal_filenames: |
|
st.subheader("Current Nominal Images") |
|
|
|
display_placeholders = [f"Image {i+1}" for i in range(len(st.session_state.nominal_filenames))] |
|
st.write(", ".join(st.session_state.nominal_filenames)) |
|
|
|
if st.button("Clear Nominal Images"): |
|
st.session_state.nominal_images = [] |
|
st.session_state.nominal_descriptions = [] |
|
st.session_state.nominal_filenames = [] |
|
st.success("Cleared all nominal images!") |
|
|
|
with col2: |
|
st.subheader("Defective Parts") |
|
|
|
|
|
defective_desc = st.text_input("Defective Description", |
|
value="Part with visible defects") |
|
|
|
|
|
defective_files = st.file_uploader("Upload images of defective parts (10 recommended)", |
|
type=["jpg", "jpeg", "png"], |
|
accept_multiple_files=True, |
|
key="defective_upload") |
|
|
|
if defective_files and st.button("Add Defective Images"): |
|
for file in defective_files: |
|
img = process_uploaded_image(file) |
|
if img: |
|
|
|
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0) |
|
st.session_state.defective_images.append(preprocessed_img) |
|
st.session_state.defective_descriptions.append(defective_desc) |
|
st.session_state.defective_filenames.append(file.name) |
|
|
|
st.success(f"Added {len(defective_files)} defective images!") |
|
|
|
|
|
st.write(f"Current defective images: {len(st.session_state.defective_images)}") |
|
|
|
|
|
if st.session_state.defective_images and st.session_state.defective_filenames: |
|
st.subheader("Current Defective Images") |
|
|
|
st.write(", ".join(st.session_state.defective_filenames)) |
|
|
|
if st.button("Clear Defective Images"): |
|
st.session_state.defective_images = [] |
|
st.session_state.defective_descriptions = [] |
|
st.session_state.defective_filenames = [] |
|
st.success("Cleared all defective images!") |
|
|
|
with tab2: |
|
st.header("Test Image Classification") |
|
|
|
|
|
if len(st.session_state.nominal_images) == 0 or len(st.session_state.defective_images) == 0: |
|
st.warning("You need to upload at least one nominal image and one defective image before testing.") |
|
else: |
|
st.write("Upload a test image to classify it as nominal or defective.") |
|
|
|
|
|
test_files = st.file_uploader("Upload test images", |
|
type=["jpg", "jpeg", "png"], |
|
accept_multiple_files=True, |
|
key="test_upload") |
|
|
|
if test_files: |
|
test_images = [] |
|
test_image_displays = [] |
|
test_filenames = [] |
|
|
|
for file in test_files: |
|
img = process_uploaded_image(file) |
|
if img: |
|
test_image_displays.append(img) |
|
|
|
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0) |
|
test_images.append(preprocessed_img.squeeze(0)) |
|
test_filenames.append(file.name) |
|
|
|
|
|
if test_image_displays: |
|
display_image_grid( |
|
test_image_displays, |
|
[f"Test: {name}" for name in test_filenames] |
|
) |
|
|
|
if st.button("Classify Images"): |
|
with st.spinner("Classifying images..."): |
|
|
|
os.makedirs("./results", exist_ok=True) |
|
|
|
|
|
results = few_shot_fault_classification( |
|
model=st.session_state.model, |
|
test_images=test_images, |
|
test_image_filenames=test_filenames, |
|
nominal_images=[img.squeeze(0) for img in st.session_state.nominal_images], |
|
nominal_descriptions=st.session_state.nominal_descriptions, |
|
defective_images=[img.squeeze(0) for img in st.session_state.defective_images], |
|
defective_descriptions=st.session_state.defective_descriptions, |
|
num_few_shot_nominal_imgs=len(st.session_state.nominal_images), |
|
device=st.session_state.device, |
|
file_path="./results", |
|
file_name="classification_results.csv" |
|
) |
|
|
|
for i, result in enumerate(results): |
|
st.session_state.test_results.append({ |
|
"image": test_image_displays[i], |
|
"filename": test_filenames[i], |
|
**result |
|
}) |
|
|
|
|
|
for result in results: |
|
classification = result["classification_result"] |
|
color = "green" if classification == "Nominal" else "red" |
|
|
|
st.markdown(f"### {result['image_name']}: <span style='color:{color}'>{classification}</span>", unsafe_allow_html=True) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.metric("Good Part Probability", f"{result['non_defect_prob']:.3f}") |
|
with col2: |
|
st.metric("Defective Part Probability", f"{result['defect_prob']:.3f}") |
|
|
|
with tab3: |
|
st.header("Classification Results") |
|
|
|
if not st.session_state.test_results: |
|
st.info("No classification results yet. Test some images in the 'Test Classification' tab.") |
|
else: |
|
st.write(f"Total images classified: {len(st.session_state.test_results)}") |
|
|
|
|
|
results_df = [{ |
|
"Filename": r["filename"], |
|
"Classification": r["classification_result"], |
|
"Good Prob": f"{r['non_defect_prob']:.3f}", |
|
"Defect Prob": f"{r['defect_prob']:.3f}", |
|
"Time": r["datetime_of_operation"].split('T')[0] |
|
} for r in st.session_state.test_results] |
|
|
|
st.table(results_df) |
|
|
|
|
|
if st.button("Clear Results"): |
|
st.session_state.test_results = [] |
|
st.success("Classification results cleared!") |
|
|
|
if __name__ == "__main__": |
|
main() |