File size: 5,530 Bytes
116201b
 
 
 
a1e1c29
116201b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1e1c29
 
 
 
116201b
 
 
 
 
 
a1e1c29
 
116201b
 
a1e1c29
 
116201b
a1e1c29
 
 
 
116201b
 
 
 
a1e1c29
116201b
 
 
a1e1c29
116201b
 
 
a1e1c29
 
 
 
116201b
 
 
 
a1e1c29
 
 
 
 
116201b
a1e1c29
116201b
 
 
 
 
 
a1e1c29
116201b
 
 
 
 
 
a1e1c29
116201b
 
a1e1c29
 
116201b
a1e1c29
 
116201b
a1e1c29
116201b
 
a1e1c29
116201b
a1e1c29
116201b
a1e1c29
 
116201b
 
a1e1c29
 
116201b
a1e1c29
 
116201b
 
a1e1c29
116201b
 
a1e1c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116201b
a1e1c29
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import torch.nn.functional as F
from datetime import datetime
import os
import csv

def few_shot_fault_classification(
    model,
    test_images,
    test_image_filenames,
    nominal_images,
    nominal_descriptions,
    defective_images,
    defective_descriptions,
    num_few_shot_nominal_imgs: int,
    device="cpu",
    file_path: str = '.',
    file_name: str = 'image_classification_results.csv',
    print_one_liner: bool = False
):
    """
    Classify test images as nominal or defective based on similarity to nominal and defective images.
    """
    # Ensure inputs are lists
    if not isinstance(test_images, list):
        test_images = [test_images]
    if not isinstance(test_image_filenames, list):
        test_image_filenames = [test_image_filenames]
    if not isinstance(nominal_images, list):
        nominal_images = [nominal_images]
    if not isinstance(nominal_descriptions, list):
        nominal_descriptions = [nominal_descriptions]
    if not isinstance(defective_images, list):
        defective_images = [defective_images]
    if not isinstance(defective_descriptions, list):
        defective_descriptions = [defective_descriptions]

    # Ensure the output directory exists
    os.makedirs(file_path, exist_ok=True)
    
    # Prepare full path for the CSV file
    csv_file = os.path.join(file_path, file_name)
    results = []

    with torch.no_grad():
        # Encode nominal images
        nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images])
        nominal_features /= nominal_features.norm(dim=-1, keepdim=True)

        # Encode defective images
        defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images])
        defective_features /= defective_features.norm(dim=-1, keepdim=True)

        # Prepare list to save data for CSV
        csv_data = []

        # Process each test image
        for idx, test_img in enumerate(test_images):
            test_features = model.encode_image(test_img.to(device))
            test_features /= test_features.norm(dim=-1, keepdim=True)

            # Initialize variables to store max similarities and indices
            max_nominal_similarity = -float('inf')
            max_defective_similarity = -float('inf')
            max_nominal_idx = -1
            max_defective_idx = -1

            # Loop through each nominal image to find max similarity
            for i in range(nominal_features.shape[0]):
                similarity = (test_features @ nominal_features[i].T).item()
                if similarity > max_nominal_similarity:
                    max_nominal_similarity = similarity
                    max_nominal_idx = i

            # Loop through each defective image to find max similarity
            for j in range(defective_features.shape[0]):
                similarity = (test_features @ defective_features[j].T).item()
                if similarity > max_defective_similarity:
                    max_defective_similarity = similarity
                    max_defective_idx = j

            # Convert similarities to probabilities
            similarities = torch.tensor([max_nominal_similarity, max_defective_similarity])
            probabilities = F.softmax(similarities, dim=0).tolist()
            prob_not_defective = probabilities[0]
            prob_defective = probabilities[1]

            # Determine classification result
            classification = "Defective" if prob_defective > prob_not_defective else "Nominal"

            # Create result dict
            result = {
                "datetime_of_operation": datetime.now().isoformat(),
                "num_few_shot_nominal_imgs": num_few_shot_nominal_imgs,
                "image_path": test_image_filenames[idx],
                "image_name": test_image_filenames[idx].split('/')[-1],
                "classification_result": classification,
                "non_defect_prob": round(prob_not_defective, 3),
                "defect_prob": round(prob_defective, 3),
                "nominal_description": nominal_descriptions[max_nominal_idx],
                "defective_description": defective_descriptions[max_defective_idx],
                "max_nominal_similarity": round(max_nominal_similarity, 3),
                "max_defective_similarity": round(max_defective_similarity, 3)
            }
            
            csv_data.append(result)
            results.append(result)

            # Optionally print one-liner summary for each test image
            if print_one_liner:
                print(f"{test_image_filenames[idx]}{classification} "
                      f"(Nominal: {prob_not_defective:.3f}, Defective: {prob_defective:.3f})")

    # Write to CSV (append mode if file exists, write mode if not)
    file_exists = os.path.isfile(csv_file)
    with open(csv_file, mode='a' if file_exists else 'w', newline='') as file:
        fieldnames = [
            "datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name",
            "classification_result", "non_defect_prob", "defect_prob", 
            "nominal_description", "defective_description",
            "max_nominal_similarity", "max_defective_similarity"
        ]
        writer = csv.DictWriter(file, fieldnames=fieldnames)

        # Write header if file doesn't exist
        if not file_exists:
            writer.writeheader()

        # Write each row of data
        for row in csv_data:
            writer.writerow(row)

    return results