File size: 3,185 Bytes
116201b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
import torch.nn.functional as F
from datetime import datetime
import os
def few_shot_fault_classification(
model,
test_images,
test_image_filenames,
nominal_images,
nominal_descriptions,
defective_images,
defective_descriptions,
num_few_shot_nominal_imgs: int,
device="cpu",
file_path: str = '.',
file_name: str = 'image_classification_results.csv',
print_one_liner: bool = False
):
if not isinstance(test_images, list):
test_images = [test_images]
if not isinstance(test_image_filenames, list):
test_image_filenames = [test_image_filenames]
if not isinstance(nominal_images, list):
nominal_images = [nominal_images]
if not isinstance(defective_images, list):
defective_images = [defective_images]
csv_file = os.path.join(file_path, file_name)
results = []
with torch.no_grad():
nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images])
nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images])
defective_features /= defective_features.norm(dim=-1, keepdim=True)
for idx, test_img in enumerate(test_images):
test_features = model.encode_image(test_img.to(device))
test_features /= test_features.norm(dim=-1, keepdim=True)
max_nominal_similarity = max_defective_similarity = -float('inf')
max_nominal_idx = max_defective_idx = -1
for i in range(nominal_features.shape[0]):
similarity = (test_features @ nominal_features[i].T).item()
if similarity > max_nominal_similarity:
max_nominal_similarity = similarity
max_nominal_idx = i
for j in range(defective_features.shape[0]):
similarity = (test_features @ defective_features[j].T).item()
if similarity > max_defective_similarity:
max_defective_similarity = similarity
max_defective_idx = j
similarities = torch.tensor([max_nominal_similarity, max_defective_similarity])
probabilities = F.softmax(similarities, dim=0).tolist()
classification = "Defective" if probabilities[1] > probabilities[0] else "Nominal"
result = {
"datetime_of_operation": datetime.now().isoformat(),
"image_path": test_image_filenames[idx],
"classification_result": classification,
"non_defect_prob": round(probabilities[0], 3),
"defect_prob": round(probabilities[1], 3),
"nominal_description": nominal_descriptions[max_nominal_idx],
"defective_description": defective_descriptions[max_defective_idx],
}
results.append(result)
if print_one_liner:
print(f"{test_image_filenames[idx]} → {classification} "
f"(Nominal: {probabilities[0]:.3f}, Defective: {probabilities[1]:.3f})")
return results
|