KeerthiVM commited on
Commit
b228643
·
1 Parent(s): 73f833f

Initial commit

Browse files
Files changed (1) hide show
  1. evo_vit.py +114 -0
evo_vit.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2024-11-10 15:29:50
3
+ """
4
+ import torchvision.transforms as transforms
5
+ import sys
6
+ import os
7
+ import torch
8
+ from torch.autograd import Variable
9
+ import torch.nn as nn
10
+ import torch.backends.cudnn as cudnn
11
+ import torch.optim as optim
12
+ from datetime import datetime
13
+ import multiprocessing
14
+ from transformers import ViTModel, ViTConfig
15
+ from sklearn.metrics import f1_score
16
+ from sklearn.model_selection import KFold
17
+ import numpy as np
18
+ from collections import Counter
19
+ from torch.optim.lr_scheduler import StepLR
20
+ from PIL import Image
21
+ import torch.nn.functional as F
22
+
23
+
24
+ class PatchEmbedding(nn.Module):
25
+ def __init__(self, img_size, patch_size, in_channels, embed_dim, hidden_dim):
26
+ super(PatchEmbedding, self).__init__()
27
+ # self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
28
+ self.patch_embed = nn.Conv2d(in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
29
+ self.num_patches = (img_size // patch_size) ** 2
30
+
31
+ def forward(self, x):
32
+ x = self.patch_embed(x).flatten(2).transpose(1, 2) # (batch_size, num_patches, embed_dim)
33
+ return x
34
+
35
+ class PositionalEncoding(nn.Module):
36
+ def __init__(self, num_patches, embed_dim, hidden_dim):
37
+ super(PositionalEncoding, self).__init__()
38
+ self.positional_encoding = nn.Parameter(torch.randn(1, num_patches, hidden_dim))
39
+
40
+ def forward(self, x):
41
+ return x + self.positional_encoding
42
+
43
+
44
+ class TransformerLayer(nn.Module):
45
+ def __init__(self, hidden_dim, num_heads, mlp_dim, dropout_rate):
46
+ super(TransformerLayer, self).__init__()
47
+ self.attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout_rate)
48
+ self.mlp = nn.Sequential(
49
+ nn.Linear(hidden_dim, mlp_dim),
50
+ nn.GELU(),
51
+ nn.Dropout(dropout_rate),
52
+ nn.Linear(mlp_dim, hidden_dim),
53
+ nn.Dropout(dropout_rate)
54
+ )
55
+ self.norm1 = nn.LayerNorm(hidden_dim)
56
+ self.norm2 = nn.LayerNorm(hidden_dim)
57
+
58
+ def forward(self, x):
59
+ attn_out, _ = self.attention(x, x, x)
60
+ x = self.norm1(x + attn_out)
61
+ x = self.norm2(x + self.mlp(x))
62
+ return x
63
+
64
+ # EvoViTModel class for building Vision Transformer model
65
+ class EvoViTModel(nn.Module):
66
+ def __init__(self, img_size, patch_size, in_channels, embed_dim, num_classes, hidden_dim):
67
+ super(EvoViTModel, self).__init__()
68
+
69
+ self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim, hidden_dim)
70
+ self.position_encoding = PositionalEncoding(self.patch_embed.num_patches, embed_dim, hidden_dim)
71
+ self.sigmoid = nn.Sigmoid()
72
+ # Placeholder for dynamically generated init:
73
+ # Transformer Layer Initialization
74
+ self.transformer_layer_0 = TransformerLayer(num_heads=8, mlp_dim=2048, hidden_dim=512, dropout_rate=0.20362387412323335)
75
+ self.transformer_layer_1 = TransformerLayer(num_heads=8, mlp_dim=3072, hidden_dim=512, dropout_rate=0.29859399476669696)
76
+ self.transformer_layer_2 = TransformerLayer(num_heads=16, mlp_dim=4096, hidden_dim=512, dropout_rate=0.24029622136332746)
77
+ self.transformer_layer_3 = TransformerLayer(num_heads=8, mlp_dim=2048, hidden_dim=512, dropout_rate=0.22640265738407994)
78
+ self.transformer_layer_4 = TransformerLayer(num_heads=16, mlp_dim=3072, hidden_dim=512, dropout_rate=0.2969787366320388)
79
+ self.transformer_layer_5 = TransformerLayer(num_heads=16, mlp_dim=2048, hidden_dim=512, dropout_rate=0.11264741089870321)
80
+ self.transformer_layer_6 = TransformerLayer(num_heads=8, mlp_dim=4096, hidden_dim=512, dropout_rate=0.25324312813345734)
81
+ self.transformer_layer_7 = TransformerLayer(num_heads=8, mlp_dim=2048, hidden_dim=512, dropout_rate=0.17729069086242882)
82
+ self.transformer_layer_8 = TransformerLayer(num_heads=8, mlp_dim=2048, hidden_dim=512, dropout_rate=0.2531553780827078)
83
+ self.transformer_layer_9 = TransformerLayer(num_heads=16, mlp_dim=2048, hidden_dim=512, dropout_rate=0.17372554665581236)
84
+ self.transformer_layer_10 = TransformerLayer(num_heads=16, mlp_dim=3072, hidden_dim=512, dropout_rate=0.25217233180956183)
85
+ self.transformer_layer_11 = TransformerLayer(num_heads=8, mlp_dim=4096, hidden_dim=512, dropout_rate=0.24459590331387862)
86
+ self.transformer_layer_12 = TransformerLayer(num_heads=8, mlp_dim=2048, hidden_dim=512, dropout_rate=0.17589263405869232)
87
+ self.classifier = nn.Linear(512, 48)
88
+
89
+ def forward(self, x):
90
+ expected_dtype = self.patch_embed.patch_embed .weight.dtype
91
+ if x.dtype != expected_dtype:
92
+ x = x.to(expected_dtype)
93
+
94
+ x = self.patch_embed(x)
95
+ x = self.position_encoding(x)
96
+ # Pass through additional transformer layers
97
+ # Placeholder for dynamically generated forward pass:
98
+ x = self.transformer_layer_0(x)
99
+ x = self.transformer_layer_1(x)
100
+ x = self.transformer_layer_2(x)
101
+ x = self.transformer_layer_3(x)
102
+ x = self.transformer_layer_4(x)
103
+ x = self.transformer_layer_5(x)
104
+ x = self.transformer_layer_6(x)
105
+ x = self.transformer_layer_7(x)
106
+ x = self.transformer_layer_8(x)
107
+ x = self.transformer_layer_9(x)
108
+ x = self.transformer_layer_10(x)
109
+ x = self.transformer_layer_11(x)
110
+ x = self.transformer_layer_12(x)
111
+ x = self.classifier(x[:, 0])
112
+ #probs = self.sigmoid(x)
113
+ #return probs
114
+ return x