ymcmy commited on
Commit
e9681bf
·
verified ·
1 Parent(s): 02a3bb5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +240 -0
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+ import cv2
8
+ import PIL.Image
9
+ from scipy.interpolate import griddata
10
+ import matplotlib.pyplot as plt
11
+
12
+ def RGB2gray(rgb):
13
+ r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
14
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
15
+ return gray
16
+
17
+ # Update img_to_patches to handle direct image input
18
+ def img_to_patches(img: PIL.Image.Image) -> tuple:
19
+ patch_size = 16
20
+ img = img.convert('RGB') # Ensure image is in RGB format
21
+
22
+ grayscale_imgs = []
23
+ imgs = []
24
+ coordinates = []
25
+
26
+ for i in range(0, img.height, patch_size):
27
+ for j in range(0, img.width, patch_size):
28
+ box = (j, i, j + patch_size, i + patch_size)
29
+ img_color = np.asarray(img.crop(box))
30
+ grayscale_image = cv2.cvtColor(src=img_color, code=cv2.COLOR_RGB2GRAY)
31
+ grayscale_imgs.append(grayscale_image.astype(dtype=np.int32))
32
+ imgs.append(img_color)
33
+ normalized_coord = (i + patch_size // 2, j + patch_size // 2)
34
+ coordinates.append(normalized_coord)
35
+
36
+ return grayscale_imgs, imgs, coordinates, (img.height, img.width)
37
+
38
+ def get_l1(v):
39
+ return np.sum(np.abs(v[:, :-1] - v[:, 1:]))
40
+
41
+ def get_l2(v):
42
+ return np.sum(np.abs(v[:-1, :] - v[1:, :]))
43
+
44
+ def get_l3l4(v):
45
+ l3 = np.sum(np.abs(v[:-1, :-1] - v[1:, 1:]))
46
+ l4 = np.sum(np.abs(v[1:, :-1] - v[:-1, 1:]))
47
+ return l3 + l4
48
+
49
+ def get_pixel_var_degree_for_patch(patch: np.array) -> int:
50
+ l1 = get_l1(patch)
51
+ l2 = get_l2(patch)
52
+ l3l4 = get_l3l4(patch)
53
+ return l1 + l2 + l3l4
54
+
55
+ def get_rich_poor_patches(img: PIL.Image.Image, coloured=True):
56
+ gray_scale_patches, color_patches, coordinates, img_size = img_to_patches(img)
57
+ var_with_patch = []
58
+ for i, patch in enumerate(gray_scale_patches):
59
+ if coloured:
60
+ var_with_patch.append((get_pixel_var_degree_for_patch(patch), color_patches[i], coordinates[i]))
61
+ else:
62
+ var_with_patch.append((get_pixel_var_degree_for_patch(patch), patch, coordinates[i]))
63
+
64
+ var_with_patch.sort(reverse=True, key=lambda x: x[0])
65
+ mid_point = len(var_with_patch) // 2
66
+ r_patch = [(patch, coor) for var, patch, coor in var_with_patch[:mid_point]]
67
+ p_patch = [(patch, coor) for var, patch, coor in var_with_patch[mid_point:]]
68
+ p_patch.reverse()
69
+ return r_patch, p_patch, img_size
70
+
71
+ def azimuthalAverage(image, center=None):
72
+ y, x = np.indices(image.shape)
73
+ if not center:
74
+ center = np.array([(x.max() - x.min()) / 2.0, (y.max() - y.min()) / 2.0])
75
+ r = np.hypot(x - center[0], y - center[1])
76
+ ind = np.argsort(r.flat)
77
+ r_sorted = r.flat[ind]
78
+ i_sorted = image.flat[ind]
79
+ r_int = r_sorted.astype(int)
80
+ deltar = r_int[1:] - r_int[:-1]
81
+ rind = np.where(deltar)[0]
82
+ nr = rind[1:] - rind[:-1]
83
+ csim = np.cumsum(i_sorted, dtype=float)
84
+ tbin = csim[rind[1:]] - csim[rind[:-1]]
85
+ radial_prof = tbin / nr
86
+ return radial_prof
87
+
88
+ def azimuthal_integral(img, epsilon=1e-8, N=50):
89
+ if len(img.shape) == 3 and img.shape[2] == 3:
90
+ img = RGB2gray(img)
91
+ f = np.fft.fft2(img)
92
+ fshift = np.fft.fftshift(f)
93
+ fshift += epsilon
94
+ magnitude_spectrum = 20 * np.log(np.abs(fshift))
95
+ psd1D = azimuthalAverage(magnitude_spectrum)
96
+ points = np.linspace(0, N, num=psd1D.size)
97
+ xi = np.linspace(0, N, num=N)
98
+ interpolated = griddata(points, psd1D, xi, method='cubic')
99
+ interpolated = (interpolated - np.min(interpolated)) / (np.max(interpolated) - np.min(interpolated))
100
+ return interpolated.astype(np.float32)
101
+
102
+ def positional_emb(coor, im_size, N):
103
+ img_height, img_width = im_size
104
+ center_y, center_x = coor
105
+ normalized_y = center_y / img_height
106
+ normalized_x = center_x / img_width
107
+ pos_emb = np.zeros(N)
108
+ indices = np.arange(N)
109
+ div_term = 10000 ** (2 * (indices // 2) / N)
110
+ pos_emb[0::2] = np.sin(normalized_y / div_term[0::2]) + np.sin(normalized_x / div_term[0::2])
111
+ pos_emb[1::2] = np.cos(normalized_y / div_term[1::2]) + np.cos(normalized_x / div_term[1::2])
112
+ return pos_emb
113
+
114
+ def azi_diff(img: PIL.Image.Image, patch_num, N):
115
+ r, p, im_size = get_rich_poor_patches(img)
116
+ r_len = len(r)
117
+ p_len = len(p)
118
+ patch_emb_r = np.zeros((patch_num, N))
119
+ patch_emb_p = np.zeros((patch_num, N))
120
+ positional_emb_r = np.zeros((patch_num, N))
121
+ positional_emb_p = np.zeros((patch_num, N))
122
+ coor_r = []
123
+ coor_p = []
124
+ if r_len != 0:
125
+ for idx in range(patch_num):
126
+ tmp_patch1 = r[idx % r_len][0]
127
+ tmp_coor1 = r[idx % r_len][1]
128
+ patch_emb_r[idx] = azimuthal_integral(tmp_patch1, N=N)
129
+ positional_emb_r[idx] = positional_emb(tmp_coor1, im_size, N)
130
+ coor_r.append(tmp_coor1)
131
+ if p_len != 0:
132
+ for idx in range(patch_num):
133
+ tmp_patch2 = p[idx % p_len][0]
134
+ tmp_coor2 = p[idx % p_len][1]
135
+ patch_emb_p[idx] = azimuthal_integral(tmp_patch2, N=N)
136
+ positional_emb_p[idx] = positional_emb(tmp_coor2, im_size, N)
137
+ coor_p.append(tmp_coor2)
138
+ output = {"total_emb": [patch_emb_r + positional_emb_r / 5, patch_emb_p + positional_emb_p / 5],
139
+ "positional_emb": [positional_emb_r / 5, positional_emb_p / 5], "coor": [coor_r, coor_p],
140
+ "image_size": im_size}
141
+ return output
142
+
143
+ class AttentionBlock(nn.Module):
144
+ def __init__(self, input_dim, num_heads, ff_dim, rate=0.1):
145
+ super(AttentionBlock, self).__init__()
146
+ self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads)
147
+ self.dropout1 = nn.Dropout(rate)
148
+ self.layer_norm1 = nn.LayerNorm(input_dim)
149
+ self.ffn = nn.Sequential(
150
+ nn.Linear(input_dim, ff_dim),
151
+ nn.ReLU(),
152
+ nn.Dropout(rate),
153
+ nn.Linear(ff_dim, input_dim),
154
+ nn.Dropout(rate)
155
+ )
156
+ self.layer_norm2 = nn.LayerNorm(input_dim)
157
+
158
+ def forward(self, x):
159
+ attn_output, _ = self.attention(x, x, x)
160
+ attn_output = self.dropout1(attn_output)
161
+ out1 = self.layer_norm1(attn_output + x)
162
+ ffn_output = self.ffn(out1)
163
+ out2 = self.layer_norm2(ffn_output + out1)
164
+ return out2
165
+
166
+ class TextureContrastClassifier(nn.Module):
167
+ def __init__(self, input_shape, num_heads=4, key_dim=64, ff_dim=256, rate=0.1):
168
+ super(TextureContrastClassifier, self).__init__()
169
+ input_dim = input_shape[1]
170
+ self.rich_attention_block = AttentionBlock(input_dim, num_heads, ff_dim, rate)
171
+ self.rich_dense = nn.Sequential(
172
+ nn.Linear(input_dim, 128),
173
+ nn.ReLU(),
174
+ nn.Dropout(0.5)
175
+ )
176
+ self.poor_attention_block = AttentionBlock(input_dim, num_heads, ff_dim, rate)
177
+ self.poor_dense = nn.Sequential(
178
+ nn.Linear(input_dim, 128),
179
+ nn.ReLU(),
180
+ nn.Dropout(0.5)
181
+ )
182
+ self.fc = nn.Sequential(
183
+ nn.Linear(128 * input_shape[0], 256),
184
+ nn.ReLU(),
185
+ nn.Dropout(0.5),
186
+ nn.Linear(256, 128),
187
+ nn.ReLU(),
188
+ nn.Dropout(0.5),
189
+ nn.Linear(128, 64),
190
+ nn.ReLU(),
191
+ nn.Dropout(0.5),
192
+ nn.Linear(64, 32),
193
+ nn.ReLU(),
194
+ nn.Dropout(0.5),
195
+ nn.Linear(32, 16),
196
+ nn.ReLU(),
197
+ nn.Dropout(0.5),
198
+ nn.Linear(16, 1),
199
+ nn.Sigmoid()
200
+ )
201
+
202
+ def forward(self, rich_texture, poor_texture):
203
+ rich_texture = rich_texture.permute(1, 0, 2)
204
+ poor_texture = poor_texture.permute(1, 0, 2)
205
+ rich_attention = self.rich_attention_block(rich_texture)
206
+ rich_attention = rich_attention.permute(1, 0, 2)
207
+ rich_features = self.rich_dense(rich_attention)
208
+ poor_attention = self.poor_attention_block(poor_texture)
209
+ poor_attention = poor_attention.permute(1, 0, 2)
210
+ poor_features = self.poor_dense(poor_attention)
211
+ difference = rich_features - poor_features
212
+ difference = difference.view(difference.size(0), -1)
213
+ output = self.fc(difference)
214
+ return output
215
+
216
+ input_shape = (128, 256)
217
+ model = TextureContrastClassifier(input_shape)
218
+ model.load_state_dict(torch.load('C:/Users/Matt/Downloads/model_epoch_45.pth', map_location=torch.device('cpu')))
219
+
220
+ def inference(image, model):
221
+ predictions = []
222
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223
+ model.to(device)
224
+ model.eval()
225
+ tmp = azi_diff(image, patch_num=128, N=256)
226
+ rich = tmp["total_emb"][0]
227
+ poor = tmp["total_emb"][1]
228
+ rich_texture_tensor = torch.tensor(rich, dtype=torch.float32).unsqueeze(0).to(device)
229
+ poor_texture_tensor = torch.tensor(poor, dtype=torch.float32).unsqueeze(0).to(device)
230
+ with torch.no_grad():
231
+ output = model(rich_texture_tensor, poor_texture_tensor)
232
+ prediction = output.cpu().numpy().flatten()[0]
233
+ return prediction
234
+
235
+ # Gradio Interface
236
+ def predict(image):
237
+ prediction = inference(image, model)
238
+ return f"{prediction * 100:.2f}% chance AI-generated"
239
+
240
+ gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()