Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1 |
from PIL import Image
|
2 |
from io import BytesIO
|
|
|
3 |
from torchvision import transforms
|
|
|
4 |
from typing import Literal, Any
|
|
|
5 |
import gradio as gr
|
6 |
-
from matplotlib.figure import Figure
|
7 |
import matplotlib.pyplot as plt
|
|
|
8 |
import spaces
|
|
|
9 |
import torch
|
10 |
import torch.nn.functional as F
|
11 |
|
@@ -19,11 +23,28 @@ LABELS = [
|
|
19 |
"Ephemeral",
|
20 |
"Canopied",
|
21 |
]
|
|
|
|
|
22 |
|
23 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
model = torch.load(
|
26 |
-
|
27 |
).module
|
28 |
model.eval()
|
29 |
preprocess = transforms.Compose(
|
@@ -73,6 +94,14 @@ def draw_bar_chart(data: dict[str, list[str | float]]):
|
|
73 |
return fig
|
74 |
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def get_layout():
|
77 |
css = """
|
78 |
.main-title {
|
@@ -106,6 +135,10 @@ def get_layout():
|
|
106 |
color: #d1d5db;
|
107 |
font-size: 14px;
|
108 |
}
|
|
|
|
|
|
|
|
|
109 |
"""
|
110 |
theme = gr.themes.Base(
|
111 |
primary_hue="orange",
|
@@ -153,22 +186,56 @@ def get_layout():
|
|
153 |
)
|
154 |
|
155 |
with gr.Row(equal_height=True):
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
chart = gr.Plot(label="分類結果")
|
158 |
|
159 |
-
start_button = gr.Button("
|
160 |
gr.HTML(
|
161 |
'<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
|
162 |
)
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
169 |
|
170 |
return demo
|
171 |
|
172 |
|
173 |
if __name__ == "__main__":
|
174 |
-
get_layout().
|
|
|
1 |
from PIL import Image
|
2 |
from io import BytesIO
|
3 |
+
from matplotlib.figure import Figure
|
4 |
from torchvision import transforms
|
5 |
+
from tqdm import tqdm
|
6 |
from typing import Literal, Any
|
7 |
+
from urllib.request import urlopen
|
8 |
import gradio as gr
|
|
|
9 |
import matplotlib.pyplot as plt
|
10 |
+
import os
|
11 |
import spaces
|
12 |
+
import sys
|
13 |
import torch
|
14 |
import torch.nn.functional as F
|
15 |
|
|
|
23 |
"Ephemeral",
|
24 |
"Canopied",
|
25 |
]
|
26 |
+
MODELFILE = "Litton-7type-visual-landscape-model.pth"
|
27 |
+
|
28 |
|
29 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
30 |
|
31 |
+
if not os.path.exists(MODELFILE):
|
32 |
+
model_url = f"https://lclab.thu.edu.tw/modelzoo/{MODELFILE}"
|
33 |
+
|
34 |
+
print(f"fetch model from {model_url}...", file=sys.stderr)
|
35 |
+
|
36 |
+
with urlopen(model_url) as resp:
|
37 |
+
progress = tqdm(total=int(resp["Content-Length"]), desc="Downloading")
|
38 |
+
with open(MODELFILE, "wb") as modelfile:
|
39 |
+
while True:
|
40 |
+
chunk = resp.read(1024)
|
41 |
+
if len(chunk) == 0:
|
42 |
+
break
|
43 |
+
modelfile.write(chunk)
|
44 |
+
progress.update(len(chunk))
|
45 |
+
|
46 |
model = torch.load(
|
47 |
+
MODELFILE, map_location=device, weights_only=False
|
48 |
).module
|
49 |
model.eval()
|
50 |
preprocess = transforms.Compose(
|
|
|
94 |
return fig
|
95 |
|
96 |
|
97 |
+
def choose_example(imgpath: str) -> gr.Image:
|
98 |
+
img = Image.open(imgpath)
|
99 |
+
width, height = img.size
|
100 |
+
ratio = 512 / max(width, height)
|
101 |
+
img = img.resize((int(width * ratio), int(height * ratio)))
|
102 |
+
return gr.Image(value=img, label="輸入影像(不支援 SVG 格式)", type="pil")
|
103 |
+
|
104 |
+
|
105 |
def get_layout():
|
106 |
css = """
|
107 |
.main-title {
|
|
|
135 |
color: #d1d5db;
|
136 |
font-size: 14px;
|
137 |
}
|
138 |
+
.example-image {
|
139 |
+
height: 220px;
|
140 |
+
padding: 25px;
|
141 |
+
}
|
142 |
"""
|
143 |
theme = gr.themes.Base(
|
144 |
primary_hue="orange",
|
|
|
186 |
)
|
187 |
|
188 |
with gr.Row(equal_height=True):
|
189 |
+
with gr.Group():
|
190 |
+
img = gr.Image(label="上傳影像", type="pil", height="256px")
|
191 |
+
gr.Label("範例影像", show_label=False)
|
192 |
+
with gr.Row():
|
193 |
+
ex1 = gr.Image(
|
194 |
+
value="examples/beach.jpg",
|
195 |
+
show_label=False,
|
196 |
+
type="filepath",
|
197 |
+
elem_classes="example-image",
|
198 |
+
interactive=False,
|
199 |
+
show_download_button=False,
|
200 |
+
show_fullscreen_button=False,
|
201 |
+
)
|
202 |
+
ex2 = gr.Image(
|
203 |
+
value="examples/field.jpg",
|
204 |
+
show_label=False,
|
205 |
+
type="filepath",
|
206 |
+
elem_classes="example-image",
|
207 |
+
interactive=False,
|
208 |
+
show_download_button=False,
|
209 |
+
show_fullscreen_button=False,
|
210 |
+
)
|
211 |
+
ex3 = gr.Image(
|
212 |
+
value="examples/sky.jpg",
|
213 |
+
show_label=False,
|
214 |
+
type="filepath",
|
215 |
+
elem_classes="example-image",
|
216 |
+
interactive=False,
|
217 |
+
show_download_button=False,
|
218 |
+
show_fullscreen_button=False,
|
219 |
+
)
|
220 |
chart = gr.Plot(label="分類結果")
|
221 |
|
222 |
+
start_button = gr.Button("開始", variant="primary")
|
223 |
gr.HTML(
|
224 |
'<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
|
225 |
)
|
226 |
|
227 |
+
start_button.click(
|
228 |
+
fn=predict,
|
229 |
+
inputs=img,
|
230 |
+
outputs=chart,
|
231 |
+
)
|
232 |
+
|
233 |
+
ex1.select(fn=choose_example, inputs=ex1, outputs=img)
|
234 |
+
ex2.select(fn=choose_example, inputs=ex2, outputs=img)
|
235 |
+
ex3.select(fn=choose_example, inputs=ex3, outputs=img)
|
236 |
|
237 |
return demo
|
238 |
|
239 |
|
240 |
if __name__ == "__main__":
|
241 |
+
get_layout().launch()
|