lichih commited on
Commit
f2b954d
·
verified ·
1 Parent(s): e511a97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -153
app.py CHANGED
@@ -1,159 +1,176 @@
1
- import matplotlib.pyplot as plt
2
- import torch
3
- from PIL import Image
4
- from torchvision import transforms
5
- import torch.nn.functional as F
6
- from typing import Literal, Any
7
  import gradio as gr
8
- import spaces
9
- from io import BytesIO
10
-
11
-
12
- class Classifier:
13
- LABELS = [
14
- "Panoramic",
15
- "Feature",
16
- "Detail",
17
- "Enclosed",
18
- "Focal",
19
- "Ephemeral",
20
- "Canopied",
21
- ]
22
-
23
- @spaces.GPU(duration=60)
24
- def __init__(
25
- self, model_path="Litton-7type-visual-landscape-model.pth", device="cuda:0"
26
- ):
27
- self.device = device
28
- self.model = torch.load(
29
- model_path, map_location=self.device, weights_only=False
30
- )
31
- if hasattr(self.model, "module"):
32
- self.model = self.model.module
33
- self.model.eval()
34
- self.preprocess = transforms.Compose(
35
- [
36
- transforms.Resize(256),
37
- transforms.CenterCrop(224),
38
- transforms.ToTensor(),
39
- transforms.Normalize(
40
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
41
- ),
42
- ]
43
- )
44
-
45
- @spaces.GPU(duration=60)
46
- def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]:
47
- image = image.convert("RGB")
48
- input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
49
-
50
- with torch.no_grad():
51
- logits = self.model(input_tensor)
52
- probs = F.softmax(logits[:, :7], dim=1).cpu()
53
-
54
- return draw_bar_chart(
55
- {
56
- "class": self.LABELS,
57
- "probs": probs[0] * 100,
58
- }
59
- )
60
-
61
-
62
- def draw_bar_chart(data: dict[str, list[str | float]]):
63
- classes = data["class"]
64
- probabilities = data["probs"]
65
-
66
- plt.figure(figsize=(8, 6))
67
- plt.bar(classes, probabilities, color="skyblue")
68
-
69
- plt.xlabel("Class")
70
- plt.ylabel("Probability (%)")
71
- plt.title("Class Probabilities")
72
-
73
- for i, prob in enumerate(probabilities):
74
- plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom")
75
-
76
- plt.tight_layout()
77
-
78
- return plt
79
-
80
-
81
- def get_layout():
82
- demo = gr.Interface(fn=Classifier().predict, inputs="image", outputs="plot")
83
- return demo
84
- css = """
85
- .main-title {
86
- font-size: 24px;
87
- font-weight: bold;
88
- text-align: center;
89
- margin-bottom: 20px;
90
- }
91
- .reference {
92
- text-align: center;
93
- font-size: 1.2em;
94
- color: #d1d5db;
95
- margin-bottom: 20px;
96
- }
97
- .reference a {
98
- color: #FB923C;
99
- text-decoration: none;
100
- }
101
- .reference a:hover {
102
- text-decoration: underline;
103
- color: #FB923C;
104
- }
105
- .title {
106
- border-bottom: 1px solid;
107
- }
108
- .footer {
109
- text-align: center;
110
- margin-top: 30px;
111
- padding-top: 20px;
112
- border-top: 1px solid #ddd;
113
- color: #d1d5db;
114
- font-size: 14px;
115
- }
116
- """
117
- theme = gr.themes.Base(
118
- primary_hue="orange",
119
- secondary_hue="orange",
120
- neutral_hue="gray",
121
- font=gr.themes.GoogleFont("Source Sans Pro"),
122
- ).set(
123
- background_fill_primary="*neutral_950", # 主背景色(深黑)
124
- button_primary_background_fill="*primary_500", # 按鈕顏色(橘色)
125
- body_text_color="*neutral_200", # 文字顏色(淺色)
126
- )
127
- # with gr.Blocks(css=css, theme=theme) as demo:
128
- with gr.Blocks() as demo:
129
- with gr.Column():
130
- gr.HTML(
131
- value=(
132
- '<div class="main-title">Litton7景觀分類模型</div>'
133
- '<div class="reference">引用資料:'
134
- '<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">'
135
- "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)"
136
- "</a>"
137
- "</div>"
138
- ),
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- with gr.Row(equal_height=True):
142
- image_input = gr.Image(label="上傳影像", type="pil")
143
- chart = gr.Image(label="分類結果")
144
 
145
- start_button = gr.Button("開始分類", variant="primary")
146
- gr.HTML(
147
- '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
148
- )
149
- start_button.click(
150
- fn=Classifier().predict,
151
- inputs=image_input,
152
- outputs=chart,
153
- )
154
 
155
- return demo
156
 
157
 
158
- if __name__ == "__main__":
159
- get_layout().launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
5
+
6
+ def predict(image):
7
+ predictions = pipeline(image)
8
+ return {p["label"]: p["score"] for p in predictions}
9
+
10
+ gr.Interface(
11
+ predict,
12
+ inputs=gr.inputs.Image(label="Upload hot dog candidate", type="filepath"),
13
+ outputs=gr.outputs.Label(num_top_classes=2),
14
+ title="Hot Dog? Or Not?",
15
+ allow_flagging="manual"
16
+ ).launch()
17
+
18
+ # import matplotlib.pyplot as plt
19
+ # import torch
20
+ # from PIL import Image
21
+ # from torchvision import transforms
22
+ # import torch.nn.functional as F
23
+ # from typing import Literal, Any
24
+ # import gradio as gr
25
+ # import spaces
26
+ # from io import BytesIO
27
+
28
+
29
+ # class Classifier:
30
+ # LABELS = [
31
+ # "Panoramic",
32
+ # "Feature",
33
+ # "Detail",
34
+ # "Enclosed",
35
+ # "Focal",
36
+ # "Ephemeral",
37
+ # "Canopied",
38
+ # ]
39
+
40
+ # @spaces.GPU(duration=60)
41
+ # def __init__(
42
+ # self, model_path="Litton-7type-visual-landscape-model.pth", device="cuda:0"
43
+ # ):
44
+ # self.device = device
45
+ # self.model = torch.load(
46
+ # model_path, map_location=self.device, weights_only=False
47
+ # )
48
+ # if hasattr(self.model, "module"):
49
+ # self.model = self.model.module
50
+ # self.model.eval()
51
+ # self.preprocess = transforms.Compose(
52
+ # [
53
+ # transforms.Resize(256),
54
+ # transforms.CenterCrop(224),
55
+ # transforms.ToTensor(),
56
+ # transforms.Normalize(
57
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
58
+ # ),
59
+ # ]
60
+ # )
61
+
62
+ # @spaces.GPU(duration=60)
63
+ # def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]:
64
+ # image = image.convert("RGB")
65
+ # input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
66
+
67
+ # with torch.no_grad():
68
+ # logits = self.model(input_tensor)
69
+ # probs = F.softmax(logits[:, :7], dim=1).cpu()
70
+
71
+ # return draw_bar_chart(
72
+ # {
73
+ # "class": self.LABELS,
74
+ # "probs": probs[0] * 100,
75
+ # }
76
+ # )
77
+
78
+
79
+ # def draw_bar_chart(data: dict[str, list[str | float]]):
80
+ # classes = data["class"]
81
+ # probabilities = data["probs"]
82
+
83
+ # plt.figure(figsize=(8, 6))
84
+ # plt.bar(classes, probabilities, color="skyblue")
85
+
86
+ # plt.xlabel("Class")
87
+ # plt.ylabel("Probability (%)")
88
+ # plt.title("Class Probabilities")
89
+
90
+ # for i, prob in enumerate(probabilities):
91
+ # plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom")
92
+
93
+ # plt.tight_layout()
94
+
95
+ # return plt
96
+
97
+
98
+ # def get_layout():
99
+ # demo = gr.Interface(fn=Classifier().predict, inputs="image", outputs="plot")
100
+ # return demo
101
+ # css = """
102
+ # .main-title {
103
+ # font-size: 24px;
104
+ # font-weight: bold;
105
+ # text-align: center;
106
+ # margin-bottom: 20px;
107
+ # }
108
+ # .reference {
109
+ # text-align: center;
110
+ # font-size: 1.2em;
111
+ # color: #d1d5db;
112
+ # margin-bottom: 20px;
113
+ # }
114
+ # .reference a {
115
+ # color: #FB923C;
116
+ # text-decoration: none;
117
+ # }
118
+ # .reference a:hover {
119
+ # text-decoration: underline;
120
+ # color: #FB923C;
121
+ # }
122
+ # .title {
123
+ # border-bottom: 1px solid;
124
+ # }
125
+ # .footer {
126
+ # text-align: center;
127
+ # margin-top: 30px;
128
+ # padding-top: 20px;
129
+ # border-top: 1px solid #ddd;
130
+ # color: #d1d5db;
131
+ # font-size: 14px;
132
+ # }
133
+ # """
134
+ # theme = gr.themes.Base(
135
+ # primary_hue="orange",
136
+ # secondary_hue="orange",
137
+ # neutral_hue="gray",
138
+ # font=gr.themes.GoogleFont("Source Sans Pro"),
139
+ # ).set(
140
+ # background_fill_primary="*neutral_950", # 主背景色(深黑)
141
+ # button_primary_background_fill="*primary_500", # 按鈕顏色(橘色)
142
+ # body_text_color="*neutral_200", # 文字顏色(淺色)
143
+ # )
144
+ # # with gr.Blocks(css=css, theme=theme) as demo:
145
+ # with gr.Blocks() as demo:
146
+ # with gr.Column():
147
+ # gr.HTML(
148
+ # value=(
149
+ # '<div class="main-title">Litton7景觀分類模型</div>'
150
+ # '<div class="reference">引用資料:'
151
+ # '<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">'
152
+ # "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)"
153
+ # "</a>"
154
+ # "</div>"
155
+ # ),
156
+ # )
157
 
158
+ # with gr.Row(equal_height=True):
159
+ # image_input = gr.Image(label="上傳影像", type="pil")
160
+ # chart = gr.Image(label="分類結果")
161
 
162
+ # start_button = gr.Button("開始分類", variant="primary")
163
+ # gr.HTML(
164
+ # '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
165
+ # )
166
+ # start_button.click(
167
+ # fn=Classifier().predict,
168
+ # inputs=image_input,
169
+ # outputs=chart,
170
+ # )
171
 
172
+ # return demo
173
 
174
 
175
+ # if __name__ == "__main__":
176
+ # get_layout().launch()