ymcmy commited on
Commit
48f6c67
·
verified ·
1 Parent(s): d4029cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -0
app.py CHANGED
@@ -10,6 +10,31 @@ from scipy.interpolate import griddata
10
  import matplotlib.pyplot as plt
11
  from utils import azi_diff
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class TextureContrastClassifier(nn.Module):
14
  def __init__(self, input_shape, num_heads=4, key_dim=64, ff_dim=256, rate=0.1):
15
  super(TextureContrastClassifier, self).__init__()
 
10
  import matplotlib.pyplot as plt
11
  from utils import azi_diff
12
 
13
+ class AttentionBlock(nn.Module):
14
+ def __init__(self, input_dim, num_heads, ff_dim, rate=0.2):
15
+ super(AttentionBlock, self).__init__()
16
+ self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads)
17
+ self.dropout1 = nn.Dropout(rate)
18
+ self.layer_norm1 = nn.LayerNorm(input_dim)
19
+
20
+ self.ffn = nn.Sequential(
21
+ nn.Linear(input_dim, ff_dim),
22
+ nn.ReLU(),
23
+ nn.Dropout(rate),
24
+ nn.Linear(ff_dim, input_dim),
25
+ nn.Dropout(rate)
26
+ )
27
+ self.layer_norm2 = nn.LayerNorm(input_dim)
28
+
29
+ def forward(self, x):
30
+ attn_output, _ = self.attention(x, x, x)
31
+ attn_output = self.dropout1(attn_output)
32
+ out1 = self.layer_norm1(attn_output + x)
33
+
34
+ ffn_output = self.ffn(out1)
35
+ out2 = self.layer_norm2(ffn_output + out1)
36
+ return out2
37
+
38
  class TextureContrastClassifier(nn.Module):
39
  def __init__(self, input_shape, num_heads=4, key_dim=64, ff_dim=256, rate=0.1):
40
  super(TextureContrastClassifier, self).__init__()