fmegahed commited on
Commit
a1e1c29
·
verified ·
1 Parent(s): 95f56a4

Update classifier.py

Browse files
Files changed (1) hide show
  1. classifier.py +62 -8
classifier.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torch.nn.functional as F
3
  from datetime import datetime
4
  import os
 
5
 
6
  def few_shot_fault_classification(
7
  model,
@@ -17,63 +18,116 @@ def few_shot_fault_classification(
17
  file_name: str = 'image_classification_results.csv',
18
  print_one_liner: bool = False
19
  ):
 
 
 
 
20
  if not isinstance(test_images, list):
21
  test_images = [test_images]
22
  if not isinstance(test_image_filenames, list):
23
  test_image_filenames = [test_image_filenames]
24
  if not isinstance(nominal_images, list):
25
  nominal_images = [nominal_images]
 
 
26
  if not isinstance(defective_images, list):
27
  defective_images = [defective_images]
 
 
28
 
 
 
 
 
29
  csv_file = os.path.join(file_path, file_name)
30
  results = []
31
 
32
  with torch.no_grad():
 
33
  nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images])
34
  nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
35
 
 
36
  defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images])
37
  defective_features /= defective_features.norm(dim=-1, keepdim=True)
38
 
 
 
 
 
39
  for idx, test_img in enumerate(test_images):
40
  test_features = model.encode_image(test_img.to(device))
41
  test_features /= test_features.norm(dim=-1, keepdim=True)
42
 
43
- max_nominal_similarity = max_defective_similarity = -float('inf')
44
- max_nominal_idx = max_defective_idx = -1
 
 
 
45
 
 
46
  for i in range(nominal_features.shape[0]):
47
  similarity = (test_features @ nominal_features[i].T).item()
48
  if similarity > max_nominal_similarity:
49
  max_nominal_similarity = similarity
50
  max_nominal_idx = i
51
 
 
52
  for j in range(defective_features.shape[0]):
53
  similarity = (test_features @ defective_features[j].T).item()
54
  if similarity > max_defective_similarity:
55
  max_defective_similarity = similarity
56
  max_defective_idx = j
57
 
 
58
  similarities = torch.tensor([max_nominal_similarity, max_defective_similarity])
59
  probabilities = F.softmax(similarities, dim=0).tolist()
 
 
60
 
61
- classification = "Defective" if probabilities[1] > probabilities[0] else "Nominal"
 
62
 
 
63
  result = {
64
  "datetime_of_operation": datetime.now().isoformat(),
 
65
  "image_path": test_image_filenames[idx],
 
66
  "classification_result": classification,
67
- "non_defect_prob": round(probabilities[0], 3),
68
- "defect_prob": round(probabilities[1], 3),
69
  "nominal_description": nominal_descriptions[max_nominal_idx],
70
  "defective_description": defective_descriptions[max_defective_idx],
 
 
71
  }
72
-
 
73
  results.append(result)
74
 
 
75
  if print_one_liner:
76
  print(f"{test_image_filenames[idx]} → {classification} "
77
- f"(Nominal: {probabilities[0]:.3f}, Defective: {probabilities[1]:.3f})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- return results
 
2
  import torch.nn.functional as F
3
  from datetime import datetime
4
  import os
5
+ import csv
6
 
7
  def few_shot_fault_classification(
8
  model,
 
18
  file_name: str = 'image_classification_results.csv',
19
  print_one_liner: bool = False
20
  ):
21
+ """
22
+ Classify test images as nominal or defective based on similarity to nominal and defective images.
23
+ """
24
+ # Ensure inputs are lists
25
  if not isinstance(test_images, list):
26
  test_images = [test_images]
27
  if not isinstance(test_image_filenames, list):
28
  test_image_filenames = [test_image_filenames]
29
  if not isinstance(nominal_images, list):
30
  nominal_images = [nominal_images]
31
+ if not isinstance(nominal_descriptions, list):
32
+ nominal_descriptions = [nominal_descriptions]
33
  if not isinstance(defective_images, list):
34
  defective_images = [defective_images]
35
+ if not isinstance(defective_descriptions, list):
36
+ defective_descriptions = [defective_descriptions]
37
 
38
+ # Ensure the output directory exists
39
+ os.makedirs(file_path, exist_ok=True)
40
+
41
+ # Prepare full path for the CSV file
42
  csv_file = os.path.join(file_path, file_name)
43
  results = []
44
 
45
  with torch.no_grad():
46
+ # Encode nominal images
47
  nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images])
