fmegahed commited on
Commit
0ac2665
·
verified ·
1 Parent(s): 4f8dfb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -35
app.py CHANGED
@@ -1,49 +1,181 @@
1
  import streamlit as st
2
  import torch
3
- import open_clip
4
  from PIL import Image
5
- from classifier import few_shot_fault_classification
 
 
 
 
6
 
7
- # Load lightweight CLIP model
 
 
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='openai')
10
- model = model.to(device)
 
11
  model.eval()
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
 
14
 
15
- st.markdown("Upload **10 Nominal Images**, **10 Defective Images**, and one or more **Test Images** to classify.")
16
 
17
- col1, col2 = st.columns(2)
18
- with col1:
19
- nominal_files = st.file_uploader("Upload Nominal Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
20
- with col2:
21
- defective_files = st.file_uploader("Upload Defective Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
22
 
23
- test_files = st.file_uploader("Upload Test Images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
 
24
 
25
- if st.button("Classify Test Images"):
26
- if len(nominal_files) < 1 or len(defective_files) < 1 or len(test_files) < 1:
27
- st.warning("Please upload at least 1 image in each category.")
28
- else:
29
- st.info("Running classification...")
30
-
31
- nominal_imgs = [preprocess(Image.open(f).convert("RGB")).to(device) for f in nominal_files]
32
- defective_imgs = [preprocess(Image.open(f).convert("RGB")).to(device) for f in defective_files]
33
- test_imgs = [preprocess(Image.open(f).convert("RGB")).to(device) for f in test_files]
34
-
35
- results = few_shot_fault_classification(
36
- model=model,
37
- test_images=test_imgs,
38
- test_image_filenames=[f.name for f in test_files],
39
- nominal_images=nominal_imgs,
40
- nominal_descriptions=[f.name for f in nominal_files],
41
- defective_images=defective_imgs,
42
- defective_descriptions=[f.name for f in defective_files],
43
- num_few_shot_nominal_imgs=len(nominal_files),
44
- device=device
 
 
 
 
 
 
 
 
 
 
 
45
  )
46
 
47
- for res in results:
48
- st.write(f"**{res['image_path']}** {res['classification_result']} "
49
- f"(Nominal: {res['non_defect_prob']}, Defective: {res['defect_prob']})")
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
+ import clip
4
  from PIL import Image
5
+ import os
6
+ import pandas as pd
7
+ from datetime import datetime
8
+ import torch.nn.functional as F
9
+ from typing import List
10
 
11
+ # Load secrets
12
+ openai_api_key = st.secrets.get("OPENAI_API_KEY")
13
+ # You can now use openai_api_key for anything requiring OpenAI access
14
+
15
+ # Device setup
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ # Load CLIP model + preprocess from OpenAI CLIP
19
+ model, preprocess = clip.load("ViT-L/14", device=device)
20
  model.eval()
21
 
22
+ # Ensure reproducibility
23
+ torch.set_grad_enabled(False)
24
+
25
+ # Import the few-shot classification function
26
+ # --- COPY YOUR FUNCTION DEFINITION BELOW DIRECTLY OR PUT IT IN A SEPARATE FILE ---
27
+ def few_shot_fault_classification(
28
+ test_images: List[Image.Image],
29
+ test_image_filenames: List[str],
30
+ nominal_images: List[Image.Image],
31
+ nominal_descriptions: List[str],
32
+ defective_images: List[Image.Image],
33
+ defective_descriptions: List[str],
34
+ num_few_shot_nominal_imgs: int,
35
+ file_path: str = '.',
36
+ file_name: str = 'image_classification_results.csv',
37
+ print_one_liner: bool = False
38
+ ):
39
+ if not isinstance(test_images, list): test_images = [test_images]
40
+ if not isinstance(test_image_filenames, list): test_image_filenames = [test_image_filenames]
41
+ if not isinstance(nominal_images, list): nominal_images = [nominal_images]
42
+ if not isinstance(nominal_descriptions, list): nominal_descriptions = [nominal_descriptions]
43
+ if not isinstance(defective_images, list): defective_images = [defective_images]
44
+ if not isinstance(defective_descriptions, list): defective_descriptions = [defective_descriptions]
45
+
46
+ csv_file = os.path.join(file_path, file_name)
47
+ results = []
48
+
49
+ with torch.no_grad():
50
+ nominal_features = torch.stack([model.encode_image(img).to(device) for img in nominal_images])
51
+ nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
52
+
53
+ defective_features = torch.stack([model.encode_image(img).to(device) for img in defective_images])
54
+ defective_features /= defective_features.norm(dim=-1, keepdim=True)
55
+
56
+ csv_data = []
57
+
58
+ for idx, test_img in enumerate(test_images):
59
+ test_features = model.encode_image(test_img).to(device)
60
+ test_features /= test_features.norm(dim=-1, keepdim=True)
61
+
62
+ max_nom_sim, max_def_sim = -float('inf'), -float('inf')
63
+ max_nom_idx, max_def_idx = -1, -1
64
+
65
+ for i in range(nominal_features.shape[0]):
66
+ sim = (test_features @ nominal_features[i].T).item()
67
+ if sim > max_nom_sim:
68
+ max_nom_sim, max_nom_idx = sim, i
69
+
70
+ for j in range(defective_features.shape[0]):
71
+ sim = (test_features @ defective_features[j].T).item()
72
+ if sim > max_def_sim:
73
+ max_def_sim, max_def_idx = sim, j
74
+
75
+ similarities = torch.tensor([max_nom_sim, max_def_sim])
76
+ probabilities = F.softmax(similarities, dim=0).tolist()
77
+ prob_nom, prob_def = probabilities
78
+
79
+ classification = "Defective" if prob_def > prob_nom else "Nominal"
80
+
81
+ csv_data.append({
82
+ "datetime_of_operation": datetime.now().isoformat(),
83
+ "num_few_shot_nominal_imgs": num_few_shot_nominal_imgs,
84
+ "image_path": test_image_filenames[idx],
85
+ "image_name": test_image_filenames[idx].split('/')[-1],
86
+ "classification_result": classification,
87
+ "non_defect_prob": round(prob_nom, 3),
88
+ "defect_prob": round(prob_def, 3),
89
+ "nominal_description": nominal_descriptions[max_nom_idx],
90
+ "defective_description": defective_descriptions[max_def_idx] if defective_images else "N/A"
91
+ })
92
+
93
+ if print_one_liner:
94
+ print(f"{test_image_filenames[idx]} classified as {classification} "
95
+ f"(Nominal Prob: {prob_nom:.3f}, Defective Prob: {prob_def:.3f})")
96
+
97
+ file_exists = os.path.isfile(csv_file)
98
+ with open(csv_file, mode='a' if file_exists else 'w', newline='') as file:
99
+ import csv
100
+ fieldnames = [
101
+ "datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name",
102
+ "classification_result", "non_defect_prob", "defect_prob",
103
+ "nominal_description", "defective_description"
104
+ ]
105
+ writer = csv.DictWriter(file, fieldnames=fieldnames)
106
+ if not file_exists:
107
+ writer.writeheader()
108
+ for row in csv_data:
109
+ writer.writerow(row)
110
+
111
+ return ""
112
+
113
+ # Initialize app state
114
+ if 'nominal_images' not in st.session_state:
115
+ st.session_state.nominal_images = []
116
+ if 'defective_images' not in st.session_state:
117
+ st.session_state.defective_images = []
118
+ if 'test_images' not in st.session_state:
119
+ st.session_state.test_images = []
120
+ if 'results' not in st.session_state:
121
+ st.session_state.results = []
122
+
123
+ st.set_page_config(page_title="Few-Shot Fault Detection", layout="wide")
124
  st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")
125
+ st.markdown("Upload **Nominal Images** (good parts), **Defective Images** (bad parts), and **Test Images** to classify.")
126
 
127
+ tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])
128
 
129
+ # --- Tab 1: Upload Reference Images ---
130
+ with tab1:
131
+ st.header("Upload Reference Images")
 
 
132
 
133
+ nominal_files = st.file_uploader("Upload Nominal Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
134
+ defective_files = st.file_uploader("Upload Defective Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
135
 
136
+ if nominal_files:
137
+ st.session_state.nominal_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in nominal_files]
138
+ st.session_state.nominal_descriptions = [file.name for file in nominal_files]
139
+ st.success(f"Uploaded {len(nominal_files)} nominal images.")
140
+
141
+ if defective_files:
142
+ st.session_state.defective_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in defective_files]
143
+ st.session_state.defective_descriptions = [file.name for file in defective_files]
144
+ st.success(f"Uploaded {len(defective_files)} defective images.")
145
+
146
+ # --- Tab 2: Classify Test Images ---
147
+ with tab2:
148
+ st.header("Upload Test Image(s)")
149
+
150
+ test_files = st.file_uploader("Upload Test Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
151
+
152
+ if st.button("🔍 Run Classification") and test_files:
153
+ test_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in test_files]
154
+ test_filenames = [file.name for file in test_files]
155
+
156
+ few_shot_fault_classification(
157
+ test_images=test_images,
158
+ test_image_filenames=test_filenames,
159
+ nominal_images=st.session_state.nominal_images,
160
+ nominal_descriptions=st.session_state.nominal_descriptions,
161
+ defective_images=st.session_state.defective_images,
162
+ defective_descriptions=st.session_state.defective_descriptions,
163
+ num_few_shot_nominal_imgs=len(st.session_state.nominal_images),
164
+ file_path=".",
165
+ file_name="streamlit_results.csv",
166
+ print_one_liner=False
167
  )
168
 
169
+ st.success("Classification complete!")
170
+ st.session_state.results = "streamlit_results.csv"
171
+
172
+ # --- Tab 3: View/Download Results ---
173
+ with tab3:
174
+ st.header("Classification Results")
175
+
176
+ if os.path.exists("streamlit_results.csv"):
177
+ df = pd.read_csv("streamlit_results.csv")
178
+ st.dataframe(df)
179
+ st.download_button("📥 Download Results", data=df.to_csv(index=False), file_name="classification_results.csv", mime="text/csv")
180
+ else:
181
+ st.info("No results yet. Please classify some test images.")