import torch import torch.nn.functional as F from datetime import datetime import os def few_shot_fault_classification( model, test_images, test_image_filenames, nominal_images, nominal_descriptions, defective_images, defective_descriptions, num_few_shot_nominal_imgs: int, device="cpu", file_path: str = '.', file_name: str = 'image_classification_results.csv', print_one_liner: bool = False ): if not isinstance(test_images, list): test_images = [test_images] if not isinstance(test_image_filenames, list): test_image_filenames = [test_image_filenames] if not isinstance(nominal_images, list): nominal_images = [nominal_images] if not isinstance(defective_images, list): defective_images = [defective_images] csv_file = os.path.join(file_path, file_name) results = [] with torch.no_grad(): nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images]) nominal_features /= nominal_features.norm(dim=-1, keepdim=True) defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images]) defective_features /= defective_features.norm(dim=-1, keepdim=True) for idx, test_img in enumerate(test_images): test_features = model.encode_image(test_img.to(device)) test_features /= test_features.norm(dim=-1, keepdim=True) max_nominal_similarity = max_defective_similarity = -float('inf') max_nominal_idx = max_defective_idx = -1 for i in range(nominal_features.shape[0]): similarity = (test_features @ nominal_features[i].T).item() if similarity > max_nominal_similarity: max_nominal_similarity = similarity max_nominal_idx = i for j in range(defective_features.shape[0]): similarity = (test_features @ defective_features[j].T).item() if similarity > max_defective_similarity: max_defective_similarity = similarity max_defective_idx = j similarities = torch.tensor([max_nominal_similarity, max_defective_similarity]) probabilities = F.softmax(similarities, dim=0).tolist() classification = "Defective" if probabilities[1] > probabilities[0] else "Nominal" result = { "datetime_of_operation": datetime.now().isoformat(), "image_path": test_image_filenames[idx], "classification_result": classification, "non_defect_prob": round(probabilities[0], 3), "defect_prob": round(probabilities[1], 3), "nominal_description": nominal_descriptions[max_nominal_idx], "defective_description": defective_descriptions[max_defective_idx], } results.append(result) if print_one_liner: print(f"{test_image_filenames[idx]} → {classification} " f"(Nominal: {probabilities[0]:.3f}, Defective: {probabilities[1]:.3f})") return results