fmegahed commited on
Commit
116201b
·
verified ·
1 Parent(s): 3b9284a

Create classifier.py

Browse files
Files changed (1) hide show
  1. classifier.py +79 -0
classifier.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from datetime import datetime
4
+ import os
5
+
6
+ def few_shot_fault_classification(
7
+ model,
8
+ test_images,
9
+ test_image_filenames,
10
+ nominal_images,
11
+ nominal_descriptions,
12
+ defective_images,
13
+ defective_descriptions,
14
+ num_few_shot_nominal_imgs: int,
15
+ device="cpu",
16
+ file_path: str = '.',
17
+ file_name: str = 'image_classification_results.csv',
18
+ print_one_liner: bool = False
19
+ ):
20
+ if not isinstance(test_images, list):
21
+ test_images = [test_images]
22
+ if not isinstance(test_image_filenames, list):
23
+ test_image_filenames = [test_image_filenames]
24
+ if not isinstance(nominal_images, list):
25
+ nominal_images = [nominal_images]
26
+ if not isinstance(defective_images, list):
27
+ defective_images = [defective_images]
28
+
29
+ csv_file = os.path.join(file_path, file_name)
30
+ results = []
31
+
32
+ with torch.no_grad():
33
+ nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images])
34
+ nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
35
+
36
+ defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images])
37
+ defective_features /= defective_features.norm(dim=-1, keepdim=True)
38
+
39
+ for idx, test_img in enumerate(test_images):
40
+ test_features = model.encode_image(test_img.to(device))
41
+ test_features /= test_features.norm(dim=-1, keepdim=True)
42
+
43
+ max_nominal_similarity = max_defective_similarity = -float('inf')
44
+ max_nominal_idx = max_defective_idx = -1
45
+
46
+ for i in range(nominal_features.shape[0]):
47
+ similarity = (test_features @ nominal_features[i].T).item()
48
+ if similarity > max_nominal_similarity:
49
+ max_nominal_similarity = similarity
50
+ max_nominal_idx = i
51
+
52
+ for j in range(defective_features.shape[0]):
53
+ similarity = (test_features @ defective_features[j].T).item()
54
+ if similarity > max_defective_similarity:
55
+ max_defective_similarity = similarity
56
+ max_defective_idx = j
57
+
58
+ similarities = torch.tensor([max_nominal_similarity, max_defective_similarity])
59
+ probabilities = F.softmax(similarities, dim=0).tolist()
60
+
61
+ classification = "Defective" if probabilities[1] > probabilities[0] else "Nominal"
62
+
63
+ result = {
64
+ "datetime_of_operation": datetime.now().isoformat(),
65
+ "image_path": test_image_filenames[idx],
66
+ "classification_result": classification,
67
+ "non_defect_prob": round(probabilities[0], 3),
68
+ "defect_prob": round(probabilities[1], 3),
69
+ "nominal_description": nominal_descriptions[max_nominal_idx],
70
+ "defective_description": defective_descriptions[max_defective_idx],
71
+ }
72
+
73
+ results.append(result)
74
+
75
+ if print_one_liner:
76
+ print(f"{test_image_filenames[idx]} → {classification} "
77
+ f"(Nominal: {probabilities[0]:.3f}, Defective: {probabilities[1]:.3f})")
78
+
79
+ return results