lichih commited on
Commit
14a458d
Β·
verified Β·
1 Parent(s): a0f4a5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -18
app.py CHANGED
@@ -1,15 +1,16 @@
1
- import matplotlib.pyplot as plt
2
- import torch
3
  from PIL import Image
 
4
  from torchvision import transforms
5
- import torch.nn.functional as F
6
  from typing import Literal, Any
7
  import gradio as gr
 
 
8
  import spaces
9
- from io import BytesIO
 
10
 
11
 
12
- class Litton7Classifier:
13
  LABELS = [
14
  "Panoramic",
15
  "Feature",
@@ -47,7 +48,7 @@ class Litton7Classifier:
47
  )
48
 
49
  @spaces.GPU(duration=60)
50
- def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]:
51
  image = image.convert("RGB")
52
  input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
53
 
@@ -67,19 +68,20 @@ def draw_bar_chart(data: dict[str, list[str | float]]):
67
  classes = data["class"]
68
  probabilities = data["probs"]
69
 
70
- plt.figure(figsize=(8, 6))
71
- plt.bar(classes, probabilities, color="skyblue")
 
72
 
73
- plt.xlabel("Class")
74
- plt.ylabel("Probability (%)")
75
- plt.title("Class Probabilities")
76
 
77
  for i, prob in enumerate(probabilities):
78
- plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom")
79
 
80
- plt.tight_layout()
81
 
82
- return plt
83
 
84
 
85
  def get_layout():
@@ -138,17 +140,17 @@ def get_layout():
138
  "</div>"
139
  ),
140
  )
141
-
142
  with gr.Row(equal_height=True):
143
  image_input = gr.Image(label="δΈŠε‚³ε½±εƒ", type="pil")
144
- chart = gr.Image(label="εˆ†ι‘žη΅ζžœ")
145
 
146
  start_button = gr.Button("ι–‹ε§‹εˆ†ι‘ž", variant="primary")
147
  gr.HTML(
148
  '<div class="footer">Β© 2024 LCL η‰ˆζ¬Šζ‰€ζœ‰<br>ι–‹η™Όθ€…οΌšδ½•η«‹ζ™Ίγ€ζ₯Šε“²ηΏ</div>',
149
  )
150
  start_button.click(
151
- fn=Litton7Classifier().predict,
152
  inputs=image_input,
153
  outputs=chart,
154
  )
@@ -157,4 +159,4 @@ def get_layout():
157
 
158
 
159
  if __name__ == "__main__":
160
- get_layout().launch()
 
 
 
1
  from PIL import Image
2
+ from io import BytesIO
3
  from torchvision import transforms
 
4
  from typing import Literal, Any
5
  import gradio as gr
6
+ from matplotlib.figure import Figure
7
+ import matplotlib.pyplot as plt
8
  import spaces
9
+ import torch
10
+ import torch.nn.functional as F
11
 
12
 
13
+ class Classifier:
14
  LABELS = [
15
  "Panoramic",
16
  "Feature",
 
48
  )
49
 
50
  @spaces.GPU(duration=60)
51
+ def predict(self, image: Image.Image) -> Figure:
52
  image = image.convert("RGB")
53
  input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
54
 
 
68
  classes = data["class"]
69
  probabilities = data["probs"]
70
 
71
+ #fig = plt.figure()
72
+ fig, ax = plt.subplots(figsize=(8, 6))
73
+ ax.bar(classes, probabilities, color="skyblue")
74
 
75
+ ax.set_xlabel("Class")
76
+ ax.set_ylabel("Probability (%)")
77
+ ax.set_title("Class Probabilities")
78
 
79
  for i, prob in enumerate(probabilities):
80
+ ax.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom")
81
 
82
+ fig.tight_layout()
83
 
84
+ return fig
85
 
86
 
87
  def get_layout():
 
140
  "</div>"
141
  ),
142
  )
143
+
144
  with gr.Row(equal_height=True):
145
  image_input = gr.Image(label="δΈŠε‚³ε½±εƒ", type="pil")
146
+ chart = gr.Plot(label="εˆ†ι‘žη΅ζžœ")
147
 
148
  start_button = gr.Button("ι–‹ε§‹εˆ†ι‘ž", variant="primary")
149
  gr.HTML(
150
  '<div class="footer">Β© 2024 LCL η‰ˆζ¬Šζ‰€ζœ‰<br>ι–‹η™Όθ€…οΌšδ½•η«‹ζ™Ίγ€ζ₯Šε“²ηΏ</div>',
151
  )
152
  start_button.click(
153
+ fn=Classifier().predict,
154
  inputs=image_input,
155
  outputs=chart,
156
  )
 
159
 
160
 
161
  if __name__ == "__main__":
162
+ get_layout().launch()