File size: 1,777 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import timm
from timm.models._factory import load_checkpoint
import torch
import os
from typing import List, Union
from torch import nn 
from torch.jit import Final
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from utils.dl.common.model import get_model_device, set_module
import torch.nn.functional as F
from utils.common.log import logger

from transformers import ViltModel, ViltForQuestionAnswering
import torch.nn.functional as F



def vilt_b_32(num_classes):
    """
    Vilt for VQA
    
    settings based on the dataset VQAv2 (3129 classes): 
    
    1. use half of classes for LoRA adaptation
    2. use this half of classes for DA evaluation (using corruptions for generating domain shifts), 
       and use another half of classes for CL evaluation.
    """
    
    model = ViltForQuestionAnswering.from_pretrained('dandelin/vilt-b32-mlm-itm')

    linear = model.classifier[3]
    new_linear = nn.Linear(linear.in_features, num_classes, bias=True)
    set_module(model, 'classifier.3', new_linear)

    return model


if __name__ == '__main__':
    model = vilt_b_32(1565)
    
    print(model)
    
    from transformers import ViltProcessor, ViltModel
    from PIL import Image
    import requests

    # prepare image and text
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    text = "hello world"

    processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
    model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm-itm")

    inputs = processor(image, text, return_tensors="pt")
    
    print(inputs)
    
    outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state
    
    print(last_hidden_states.shape)