File size: 4,029 Bytes
e9681bf
 
 
 
 
 
 
 
 
 
d4029cb
e9681bf
48f6c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9681bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e41eb8
e9681bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import cv2
import PIL.Image
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
from utils import azi_diff

class AttentionBlock(nn.Module):
    def __init__(self, input_dim, num_heads, ff_dim, rate=0.2):
        super(AttentionBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads)
        self.dropout1 = nn.Dropout(rate)
        self.layer_norm1 = nn.LayerNorm(input_dim)

        self.ffn = nn.Sequential(
            nn.Linear(input_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(rate),
            nn.Linear(ff_dim, input_dim),
            nn.Dropout(rate)
        )
        self.layer_norm2 = nn.LayerNorm(input_dim)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        attn_output = self.dropout1(attn_output)
        out1 = self.layer_norm1(attn_output + x)

        ffn_output = self.ffn(out1)
        out2 = self.layer_norm2(ffn_output + out1)
        return out2

class TextureContrastClassifier(nn.Module):
    def __init__(self, input_shape, num_heads=4, key_dim=64, ff_dim=256, rate=0.1):
        super(TextureContrastClassifier, self).__init__()
        input_dim = input_shape[1]
        self.rich_attention_block = AttentionBlock(input_dim, num_heads, ff_dim, rate)
        self.rich_dense = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.poor_attention_block = AttentionBlock(input_dim, num_heads, ff_dim, rate)
        self.poor_dense = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * input_shape[0], 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, rich_texture, poor_texture):
        rich_texture = rich_texture.permute(1, 0, 2)
        poor_texture = poor_texture.permute(1, 0, 2)
        rich_attention = self.rich_attention_block(rich_texture)
        rich_attention = rich_attention.permute(1, 0, 2)
        rich_features = self.rich_dense(rich_attention)
        poor_attention = self.poor_attention_block(poor_texture)
        poor_attention = poor_attention.permute(1, 0, 2)
        poor_features = self.poor_dense(poor_attention)
        difference = rich_features - poor_features
        difference = difference.view(difference.size(0), -1)
        output = self.fc(difference)
        return output

input_shape = (128, 256)
model = TextureContrastClassifier(input_shape)
model.load_state_dict(torch.load('./model_epoch_36.pth', map_location=torch.device('cpu')))

def inference(image, model):
    predictions = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    tmp = azi_diff(image, patch_num=128, N=256)
    rich = tmp["total_emb"][0]
    poor = tmp["total_emb"][1]
    rich_texture_tensor = torch.tensor(rich, dtype=torch.float32).unsqueeze(0).to(device)
    poor_texture_tensor = torch.tensor(poor, dtype=torch.float32).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(rich_texture_tensor, poor_texture_tensor)
    prediction = output.cpu().numpy().flatten()[0]
    return prediction

# Gradio Interface
def predict(image):
    prediction = inference(image, model)
    return f"{prediction * 100:.2f}% chance AI-generated"

gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()