|
import cv2 |
|
import torch |
|
import numpy as np |
|
from transformers import CLIPProcessor, CLIPVisionModel |
|
from PIL import Image |
|
from torch import nn |
|
import requests |
|
import matplotlib.pyplot as plt |
|
from huggingface_hub import hf_hub_download |
|
|
|
MODEL_PATH = "pytorch_model.bin" |
|
REPO_ID = "Hayloo9838/uno-recognizer" |
|
MAPANDSTUFF = "mapandstuff.pth" |
|
|
|
class CLIPVisionClassifier(nn.Module): |
|
def __init__(self, num_labels): |
|
super().__init__() |
|
self.vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14', |
|
attn_implementation="eager") |
|
self.classifier = nn.Linear(self.vision_model.config.hidden_size, num_labels, bias=False) |
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def forward(self, pixel_values, output_attentions=False): |
|
outputs = self.vision_model(pixel_values, output_attentions=output_attentions) |
|
pooled_output = outputs.pooler_output |
|
logits = self.classifier(pooled_output) |
|
|
|
if output_attentions: |
|
return logits, outputs.attentions |
|
return logits |
|
|
|
def get_attention_map(attentions): |
|
attention = attentions[-1] |
|
attention = attention.mean(dim=1) |
|
attention = attention[0, 0, 1:] |
|
|
|
num_patches = int(np.sqrt(attention.shape[0])) |
|
|
|
attention_map = attention.reshape(num_patches, num_patches) |
|
|
|
attention_map = attention_map.cpu().numpy() |
|
|
|
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) |
|
return attention_map |
|
|
|
def apply_heatmap(image, attention_map, new_size=None): |
|
heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET) |
|
|
|
if isinstance(image, Image.Image): |
|
image = np.array(image) |
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
|
if new_size is not None: |
|
image_resized = cv2.resize(image, new_size) |
|
attention_map_resized = cv2.resize(attention_map, image_resized.shape[:2][::-1] , interpolation=cv2.INTER_LINEAR) |
|
attention_map_resized = (attention_map_resized - attention_map_resized.min()) / (attention_map_resized.max() - attention_map_resized.min()) |
|
heatmap_resized = cv2.applyColorMap(np.uint8(255 * attention_map_resized), cv2.COLORMAP_JET) |
|
output = cv2.addWeighted(image_resized, 0.7, heatmap_resized, 0.3, 0) |
|
else: |
|
attention_map_resized = cv2.resize(attention_map, image.shape[:2][::-1] , interpolation=cv2.INTER_LINEAR) |
|
attention_map_resized = (attention_map_resized - attention_map_resized.min()) / (attention_map_resized.max() - attention_map_resized.min()) |
|
heatmap_resized = cv2.applyColorMap(np.uint8(255 * attention_map_resized), cv2.COLORMAP_JET) |
|
output = cv2.addWeighted(image, 0.7, heatmap_resized, 0.3, 0) |
|
|
|
|
|
return output |
|
|
|
def process_image_classification(image_url): |
|
model, processor, reverse_mapping, device = load_model() |
|
|
|
image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB') |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
pixel_values = inputs.pixel_values.to(device) |
|
|
|
with torch.no_grad(): |
|
logits, attentions = model(pixel_values, output_attentions=True) |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
prediction = torch.argmax(probs).item() |
|
|
|
|
|
attention_map = get_attention_map(attentions) |
|
|
|
visualization = apply_heatmap(image, attention_map) |
|
|
|
card_name = reverse_mapping[prediction] |
|
confidence = probs[0][prediction].item() |
|
|
|
|
|
visualization_rgb = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB) |
|
|
|
return visualization_rgb, card_name, confidence |
|
|
|
def load_model(): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH) |
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
label_mapping = checkpoint['label_mapping'] |
|
reverse_mapping = {v: k for k, v in label_mapping.items()} |
|
model = CLIPVisionClassifier(len(label_mapping)) |
|
|
|
model_state_dict = checkpoint["model_state_dict"] |
|
model.load_state_dict(model_state_dict) |
|
|
|
model = model.to(device) |
|
model.eval() |
|
|
|
processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14') |
|
|
|
return model, processor, reverse_mapping, device |
|
|
|
if __name__ == "__main__": |
|
image_url = "https://www.shutterstock.com/image-vector/hand-hold-reverse-card-symbol-600w-2360073097.jpg" |
|
visualization, card_name, confidence = process_image_classification(image_url) |
|
|
|
plt.figure(figsize=(10, 5)) |
|
|
|
plt.subplot(1, 2, 1) |
|
plt.imshow(visualization) |
|
plt.title(f"Heatmap on Image") |
|
plt.axis('off') |
|
|
|
plt.subplot(1, 2, 2) |
|
plt.text(0.5, 0.5, f"Predicted Card: {card_name}\nConfidence: {confidence:.2%}", |
|
fontsize=12, ha='center', va='center') |
|
plt.axis('off') |
|
plt.show() |