lichih commited on
Commit
9717fe0
·
verified ·
1 Parent(s): 173336d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -10
app.py CHANGED
@@ -1,11 +1,15 @@
1
  from PIL import Image
2
  from io import BytesIO
 
3
  from torchvision import transforms
 
4
  from typing import Literal, Any
 
5
  import gradio as gr
6
- from matplotlib.figure import Figure
7
  import matplotlib.pyplot as plt
 
8
  import spaces
 
9
  import torch
10
  import torch.nn.functional as F
11
 
@@ -19,11 +23,28 @@ LABELS = [
19
  "Ephemeral",
20
  "Canopied",
21
  ]
 
 
22
 
23
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  model = torch.load(
26
- "Litton-7type-visual-landscape-model.pth", map_location=device, weights_only=False
27
  ).module
28
  model.eval()
29
  preprocess = transforms.Compose(
@@ -73,6 +94,14 @@ def draw_bar_chart(data: dict[str, list[str | float]]):
73
  return fig
74
 
75
 
 
 
 
 
 
 
 
 
76
  def get_layout():
77
  css = """
78
  .main-title {
@@ -106,6 +135,10 @@ def get_layout():
106
  color: #d1d5db;
107
  font-size: 14px;
108
  }
 
 
 
 
109
  """
110
  theme = gr.themes.Base(
111
  primary_hue="orange",
@@ -153,22 +186,56 @@ def get_layout():
153
  )
154
 
155
  with gr.Row(equal_height=True):
156
- image_input = gr.Image(label="上傳影像", type="pil")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  chart = gr.Plot(label="分類結果")
158
 
159
- start_button = gr.Button("開始分類", variant="primary")
160
  gr.HTML(
161
  '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
162
  )
163
 
164
- start_button.click(
165
- fn=predict,
166
- inputs=image_input,
167
- outputs=chart,
168
- )
 
 
 
 
169
 
170
  return demo
171
 
172
 
173
  if __name__ == "__main__":
174
- get_layout().queue().launch()
 
1
  from PIL import Image
2
  from io import BytesIO
3
+ from matplotlib.figure import Figure
4
  from torchvision import transforms
5
+ from tqdm import tqdm
6
  from typing import Literal, Any
7
+ from urllib.request import urlopen
8
  import gradio as gr
 
9
  import matplotlib.pyplot as plt
10
+ import os
11
  import spaces
12
+ import sys
13
  import torch
14
  import torch.nn.functional as F
15
 
 
23
  "Ephemeral",
24
  "Canopied",
25
  ]
26
+ MODELFILE = "Litton-7type-visual-landscape-model.pth"
27
+
28
 
29
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
 
31
+ if not os.path.exists(MODELFILE):
32
+ model_url = f"https://lclab.thu.edu.tw/modelzoo/{MODELFILE}"
33
+
34
+ print(f"fetch model from {model_url}...", file=sys.stderr)
35
+
36
+ with urlopen(model_url) as resp:
37
+ progress = tqdm(total=int(resp["Content-Length"]), desc="Downloading")
38
+ with open(MODELFILE, "wb") as modelfile:
39
+ while True:
40
+ chunk = resp.read(1024)
41
+ if len(chunk) == 0:
42
+ break
43
+ modelfile.write(chunk)
44
+ progress.update(len(chunk))
45
+
46
  model = torch.load(
47
+ MODELFILE, map_location=device, weights_only=False
48
  ).module
49
  model.eval()
50
  preprocess = transforms.Compose(
 
94
  return fig
95
 
96
 
97
+ def choose_example(imgpath: str) -> gr.Image:
98
+ img = Image.open(imgpath)
99
+ width, height = img.size
100
+ ratio = 512 / max(width, height)
101
+ img = img.resize((int(width * ratio), int(height * ratio)))
102
+ return gr.Image(value=img, label="輸入影像(不支援 SVG 格式)", type="pil")
103
+
104
+
105
  def get_layout():
106
  css = """
107
  .main-title {
 
135
  color: #d1d5db;
136
  font-size: 14px;
137
  }
138
+ .example-image {
139
+ height: 220px;
140
+ padding: 25px;
141
+ }
142
  """
143
  theme = gr.themes.Base(
144
  primary_hue="orange",
 
186
  )
187
 
188
  with gr.Row(equal_height=True):
189
+ with gr.Group():
190
+ img = gr.Image(label="上傳影像", type="pil", height="256px")
191
+ gr.Label("範例影像", show_label=False)
192
+ with gr.Row():
193
+ ex1 = gr.Image(
194
+ value="examples/beach.jpg",
195
+ show_label=False,
196
+ type="filepath",
197
+ elem_classes="example-image",
198
+ interactive=False,
199
+ show_download_button=False,
200
+ show_fullscreen_button=False,
201
+ )
202
+ ex2 = gr.Image(
203
+ value="examples/field.jpg",
204
+ show_label=False,
205
+ type="filepath",
206
+ elem_classes="example-image",
207
+ interactive=False,
208
+ show_download_button=False,
209
+ show_fullscreen_button=False,
210
+ )
211
+ ex3 = gr.Image(
212
+ value="examples/sky.jpg",
213
+ show_label=False,
214
+ type="filepath",
215
+ elem_classes="example-image",
216
+ interactive=False,
217
+ show_download_button=False,
218
+ show_fullscreen_button=False,
219
+ )
220
  chart = gr.Plot(label="分類結果")
221
 
222
+ start_button = gr.Button("開始", variant="primary")
223
  gr.HTML(
224
  '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
225
  )
226
 
227
+ start_button.click(
228
+ fn=predict,
229
+ inputs=img,
230
+ outputs=chart,
231
+ )
232
+
233
+ ex1.select(fn=choose_example, inputs=ex1, outputs=img)
234
+ ex2.select(fn=choose_example, inputs=ex2, outputs=img)
235
+ ex3.select(fn=choose_example, inputs=ex3, outputs=img)
236
 
237
  return demo
238
 
239
 
240
  if __name__ == "__main__":
241
+ get_layout().launch()