Kiwinicki commited on
Commit
827021c
·
verified ·
1 Parent(s): 66754a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -3
app.py CHANGED
@@ -1,7 +1,154 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return f"Hello {name}!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
  import gradio as gr
2
+ import torch.nn as nn
3
+ from torch import tanh, Tensor
4
+ from abc import ABC, abstractmethod
5
+ from huggingface_hub import hf_hub_download
6
+ import torch
7
+ import json
8
+ from omegaconf import OmegaConf
9
+ from model import Generator
10
 
 
 
11
 
12
+ class BaseGenerator(ABC, nn.Module):
13
+ def __init__(self, channels: int = 3):
14
+ super().__init__()
15
+ self.channels = channels
16
+
17
+ @abstractmethod
18
+ def forward(self, x: Tensor) -> Tensor:
19
+ pass
20
+
21
+
22
+ class Generator(BaseGenerator):
23
+ def __init__(self, cfg: DictConfig):
24
+ super().__init__(cfg.channels)
25
+ self.cfg = cfg
26
+ self.model = self._construct_model()
27
+
28
+ def _construct_model(self):
29
+ initial_layer = nn.Sequential(
30
+ nn.Conv2d(
31
+ self.cfg.channels,
32
+ self.cfg.num_features,
33
+ kernel_size=7,
34
+ stride=1,
35
+ padding=3,
36
+ padding_mode="reflect",
37
+ ),
38
+ nn.ReLU(inplace=True),
39
+ )
40
+
41
+ down_blocks = nn.Sequential(
42
+ ConvBlock(
43
+ self.cfg.num_features,
44
+ self.cfg.num_features * 2,
45
+ kernel_size=3,
46
+ stride=2,
47
+ padding=1,
48
+ ),
49
+ ConvBlock(
50
+ self.cfg.num_features * 2,
51
+ self.cfg.num_features * 4,
52
+ kernel_size=3,
53
+ stride=2,
54
+ padding=1,
55
+ ),
56
+ )
57
+
58
+ residual_blocks = nn.Sequential(
59
+ *[
60
+ ResidualBlock(self.cfg.num_features * 4)
61
+ for _ in range(self.cfg.num_residuals)
62
+ ]
63
+ )
64
+
65
+ up_blocks = nn.Sequential(
66
+ ConvBlock(
67
+ self.cfg.num_features * 4,
68
+ self.cfg.num_features * 2,
69
+ down=False,
70
+ kernel_size=3,
71
+ stride=2,
72
+ padding=1,
73
+ output_padding=1,
74
+ ),
75
+ ConvBlock(
76
+ self.cfg.num_features * 2,
77
+ self.cfg.num_features,
78
+ down=False,
79
+ kernel_size=3,
80
+ stride=2,
81
+ padding=1,
82
+ output_padding=1,
83
+ ),
84
+ )
85
+
86
+ last_layer = nn.Conv2d(
87
+ self.cfg.num_features,
88
+ self.cfg.channels,
89
+ kernel_size=7,
90
+ stride=1,
91
+ padding=3,
92
+ padding_mode="reflect",
93
+ )
94
+
95
+ return nn.Sequential(
96
+ initial_layer, down_blocks, residual_blocks, up_blocks, last_layer
97
+ )
98
+
99
+ def forward(self, x: Tensor) -> Tensor:
100
+ return tanh(self.model(x))
101
+
102
+
103
+ class ConvBlock(nn.Module):
104
+ def __init__(
105
+ self, in_channels, out_channels, down=True, use_activation=True, **kwargs
106
+ ):
107
+ super().__init__()
108
+ self.conv = nn.Sequential(
109
+ nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
110
+ if down
111
+ else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
112
+ nn.InstanceNorm2d(out_channels),
113
+ nn.ReLU(inplace=True) if use_activation else nn.Identity(),
114
+ )
115
+
116
+ def forward(self, x: Tensor) -> Tensor:
117
+ return self.conv(x)
118
+
119
+
120
+ class ResidualBlock(nn.Module):
121
+ def __init__(self, channels: int):
122
+ super().__init__()
123
+ self.block = nn.Sequential(
124
+ ConvBlock(channels, channels, kernel_size=3, padding=1),
125
+ ConvBlock(
126
+ channels, channels, use_activation=False, kernel_size=3, padding=1
127
+ ),
128
+ )
129
+
130
+ def forward(self, x: Tensor) -> Tensor:
131
+ return x + self.block(x)
132
+
133
+
134
+ repo_id = "Kiwinicki/sat2map-generator"
135
+ generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
136
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
137
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
138
+
139
+
140
+ with open(config_path, "r") as f:
141
+ config_dict = json.load(f)
142
+ cfg = OmegaConf.create(config_dict)
143
+
144
+ generator = Generator(cfg)
145
+ generator.load_state_dict(torch.load(generator_path))
146
+ generator.eval()
147
+
148
+
149
+
150
+ def greet(iamge):
151
+ return image
152
+
153
+ iface = gr.Interface(fn=greet, inputs="image", outputs="image")
154
  iface.launch()