lichih commited on
Commit
8a589e9
·
verified ·
1 Parent(s): 83784d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -159
app.py CHANGED
@@ -1,159 +1,162 @@
1
- import matplotlib.pyplot as plt
2
- import torch
3
- from PIL import Image
4
- from torchvision import transforms
5
- import torch.nn.functional as F
6
- from typing import Literal, Any
7
- import gradio as gr
8
-
9
-
10
- class Classifier:
11
- LABELS = [
12
- "Panoramic",
13
- "Feature",
14
- "Detail",
15
- "Enclosed",
16
- "Focal",
17
- "Ephemeral",
18
- "Canopied",
19
- ]
20
-
21
- def __init__(
22
- self, model_path="Litton-7type-visual-landscape-model.pth", device="cuda:0"
23
- ):
24
- self.device = device
25
- self.model = torch.load(
26
- model_path, map_location=self.device, weights_only=False
27
- )
28
- if hasattr(self.model, "module"):
29
- self.model = self.model.module
30
- self.model.eval()
31
- self.preprocess = transforms.Compose(
32
- [
33
- transforms.Resize(256),
34
- transforms.CenterCrop(224),
35
- transforms.ToTensor(),
36
- transforms.Normalize(
37
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
38
- ),
39
- ]
40
- )
41
-
42
- def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]:
43
- image = image.convert("RGB")
44
- input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
45
-
46
- with torch.no_grad():
47
- logits = self.model(input_tensor)
48
- probs = F.softmax(logits[:, :7], dim=1).cpu()
49
-
50
- # probs = pd.DataFrame(
51
- # {
52
- # "class": self.LABELS,
53
- # "probs": probs[0] * 100,
54
- # }
55
- # )
56
- return draw_bar_chart(
57
- {
58
- "class": self.LABELS,
59
- "probs": probs[0] * 100,
60
- }
61
- )
62
-
63
-
64
- def draw_bar_chart(data: dict[str, list[str | float]]):
65
- classes = data["class"]
66
- probabilities = data["probs"]
67
-
68
- plt.figure(figsize=(8, 6))
69
- plt.bar(classes, probabilities, color="skyblue")
70
-
71
- plt.xlabel("Class")
72
- plt.ylabel("Probability (%)")
73
- plt.title("Class Probabilities")
74
-
75
- for i, prob in enumerate(probabilities):
76
- plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom")
77
-
78
- plt.tight_layout()
79
-
80
- return plt
81
-
82
-
83
- def get_layout():
84
- css = """
85
- .main-title {
86
- font-size: 24px;
87
- font-weight: bold;
88
- text-align: center;
89
- margin-bottom: 20px;
90
- }
91
- .reference {
92
- text-align: center;
93
- font-size: 1.2em;
94
- color: #d1d5db;
95
- margin-bottom: 20px;
96
- }
97
- .reference a {
98
- color: #FB923C;
99
- text-decoration: none;
100
- }
101
- .reference a:hover {
102
- text-decoration: underline;
103
- color: #FB923C;
104
- }
105
- .title {
106
- border-bottom: 1px solid;
107
- }
108
- .footer {
109
- text-align: center;
110
- margin-top: 30px;
111
- padding-top: 20px;
112
- border-top: 1px solid #ddd;
113
- color: #d1d5db;
114
- font-size: 14px;
115
- }
116
- """
117
- theme = gr.themes.Base(
118
- primary_hue="orange",
119
- secondary_hue="orange",
120
- neutral_hue="gray",
121
- font=gr.themes.GoogleFont("Source Sans Pro"),
122
- ).set(
123
- background_fill_primary="*neutral_950", # 主背景色(深黑)
124
- button_primary_background_fill="*primary_500", # 按鈕顏色(橘色)
125
- body_text_color="*neutral_200", # 文字顏色(淺色)
126
- )
127
- with gr.Blocks(css=css, theme=theme) as demo:
128
- gr.HTML(
129
- value=(
130
- '<div class="main-title">Litton7景觀分類模型</div>'
131
- '<div class="reference">引用資料:'
132
- '<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">'
133
- "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)"
134
- "</a>"
135
- "</div>"
136
- ),
137
- )
138
-
139
- with gr.Row():
140
- image_input = gr.Image(label="上傳影像", type="pil")
141
- bar_chart = gr.Plot(label="分類結果")
142
- start_button = gr.Button("開始分類", variant="primary")
143
-
144
- start_button.click(
145
- fn=Classifier().predict,
146
- inputs=image_input,
147
- outputs=bar_chart,
148
- )
149
-
150
- gr.HTML(
151
- '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
152
- )
153
-
154
- return demo
155
-
156
-
157
- if __name__ == "__main__":
158
- app = get_layout()
159
- app.queue().launch(server_name="0.0.0.0")
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+ from typing import Literal, Any
7
+ import gradio as gr
8
+ import space
9
+
10
+
11
+ class Classifier:
12
+ LABELS = [
13
+ "Panoramic",
14
+ "Feature",
15
+ "Detail",
16
+ "Enclosed",
17
+ "Focal",
18
+ "Ephemeral",
19
+ "Canopied",
20
+ ]
21
+
22
+ @space.GPU(duration=60)
23
+ def __init__(
24
+ self, model_path="Litton-7type-visual-landscape-model.pth", device="cuda:0"
25
+ ):
26
+ self.device = device
27
+ self.model = torch.load(
28
+ model_path, map_location=self.device, weights_only=False
29
+ )
30
+ if hasattr(self.model, "module"):
31
+ self.model = self.model.module
32
+ self.model.eval()
33
+ self.preprocess = transforms.Compose(
34
+ [
35
+ transforms.Resize(256),
36
+ transforms.CenterCrop(224),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(
39
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
40
+ ),
41
+ ]
42
+ )
43
+
44
+ @space.GPU(duration=60)
45
+ def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]:
46
+ image = image.convert("RGB")
47
+ input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
48
+
49
+ with torch.no_grad():
50
+ logits = self.model(input_tensor)
51
+ probs = F.softmax(logits[:, :7], dim=1).cpu()
52
+
53
+ # probs = pd.DataFrame(
54
+ # {
55
+ # "class": self.LABELS,
56
+ # "probs": probs[0] * 100,
57
+ # }
58
+ # )
59
+ return draw_bar_chart(
60
+ {
61
+ "class": self.LABELS,
62
+ "probs": probs[0] * 100,
63
+ }
64
+ )
65
+
66
+
67
+ def draw_bar_chart(data: dict[str, list[str | float]]):
68
+ classes = data["class"]
69
+ probabilities = data["probs"]
70
+
71
+ plt.figure(figsize=(8, 6))
72
+ plt.bar(classes, probabilities, color="skyblue")
73
+
74
+ plt.xlabel("Class")
75
+ plt.ylabel("Probability (%)")
76
+ plt.title("Class Probabilities")
77
+
78
+ for i, prob in enumerate(probabilities):
79
+ plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom")
80
+
81
+ plt.tight_layout()
82
+
83
+ return plt
84
+
85
+
86
+ def get_layout():
87
+ css = """
88
+ .main-title {
89
+ font-size: 24px;
90
+ font-weight: bold;
91
+ text-align: center;
92
+ margin-bottom: 20px;
93
+ }
94
+ .reference {
95
+ text-align: center;
96
+ font-size: 1.2em;
97
+ color: #d1d5db;
98
+ margin-bottom: 20px;
99
+ }
100
+ .reference a {
101
+ color: #FB923C;
102
+ text-decoration: none;
103
+ }
104
+ .reference a:hover {
105
+ text-decoration: underline;
106
+ color: #FB923C;
107
+ }
108
+ .title {
109
+ border-bottom: 1px solid;
110
+ }
111
+ .footer {
112
+ text-align: center;
113
+ margin-top: 30px;
114
+ padding-top: 20px;
115
+ border-top: 1px solid #ddd;
116
+ color: #d1d5db;
117
+ font-size: 14px;
118
+ }
119
+ """
120
+ theme = gr.themes.Base(
121
+ primary_hue="orange",
122
+ secondary_hue="orange",
123
+ neutral_hue="gray",
124
+ font=gr.themes.GoogleFont("Source Sans Pro"),
125
+ ).set(
126
+ background_fill_primary="*neutral_950", # 主背景色(深黑)
127
+ button_primary_background_fill="*primary_500", # 按鈕顏色(橘色)
128
+ body_text_color="*neutral_200", # 文字顏色(淺色)
129
+ )
130
+ with gr.Blocks(css=css, theme=theme) as demo:
131
+ gr.HTML(
132
+ value=(
133
+ '<div class="main-title">Litton7景觀分類模型</div>'
134
+ '<div class="reference">引用資料:'
135
+ '<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">'
136
+ "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)"
137
+ "</a>"
138
+ "</div>"
139
+ ),
140
+ )
141
+
142
+ with gr.Row():
143
+ image_input = gr.Image(label="上傳影像", type="pil")
144
+ bar_chart = gr.Plot(label="分類結果")
145
+ start_button = gr.Button("開始分類", variant="primary")
146
+
147
+ start_button.click(
148
+ fn=Classifier().predict,
149
+ inputs=image_input,
150
+ outputs=bar_chart,
151
+ )
152
+
153
+ gr.HTML(
154
+ '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
155
+ )
156
+
157
+ return demo
158
+
159
+
160
+ if __name__ == "__main__":
161
+ app = get_layout()
162
+ app.queue().launch(server_name="0.0.0.0")