48
  nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
49
 
50
+ # Encode defective images
51
  defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images])
52
  defective_features /= defective_features.norm(dim=-1, keepdim=True)
53
 
54
+ # Prepare list to save data for CSV
55
+ csv_data = []
56
+
57
+ # Process each test image
58
  for idx, test_img in enumerate(test_images):
59
  test_features = model.encode_image(test_img.to(device))
60
  test_features /= test_features.norm(dim=-1, keepdim=True)
61
 
62
+ # Initialize variables to store max similarities and indices
63
+ max_nominal_similarity = -float('inf')
64
+ max_defective_similarity = -float('inf')
65
+ max_nominal_idx = -1
66
+ max_defective_idx = -1
67
 
68
+ # Loop through each nominal image to find max similarity
69
  for i in range(nominal_features.shape[0]):
70
  similarity = (test_features @ nominal_features[i].T).item()
71
  if similarity > max_nominal_similarity:
72
  max_nominal_similarity = similarity
73
  max_nominal_idx = i
74
 
75
+ # Loop through each defective image to find max similarity
76
  for j in range(defective_features.shape[0]):
77
  similarity = (test_features @ defective_features[j].T).item()
78
  if similarity > max_defective_similarity:
79
  max_defective_similarity = similarity
80
  max_defective_idx = j
81
 
82
+ # Convert similarities to probabilities
83
  similarities = torch.tensor([max_nominal_similarity, max_defective_similarity])
84
  probabilities = F.softmax(similarities, dim=0).tolist()
85
+ prob_not_defective = probabilities[0]
86
+ prob_defective = probabilities[1]
87
 
88
+ # Determine classification result
89
+ classification = "Defective" if prob_defective > prob_not_defective else "Nominal"
90
 
91
+ # Create result dict
92
  result = {
93
  "datetime_of_operation": datetime.now().isoformat(),
94
+ "num_few_shot_nominal_imgs": num_few_shot_nominal_imgs,
95
  "image_path": test_image_filenames[idx],
96
+ "image_name": test_image_filenames[idx].split('/')[-1],
97
  "classification_result": classification,
98
+ "non_defect_prob": round(prob_not_defective, 3),
99
+ "defect_prob": round(prob_defective, 3),
100
  "nominal_description": nominal_descriptions[max_nominal_idx],
101
  "defective_description": defective_descriptions[max_defective_idx],
102
+ "max_nominal_similarity": round(max_nominal_similarity, 3),
103
+ "max_defective_similarity": round(max_defective_similarity, 3)
104
  }
105
+
106
+ csv_data.append(result)
107
  results.append(result)
108
 
109
+ # Optionally print one-liner summary for each test image
110
  if print_one_liner:
111
  print(f"{test_image_filenames[idx]} → {classification} "
112
+ f"(Nominal: {prob_not_defective:.3f}, Defective: {prob_defective:.3f})")
113
+
114
+ # Write to CSV (append mode if file exists, write mode if not)
115
+ file_exists = os.path.isfile(csv_file)
116
+ with open(csv_file, mode='a' if file_exists else 'w', newline='') as file:
117
+ fieldnames = [
118
+ "datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name",
119
+ "classification_result", "non_defect_prob", "defect_prob",
120
+ "nominal_description", "defective_description",
121
+ "max_nominal_similarity", "max_defective_similarity"
122
+ ]
123
+ writer = csv.DictWriter(file, fieldnames=fieldnames)
124
+
125
+ # Write header if file doesn't exist
126
+ if not file_exists:
127
+ writer.writeheader()
128
+
129
+ # Write each row of data
130
+ for row in csv_data:
131
+ writer.writerow(row)
132
 
133
+ return results