lichih commited on
Commit
240657a
·
verified ·
1 Parent(s): 7da1b80
Litton-7type-visual-landscape-model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:384af7f997e31c82009a65338e5061a6217d2e0e4cf82855ac03fd9bf68f7650
3
+ size 236604255
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+ from typing import Literal, Any
7
+ import gradio as gr
8
+
9
+
10
+ class Classifier:
11
+ LABELS = [
12
+ "Panoramic",
13
+ "Feature",
14
+ "Detail",
15
+ "Enclosed",
16
+ "Focal",
17
+ "Ephemeral",
18
+ "Canopied",
19
+ ]
20
+
21
+ def __init__(
22
+ self, model_path="Litton-7type-visual-landscape-model.pth", device="cuda:0"
23
+ ):
24
+ self.device = device
25
+ self.model = torch.load(
26
+ model_path, map_location=self.device, weights_only=False
27
+ )
28
+ if hasattr(self.model, "module"):
29
+ self.model = self.model.module
30
+ self.model.eval()
31
+ self.preprocess = transforms.Compose(
32
+ [
33
+ transforms.Resize(256),
34
+ transforms.CenterCrop(224),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(
37
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
38
+ ),
39
+ ]
40
+ )
41
+
42
+ def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]:
43
+ image = image.convert("RGB")
44
+ input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
45
+
46
+ with torch.no_grad():
47
+ logits = self.model(input_tensor)
48
+ probs = F.softmax(logits[:, :7], dim=1).cpu()
49
+
50
+ # probs = pd.DataFrame(
51
+ # {
52
+ # "class": self.LABELS,
53
+ # "probs": probs[0] * 100,
54
+ # }
55
+ # )
56
+ return draw_bar_chart(
57
+ {
58
+ "class": self.LABELS,
59
+ "probs": probs[0] * 100,
60
+ }
61
+ )
62
+
63
+
64
+ def draw_bar_chart(data: dict[str, list[str | float]]):
65
+ classes = data["class"]
66
+ probabilities = data["probs"]
67
+
68
+ plt.figure(figsize=(8, 6))
69
+ plt.bar(classes, probabilities, color="skyblue")
70
+
71
+ plt.xlabel("Class")
72
+ plt.ylabel("Probability (%)")
73
+ plt.title("Class Probabilities")
74
+
75
+ for i, prob in enumerate(probabilities):
76
+ plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom")
77
+
78
+ plt.tight_layout()
79
+
80
+ return plt
81
+
82
+
83
+ def get_layout():
84
+ css = """
85
+ .main-title {
86
+ font-size: 24px;
87
+ font-weight: bold;
88
+ text-align: center;
89
+ margin-bottom: 20px;
90
+ }
91
+ .reference {
92
+ text-align: center;
93
+ font-size: 1.2em;
94
+ color: #d1d5db;
95
+ margin-bottom: 20px;
96
+ }
97
+ .reference a {
98
+ color: #FB923C;
99
+ text-decoration: none;
100
+ }
101
+ .reference a:hover {
102
+ text-decoration: underline;
103
+ color: #FB923C;
104
+ }
105
+ .title {
106
+ border-bottom: 1px solid;
107
+ }
108
+ .footer {
109
+ text-align: center;
110
+ margin-top: 30px;
111
+ padding-top: 20px;
112
+ border-top: 1px solid #ddd;
113
+ color: #d1d5db;
114
+ font-size: 14px;
115
+ }
116
+ """
117
+ theme = gr.themes.Base(
118
+ primary_hue="orange",
119
+ secondary_hue="orange",
120
+ neutral_hue="gray",
121
+ font=gr.themes.GoogleFont("Source Sans Pro"),
122
+ ).set(
123
+ background_fill_primary="*neutral_950", # 主背景色(深黑)
124
+ button_primary_background_fill="*primary_500", # 按鈕顏色(橘色)
125
+ body_text_color="*neutral_200", # 文字顏色(淺色)
126
+ )
127
+ with gr.Blocks(css=css, theme=theme) as demo:
128
+ gr.HTML(
129
+ value=(
130
+ '<div class="main-title">Litton7景觀分類模型</div>'
131
+ '<div class="reference">引用資料:'
132
+ '<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">'
133
+ "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)"
134
+ "</a>"
135
+ "</div>"
136
+ ),
137
+ )
138
+
139
+ with gr.Row():
140
+ image_input = gr.Image(label="上傳影像", type="pil")
141
+ bar_chart = gr.Plot(label="分類結果")
142
+ start_button = gr.Button("開始分類", variant="primary")
143
+
144
+ start_button.click(
145
+ fn=Classifier().predict,
146
+ inputs=image_input,
147
+ outputs=bar_chart,
148
+ )
149
+
150
+ gr.HTML(
151
+ '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
152
+ )
153
+
154
+ return demo
155
+
156
+
157
+ if __name__ == "__main__":
158
+ app = get_layout()
159
+ app.queue().launch(server_name="0.0.0.0")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ pillow
3
+ gradio==5.5.0
4
+ torch==2.5.1
5
+ torchvision