File size: 5,530 Bytes
116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 116201b a1e1c29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import torch.nn.functional as F
from datetime import datetime
import os
import csv
def few_shot_fault_classification(
model,
test_images,
test_image_filenames,
nominal_images,
nominal_descriptions,
defective_images,
defective_descriptions,
num_few_shot_nominal_imgs: int,
device="cpu",
file_path: str = '.',
file_name: str = 'image_classification_results.csv',
print_one_liner: bool = False
):
"""
Classify test images as nominal or defective based on similarity to nominal and defective images.
"""
# Ensure inputs are lists
if not isinstance(test_images, list):
test_images = [test_images]
if not isinstance(test_image_filenames, list):
test_image_filenames = [test_image_filenames]
if not isinstance(nominal_images, list):
nominal_images = [nominal_images]
if not isinstance(nominal_descriptions, list):
nominal_descriptions = [nominal_descriptions]
if not isinstance(defective_images, list):
defective_images = [defective_images]
if not isinstance(defective_descriptions, list):
defective_descriptions = [defective_descriptions]
# Ensure the output directory exists
os.makedirs(file_path, exist_ok=True)
# Prepare full path for the CSV file
csv_file = os.path.join(file_path, file_name)
results = []
with torch.no_grad():
# Encode nominal images
nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images])
nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
# Encode defective images
defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images])
defective_features /= defective_features.norm(dim=-1, keepdim=True)
# Prepare list to save data for CSV
csv_data = []
# Process each test image
for idx, test_img in enumerate(test_images):
test_features = model.encode_image(test_img.to(device))
test_features /= test_features.norm(dim=-1, keepdim=True)
# Initialize variables to store max similarities and indices
max_nominal_similarity = -float('inf')
max_defective_similarity = -float('inf')
max_nominal_idx = -1
max_defective_idx = -1
# Loop through each nominal image to find max similarity
for i in range(nominal_features.shape[0]):
similarity = (test_features @ nominal_features[i].T).item()
if similarity > max_nominal_similarity:
max_nominal_similarity = similarity
max_nominal_idx = i
# Loop through each defective image to find max similarity
for j in range(defective_features.shape[0]):
similarity = (test_features @ defective_features[j].T).item()
if similarity > max_defective_similarity:
max_defective_similarity = similarity
max_defective_idx = j
# Convert similarities to probabilities
similarities = torch.tensor([max_nominal_similarity, max_defective_similarity])
probabilities = F.softmax(similarities, dim=0).tolist()
prob_not_defective = probabilities[0]
prob_defective = probabilities[1]
# Determine classification result
classification = "Defective" if prob_defective > prob_not_defective else "Nominal"
# Create result dict
result = {
"datetime_of_operation": datetime.now().isoformat(),
"num_few_shot_nominal_imgs": num_few_shot_nominal_imgs,
"image_path": test_image_filenames[idx],
"image_name": test_image_filenames[idx].split('/')[-1],
"classification_result": classification,
"non_defect_prob": round(prob_not_defective, 3),
"defect_prob": round(prob_defective, 3),
"nominal_description": nominal_descriptions[max_nominal_idx],
"defective_description": defective_descriptions[max_defective_idx],
"max_nominal_similarity": round(max_nominal_similarity, 3),
"max_defective_similarity": round(max_defective_similarity, 3)
}
csv_data.append(result)
results.append(result)
# Optionally print one-liner summary for each test image
if print_one_liner:
print(f"{test_image_filenames[idx]} → {classification} "
f"(Nominal: {prob_not_defective:.3f}, Defective: {prob_defective:.3f})")
# Write to CSV (append mode if file exists, write mode if not)
file_exists = os.path.isfile(csv_file)
with open(csv_file, mode='a' if file_exists else 'w', newline='') as file:
fieldnames = [
"datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name",
"classification_result", "non_defect_prob", "defect_prob",
"nominal_description", "defective_description",
"max_nominal_similarity", "max_defective_similarity"
]
writer = csv.DictWriter(file, fieldnames=fieldnames)
# Write header if file doesn't exist
if not file_exists:
writer.writeheader()
# Write each row of data
for row in csv_data:
writer.writerow(row)
return results |