terapyon commited on
Commit
4ea53ee
·
1 Parent(s): f179595

一通りの表示を作った

Browse files
Files changed (3) hide show
  1. app.py +13 -4
  2. requirements.txt +3 -0
  3. visualization.py +31 -0
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import random
2
 
 
3
  import streamlit as st
4
 
 
 
5
  ID2CAT = {
6
  0: "マイクロアグレッションではない",
7
  1: "マイクロアグレッションである",
@@ -20,7 +23,6 @@ provide_by = """提供元: オールマイノリティプロジェクト
20
  [https://all-minorities.com/](https://all-minorities.com/)
21
  """
22
 
23
-
24
  st.title("マイクロアグレッション判別モデル")
25
  st.markdown(explanation_text)
26
 
@@ -33,8 +35,15 @@ if st.button("判定"):
33
  else:
34
  st.markdown(attention_text)
35
  st.divider()
36
- random_id = random.randint(0, 1)
37
  st.markdown(f"判定結果: **{ID2CAT[random_id]}**")
38
- st.divider()
 
 
 
 
 
 
39
 
40
- st.markdown(provide_by)
 
 
1
  import random
2
 
3
+ import numpy as np
4
  import streamlit as st
5
 
6
+ from visualization import heatmap, html_hext
7
+
8
  ID2CAT = {
9
  0: "マイクロアグレッションではない",
10
  1: "マイクロアグレッションである",
 
23
  [https://all-minorities.com/](https://all-minorities.com/)
24
  """
25
 
 
26
  st.title("マイクロアグレッション判別モデル")
27
  st.markdown(explanation_text)
28
 
 
35
  else:
36
  st.markdown(attention_text)
37
  st.divider()
38
+ random_id = random.randint(0, 1) # TODO: make dummy data
39
  st.markdown(f"判定結果: **{ID2CAT[random_id]}**")
40
+ if random_id == 1:
41
+ rng = np.random.default_rng() # TODO: make dummy data
42
+ html_hext_result = html_hext(((f"単語{i}", rng.random()) for i in range(7))) # TODO: make dummy data
43
+ st.markdown(html_hext_result, unsafe_allow_html=True)
44
+ data = rng.random((10, 1)).reshape(-1, 1) # TODO: make dummy data
45
+ st.plotly_chart(heatmap(data), use_container_width=True)
46
+
47
 
48
+ st.divider()
49
+ st.markdown(provide_by)
requirements.txt CHANGED
@@ -1 +1,4 @@
1
  streamlit
 
 
 
 
1
  streamlit
2
+ numpy
3
+ pandas
4
+ plotly
visualization.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable
2
+
3
+ import numpy as np
4
+ import plotly.express as px # type: ignore
5
+
6
+
7
+ def highlight(word: str, attn: float) -> str:
8
+ color = "#%02X%02X%02X" % (255, int(255 * (1 - attn)), int(255 * (1 - attn)))
9
+ return f'<span style="background-color: {color}">{word}</span>'
10
+
11
+
12
+ def html_hext(words_attn: Iterable[tuple[str, float]]) -> str:
13
+ return " ".join(highlight(word, attn) for word, attn in words_attn)
14
+
15
+
16
+ def heatmap(data: np.ndarray):
17
+ y_labels = [
18
+ "嘲笑や特性を理解されない",
19
+ "特性や能力への攻撃",
20
+ "学校や職場で受け入れられない",
21
+ "特性をおかしいとみなされる",
22
+ "障害への差別や苦悩をなかったことにされる",
23
+ "うまくコミュニケーションがとれない",
24
+ "障害について理解されない",
25
+ "侮蔑される,認められない",
26
+ "周囲の理解不足",
27
+ "障害をなかったことにされる,責められる",
28
+ ]
29
+ fig = px.imshow(data, labels=dict(x="判定", y="名称"), y=y_labels)
30
+ fig.update_layout(coloraxis_colorbar=dict(title="得点"))
31
+ return fig