clip / app.py
fmegahed's picture
Update app.py
95f56a4 verified
raw
history blame
12.3 kB
import streamlit as st
import torch
import open_clip
from PIL import Image
import os
from classifier import few_shot_fault_classification
# Set page configuration
st.set_page_config(
page_title="Industrial QC Image Classifier",
page_icon="🔍",
layout="wide"
)
# Initialize session state variables if they don't exist
if 'nominal_images' not in st.session_state:
st.session_state.nominal_images = []
st.session_state.nominal_descriptions = []
st.session_state.nominal_filenames = []
if 'defective_images' not in st.session_state:
st.session_state.defective_images = []
st.session_state.defective_descriptions = []
st.session_state.defective_filenames = []
if 'test_results' not in st.session_state:
st.session_state.test_results = []
if 'model' not in st.session_state:
st.session_state.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
with st.spinner("Loading CLIP model..."):
model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai')
model = model.to(st.session_state.device)
model.eval()
st.session_state.model = model
st.session_state.preprocess = preprocess
def process_uploaded_image(uploaded_file):
if uploaded_file is not None:
image = Image.open(uploaded_file)
if image.mode != "RGB":
image = image.convert("RGB")
return image
return None
def display_image_grid(images, captions, columns=5):
if not images:
return
rows = (len(images) + columns - 1) // columns
for i in range(rows):
cols = st.columns(columns)
for j in range(columns):
idx = i * columns + j
if idx < len(images):
with cols[j]:
st.image(images[idx], caption=captions[idx], use_column_width=True)
def main():
st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
st.markdown("Upload **Nominal Images** (good parts), **Defective Images** (bad parts), and **Test Images** to classify.")
# Create tabs for different functionality
tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
with tab1:
st.header("Upload Reference Images")
# Create two columns for nominal and defective images
col1, col2 = st.columns(2)
with col1:
st.subheader("Good Parts (Nominal)")
# Input for nominal description
nominal_desc = st.text_input("Nominal Description",
value="Good part without defects")
# File uploader for nominal images
nominal_files = st.file_uploader("Upload images of good parts (10 recommended)",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True,
key="nominal_upload")
if nominal_files and st.button("Add Nominal Images"):
for file in nominal_files:
img = process_uploaded_image(file)
if img:
# Preprocess the image for the model
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
st.session_state.nominal_images.append(preprocessed_img)
st.session_state.nominal_descriptions.append(nominal_desc)
st.session_state.nominal_filenames.append(file.name)
st.success(f"Added {len(nominal_files)} nominal images!")
# Display current nominal image count
st.write(f"Current nominal images: {len(st.session_state.nominal_images)}")
# Display nominal images in a grid if we have any
if st.session_state.nominal_images and st.session_state.nominal_filenames:
st.subheader("Current Nominal Images")
# We'll display the filenames instead of the actual tensor images
display_placeholders = [f"Image {i+1}" for i in range(len(st.session_state.nominal_filenames))]
st.write(", ".join(st.session_state.nominal_filenames))
if st.button("Clear Nominal Images"):
st.session_state.nominal_images = []
st.session_state.nominal_descriptions = []
st.session_state.nominal_filenames = []
st.success("Cleared all nominal images!")
with col2:
st.subheader("Defective Parts")
# Input for defective description
defective_desc = st.text_input("Defective Description",
value="Part with visible defects")
# File uploader for defective images
defective_files = st.file_uploader("Upload images of defective parts (10 recommended)",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True,
key="defective_upload")
if defective_files and st.button("Add Defective Images"):
for file in defective_files:
img = process_uploaded_image(file)
if img:
# Preprocess the image for the model
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
st.session_state.defective_images.append(preprocessed_img)
st.session_state.defective_descriptions.append(defective_desc)
st.session_state.defective_filenames.append(file.name)
st.success(f"Added {len(defective_files)} defective images!")
# Display current defective image count
st.write(f"Current defective images: {len(st.session_state.defective_images)}")
# Display defective images in a grid if we have any
if st.session_state.defective_images and st.session_state.defective_filenames:
st.subheader("Current Defective Images")
# We'll display the filenames instead of the actual tensor images
st.write(", ".join(st.session_state.defective_filenames))
if st.button("Clear Defective Images"):
st.session_state.defective_images = []
st.session_state.defective_descriptions = []
st.session_state.defective_filenames = []
st.success("Cleared all defective images!")
with tab2:
st.header("Test Image Classification")
# Check if we have enough reference images
if len(st.session_state.nominal_images) == 0 or len(st.session_state.defective_images) == 0:
st.warning("You need to upload at least one nominal image and one defective image before testing.")
else:
st.write("Upload a test image to classify it as nominal or defective.")
# File uploader for test image
test_files = st.file_uploader("Upload test images",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True,
key="test_upload")
if test_files:
test_images = []
test_image_displays = []
test_filenames = []
for file in test_files:
img = process_uploaded_image(file)
if img:
test_image_displays.append(img)
# Preprocess the image for the model
preprocessed_img = st.session_state.preprocess(img).unsqueeze(0)
test_images.append(preprocessed_img.squeeze(0))
test_filenames.append(file.name)
# Display the test images
if test_image_displays:
display_image_grid(
test_image_displays,
[f"Test: {name}" for name in test_filenames]
)
if st.button("Classify Images"):
with st.spinner("Classifying images..."):
# Create output directory
os.makedirs("./results", exist_ok=True)
# Run classification
results = few_shot_fault_classification(
model=st.session_state.model,
test_images=test_images,
test_image_filenames=test_filenames,
nominal_images=[img.squeeze(0) for img in st.session_state.nominal_images],
nominal_descriptions=st.session_state.nominal_descriptions,
defective_images=[img.squeeze(0) for img in st.session_state.defective_images],
defective_descriptions=st.session_state.defective_descriptions,
num_few_shot_nominal_imgs=len(st.session_state.nominal_images),
device=st.session_state.device,
file_path="./results",
file_name="classification_results.csv"
)
for i, result in enumerate(results):
st.session_state.test_results.append({
"image": test_image_displays[i],
"filename": test_filenames[i],
**result
})
# Display classification results
for result in results:
classification = result["classification_result"]
color = "green" if classification == "Nominal" else "red"
st.markdown(f"### {result['image_name']}: <span style='color:{color}'>{classification}</span>", unsafe_allow_html=True)
# Display probabilities
col1, col2 = st.columns(2)
with col1:
st.metric("Good Part Probability", f"{result['non_defect_prob']:.3f}")
with col2:
st.metric("Defective Part Probability", f"{result['defect_prob']:.3f}")
with tab3:
st.header("Classification Results")
if not st.session_state.test_results:
st.info("No classification results yet. Test some images in the 'Test Classification' tab.")
else:
st.write(f"Total images classified: {len(st.session_state.test_results)}")
# Display results in a table
results_df = [{
"Filename": r["filename"],
"Classification": r["classification_result"],
"Good Prob": f"{r['non_defect_prob']:.3f}",
"Defect Prob": f"{r['defect_prob']:.3f}",
"Time": r["datetime_of_operation"].split('T')[0]
} for r in st.session_state.test_results]
st.table(results_df)
# Option to clear results
if st.button("Clear Results"):
st.session_state.test_results = []
st.success("Classification results cleared!")
if __name__ == "__main__":
main()