|
import torch |
|
import torch.nn.functional as F |
|
from datetime import datetime |
|
import os |
|
|
|
def few_shot_fault_classification( |
|
model, |
|
test_images, |
|
test_image_filenames, |
|
nominal_images, |
|
nominal_descriptions, |
|
defective_images, |
|
defective_descriptions, |
|
num_few_shot_nominal_imgs: int, |
|
device="cpu", |
|
file_path: str = '.', |
|
file_name: str = 'image_classification_results.csv', |
|
print_one_liner: bool = False |
|
): |
|
if not isinstance(test_images, list): |
|
test_images = [test_images] |
|
if not isinstance(test_image_filenames, list): |
|
test_image_filenames = [test_image_filenames] |
|
if not isinstance(nominal_images, list): |
|
nominal_images = [nominal_images] |
|
if not isinstance(defective_images, list): |
|
defective_images = [defective_images] |
|
|
|
csv_file = os.path.join(file_path, file_name) |
|
results = [] |
|
|
|
with torch.no_grad(): |
|
nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images]) |
|
nominal_features /= nominal_features.norm(dim=-1, keepdim=True) |
|
|
|
defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images]) |
|
defective_features /= defective_features.norm(dim=-1, keepdim=True) |
|
|
|
for idx, test_img in enumerate(test_images): |
|
test_features = model.encode_image(test_img.to(device)) |
|
test_features /= test_features.norm(dim=-1, keepdim=True) |
|
|
|
max_nominal_similarity = max_defective_similarity = -float('inf') |
|
max_nominal_idx = max_defective_idx = -1 |
|
|
|
for i in range(nominal_features.shape[0]): |
|
similarity = (test_features @ nominal_features[i].T).item() |
|
if similarity > max_nominal_similarity: |
|
max_nominal_similarity = similarity |
|
max_nominal_idx = i |
|
|
|
for j in range(defective_features.shape[0]): |
|
similarity = (test_features @ defective_features[j].T).item() |
|
if similarity > max_defective_similarity: |
|
max_defective_similarity = similarity |
|
max_defective_idx = j |
|
|
|
similarities = torch.tensor([max_nominal_similarity, max_defective_similarity]) |
|
probabilities = F.softmax(similarities, dim=0).tolist() |
|
|
|
classification = "Defective" if probabilities[1] > probabilities[0] else "Nominal" |
|
|
|
result = { |
|
"datetime_of_operation": datetime.now().isoformat(), |
|
"image_path": test_image_filenames[idx], |
|
"classification_result": classification, |
|
"non_defect_prob": round(probabilities[0], 3), |
|
"defect_prob": round(probabilities[1], 3), |
|
"nominal_description": nominal_descriptions[max_nominal_idx], |
|
"defective_description": defective_descriptions[max_defective_idx], |
|
} |
|
|
|
results.append(result) |
|
|
|
if print_one_liner: |
|
print(f"{test_image_filenames[idx]} → {classification} " |
|
f"(Nominal: {probabilities[0]:.3f}, Defective: {probabilities[1]:.3f})") |
|
|
|
return results |
|
|