File size: 5,178 Bytes
fd8732a
 
 
 
 
 
07ff5a5
 
 
fd8732a
990ca41
07ff5a5
 
fd8732a
 
 
 
 
07ff5a5
fd8732a
07ff5a5
fd8732a
 
 
 
 
 
 
 
 
 
07ff5a5
fd8732a
 
 
 
 
 
 
 
 
 
 
 
 
07ff5a5
fd8732a
 
 
 
 
07ff5a5
 
 
 
 
 
 
 
 
 
 
 
 
fd8732a
07ff5a5
 
 
 
fd8732a
07ff5a5
fd8732a
07ff5a5
 
fd8732a
07ff5a5
 
 
 
fd8732a
07ff5a5
 
fd8732a
07ff5a5
fd8732a
07ff5a5
 
fd8732a
07ff5a5
 
 
 
fd8732a
 
 
07ff5a5
 
 
 
 
fd8732a
 
 
07ff5a5
 
 
 
fd8732a
 
 
 
 
 
 
 
07ff5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import cv2
import torch
import numpy as np
from transformers import CLIPProcessor, CLIPVisionModel
from PIL import Image
from torch import nn
import requests
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download

MODEL_PATH = "pytorch_model.bin"
REPO_ID = "Hayloo9838/uno-recognizer"
MAPANDSTUFF = "mapandstuff.pth"

class CLIPVisionClassifier(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14', 
                                                          attn_implementation="eager")
        self.classifier = nn.Linear(self.vision_model.config.hidden_size, num_labels, bias=False)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, pixel_values, output_attentions=False):
        outputs = self.vision_model(pixel_values, output_attentions=output_attentions)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        
        if output_attentions:
            return logits, outputs.attentions
        return logits

def get_attention_map(attentions):
    attention = attentions[-1]
    attention = attention.mean(dim=1)
    attention = attention[0, 0, 1:]
    
    num_patches = int(np.sqrt(attention.shape[0]))
    
    attention_map = attention.reshape(num_patches, num_patches)
    
    attention_map = attention_map.cpu().numpy()
    
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    return attention_map

def apply_heatmap(image, attention_map, new_size=None):
    heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
    
    if isinstance(image, Image.Image):
        image = np.array(image)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    if new_size is not None:
        image_resized = cv2.resize(image, new_size)
        attention_map_resized = cv2.resize(attention_map, image_resized.shape[:2][::-1] , interpolation=cv2.INTER_LINEAR)
        attention_map_resized = (attention_map_resized - attention_map_resized.min()) / (attention_map_resized.max() - attention_map_resized.min())
        heatmap_resized = cv2.applyColorMap(np.uint8(255 * attention_map_resized), cv2.COLORMAP_JET)
        output = cv2.addWeighted(image_resized, 0.7, heatmap_resized, 0.3, 0)
    else:
        attention_map_resized = cv2.resize(attention_map, image.shape[:2][::-1] , interpolation=cv2.INTER_LINEAR)
        attention_map_resized = (attention_map_resized - attention_map_resized.min()) / (attention_map_resized.max() - attention_map_resized.min())
        heatmap_resized = cv2.applyColorMap(np.uint8(255 * attention_map_resized), cv2.COLORMAP_JET)
        output = cv2.addWeighted(image, 0.7, heatmap_resized, 0.3, 0)
      
    
    return output

def process_image_classification(image_url):
    model, processor, reverse_mapping, device = load_model()
    
    image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
    
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs.pixel_values.to(device)
    
    with torch.no_grad():
        logits, attentions = model(pixel_values, output_attentions=True)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        prediction = torch.argmax(probs).item()
    
    # Generate attention map
    attention_map = get_attention_map(attentions)
    
    visualization = apply_heatmap(image, attention_map)
    
    card_name = reverse_mapping[prediction]
    confidence = probs[0][prediction].item()
    
    # Convert back to RGB for matplotlib display
    visualization_rgb = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB)
    
    return visualization_rgb, card_name, confidence

def load_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Download model weights and label mapping from Hugging Face Hub
    model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH)
    #mapandstuff_path = hf_hub_download(repo_id=REPO_ID, filename=MAPANDSTUFF)
    checkpoint = torch.load(model_path, map_location=device)
    label_mapping = checkpoint['label_mapping']
    reverse_mapping = {v: k for k, v in label_mapping.items()}
    model = CLIPVisionClassifier(len(label_mapping))
    
    model_state_dict = checkpoint["model_state_dict"]
    model.load_state_dict(model_state_dict)
    
    model = model.to(device)
    model.eval()

    processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')

    return model, processor, reverse_mapping, device

if __name__ == "__main__":
    image_url = "https://www.shutterstock.com/image-vector/hand-hold-reverse-card-symbol-600w-2360073097.jpg"
    visualization, card_name, confidence = process_image_classification(image_url)

    plt.figure(figsize=(10, 5))
    
    plt.subplot(1, 2, 1)
    plt.imshow(visualization)
    plt.title(f"Heatmap on Image")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.text(0.5, 0.5, f"Predicted Card: {card_name}\nConfidence: {confidence:.2%}",
             fontsize=12, ha='center', va='center')
    plt.axis('off')
    plt.show()