clip / classifier.py
fmegahed's picture
Create classifier.py
116201b verified
raw
history blame
3.19 kB
import torch
import torch.nn.functional as F
from datetime import datetime
import os
def few_shot_fault_classification(
model,
test_images,
test_image_filenames,
nominal_images,
nominal_descriptions,
defective_images,
defective_descriptions,
num_few_shot_nominal_imgs: int,
device="cpu",
file_path: str = '.',
file_name: str = 'image_classification_results.csv',
print_one_liner: bool = False
):
if not isinstance(test_images, list):
test_images = [test_images]
if not isinstance(test_image_filenames, list):
test_image_filenames = [test_image_filenames]
if not isinstance(nominal_images, list):
nominal_images = [nominal_images]
if not isinstance(defective_images, list):
defective_images = [defective_images]
csv_file = os.path.join(file_path, file_name)
results = []
with torch.no_grad():
nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images])
nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images])
defective_features /= defective_features.norm(dim=-1, keepdim=True)
for idx, test_img in enumerate(test_images):
test_features = model.encode_image(test_img.to(device))
test_features /= test_features.norm(dim=-1, keepdim=True)
max_nominal_similarity = max_defective_similarity = -float('inf')
max_nominal_idx = max_defective_idx = -1
for i in range(nominal_features.shape[0]):
similarity = (test_features @ nominal_features[i].T).item()
if similarity > max_nominal_similarity:
max_nominal_similarity = similarity
max_nominal_idx = i
for j in range(defective_features.shape[0]):
similarity = (test_features @ defective_features[j].T).item()
if similarity > max_defective_similarity:
max_defective_similarity = similarity
max_defective_idx = j
similarities = torch.tensor([max_nominal_similarity, max_defective_similarity])
probabilities = F.softmax(similarities, dim=0).tolist()
classification = "Defective" if probabilities[1] > probabilities[0] else "Nominal"
result = {
"datetime_of_operation": datetime.now().isoformat(),
"image_path": test_image_filenames[idx],
"classification_result": classification,
"non_defect_prob": round(probabilities[0], 3),
"defect_prob": round(probabilities[1], 3),
"nominal_description": nominal_descriptions[max_nominal_idx],
"defective_description": defective_descriptions[max_defective_idx],
}
results.append(result)
if print_one_liner:
print(f"{test_image_filenames[idx]}{classification} "
f"(Nominal: {probabilities[0]:.3f}, Defective: {probabilities[1]:.3f})")
return results