terada/init-package

#1
by terapyon - opened
.gitattributes CHANGED
@@ -33,5 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- saved_model/stop_words/Japanese_selection.txt filter=lfs diff=lfs merge=lfs -text
37
- saved_model/topic/trained_model.bin filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
.gitignore DELETED
@@ -1,160 +0,0 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # poetry
98
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- #poetry.lock
103
-
104
- # pdm
105
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
- #pdm.lock
107
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
- # in version control.
109
- # https://pdm.fming.dev/#use-with-ide
110
- .pdm.toml
111
-
112
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
- __pypackages__/
114
-
115
- # Celery stuff
116
- celerybeat-schedule
117
- celerybeat.pid
118
-
119
- # SageMath parsed files
120
- *.sage.py
121
-
122
- # Environments
123
- .env
124
- .venv
125
- env/
126
- venv/
127
- ENV/
128
- env.bak/
129
- venv.bak/
130
-
131
- # Spyder project settings
132
- .spyderproject
133
- .spyproject
134
-
135
- # Rope project settings
136
- .ropeproject
137
-
138
- # mkdocs documentation
139
- /site
140
-
141
- # mypy
142
- .mypy_cache/
143
- .dmypy.json
144
- dmypy.json
145
-
146
- # Pyre type checker
147
- .pyre/
148
-
149
- # pytype static type analyzer
150
- .pytype/
151
-
152
- # Cython debug symbols
153
- cython_debug/
154
-
155
- # PyCharm
156
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
- # and can be added to the global gitignore or merged into this file. For a more nuclear
159
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
- #.idea/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -17,40 +17,3 @@ license: unknown
17
  - Python 3.11
18
  - Streamlit 1.33
19
 
20
-
21
- ### 仮想環境
22
-
23
- venvを用いてインストールを行います。
24
- venvは、Pythonの標準ライブラリです。
25
-
26
- https://docs.python.org/ja/3/tutorial/venv.html
27
-
28
-
29
- ```sh
30
- % cd (任意のフォルダ)
31
- % python3 -m venv venv
32
- % source venv/bin/activate
33
- ```
34
-
35
- ### インストール
36
-
37
- GitHubからパッケージをダウンロードしてインストール
38
-
39
- ```sh
40
- (venv) % git clone https://github.com/awarefy/amp.git
41
- (venv) % cd amp
42
- (venv) % pip install -r requirements.txt -c constraints.txt
43
- ```
44
-
45
- ## 起動方法
46
-
47
- ```
48
- (venv) % streamlit run app.py
49
- ```
50
-
51
- ## 表示確認
52
-
53
- 起動すると、デフォルトブラウザが立ち上がり表示確認ができる。
54
-
55
- もし、ブラウザが立ち上がらない場合は、コンソールに表示されるポート付URLをブラウザで呼び出す。
56
-
 
17
  - Python 3.11
18
  - Streamlit 1.33
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,18 +1,16 @@
1
- import streamlit as st
2
 
3
- from inference import classify_ma, get_word_attn, infer_topic
4
- from visualization import heatmap, html_hext
5
 
6
  ID2CAT = {
7
- 0: "マイクロアグレッションではない可能性が高い",
8
- 1: "マイクロアグレッションの可能性が高い",
9
  }
10
  explanation_text = """
11
- このマイクロアグレッションチェッカーは、機械学習(AI技術のようなもの)によって、マイクロアグレッションらしい表現を検出できるように設計されています。
12
  """
13
  attention_text = """
14
- 【結果を見る際の注意点】
15
- この技術は「文中にマイクロアグレッションに結びつく要素が含まれているかどうか」を判定するモデルであり、
16
  必ずしも「この文章の書き手がマイクロアグレッションをしている」ことを明確に示すものではありません。
17
 
18
  判定結果を元に、改めて人間同士で「なぜ/どのようにしてマイクロアグレッションたりうるか」議論をするために利用してください。
@@ -22,31 +20,21 @@ provide_by = """提供元: オールマイノリティプロジェクト
22
  [https://all-minorities.com/](https://all-minorities.com/)
23
  """
24
 
 
25
  st.title("マイクロアグレッション判別モデル")
26
  st.markdown(explanation_text)
27
 
28
  user_input = st.text_input("文章を入力してください:", key="user_input")
29
 
30
 
31
- if st.button("判定", key="run"):
32
  if not user_input:
33
- st.warning("入力が空です。何か入力してください。")
34
  else:
35
- pred_class, input_ids, attention_list = classify_ma(user_input)
36
- st.markdown(f"判定結果: **{ID2CAT[pred_class]}**")
37
- if pred_class == 1:
38
- topic_dist, ll = infer_topic(user_input)
39
- words_atten = get_word_attn(input_ids, attention_list)
40
- html_hext_result = html_hext(((word, attn) for word, attn in words_atten))
41
- st.markdown(html_hext_result, unsafe_allow_html=True)
42
-
43
- data = topic_dist.reshape(-1, 1)
44
- st.plotly_chart(heatmap(data), use_container_width=True)
45
-
46
- st.divider()
47
  st.markdown(attention_text)
 
 
 
 
48
 
49
-
50
-
51
- st.divider()
52
- st.markdown(provide_by)
 
1
+ import random
2
 
3
+ import streamlit as st
 
4
 
5
  ID2CAT = {
6
+ 0: "マイクロアグレッションではない",
7
+ 1: "マイクロアグレッションである",
8
  }
9
  explanation_text = """
10
+ このマイクロアグレッションチェッカーは、機械学習(AI技術のようなもの)によって、マイクロアグレッションらしい言語を検出できるように設計されています。
11
  """
12
  attention_text = """
13
+ この技術は「文中にマイクロアグレッションに結びつく要素が含まれているかどうか」を判定するモデルになっています。
 
14
  必ずしも「この文章の書き手がマイクロアグレッションをしている」ことを明確に示すものではありません。
15
 
16
  判定結果を元に、改めて人間同士で「なぜ/どのようにしてマイクロアグレッションたりうるか」議論をするために利用してください。
 
20
  [https://all-minorities.com/](https://all-minorities.com/)
21
  """
22
 
23
+
24
  st.title("マイクロアグレッション判別モデル")
25
  st.markdown(explanation_text)
26
 
27
  user_input = st.text_input("文章を入力してください:", key="user_input")
28
 
29
 
30
+ if st.button("判定"):
31
  if not user_input:
32
+ st.write("入力が空です。何か入力してください。")
33
  else:
 
 
 
 
 
 
 
 
 
 
 
 
34
  st.markdown(attention_text)
35
+ st.divider()
36
+ random_id = random.randint(0, 1)
37
+ st.markdown(f"判定結果: **{ID2CAT[random_id]}**")
38
+ st.divider()
39
 
40
+ st.markdown(provide_by)
 
 
 
inference.py DELETED
@@ -1,223 +0,0 @@
1
- import os
2
- import re
3
- from pathlib import Path
4
- from typing import Generator
5
- from unicodedata import normalize
6
-
7
- import numpy as np
8
- import streamlit as st
9
- import tomotopy as tp # type: ignore
10
- import torch
11
- import torch.nn as nn
12
- import transformers as T # type: ignore
13
- from huggingface_hub import PyTorchModelHubMixin # type: ignore
14
- from scipy import stats # type: ignore
15
- from sudachipy import dictionary, tokenizer # type: ignore
16
-
17
- HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
18
-
19
- MODELS_PATH = Path(__file__).parent / "saved_model"
20
- # model_base_path = MODELS_PATH / "two_class"
21
- MODEL_BASE = "awarefy/awarefy-two_class-trained-"
22
- topic_model_trained = MODELS_PATH / "topic" / "trained_model.bin"
23
- japanese_selection_path = MODELS_PATH / "stop_words" / "Japanese_selection.txt"
24
-
25
- # GPUの指定
26
- if torch.cuda.is_available():
27
- gpu = 0
28
- # gpu = -1 # For debugging
29
- else:
30
- gpu = -1 # gpu = -1 # GPUが使用できなければ(CPUで処理)-1を指定
31
-
32
-
33
- # cls_num = 3
34
- max_length = 512
35
- k_folds = 10
36
- bert_model_name = "cl-tohoku/bert-base-japanese-v3"
37
- device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu")
38
-
39
-
40
- #BERTモデルの定義
41
- class BertClassifier(nn.Module, PyTorchModelHubMixin):
42
- def __init__(self, cls_num: int):
43
- super().__init__()
44
- self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True)
45
- self.fc = nn.Linear(768, cls_num, bias=True)
46
-
47
- nn.init.normal_(self.fc.weight, std=0.02)
48
- nn.init.normal_(self.fc.bias, 0)
49
-
50
- def forward(self, input_ids, masks):
51
- result = self.bert(input_ids, masks)
52
-
53
- vec = result[0]
54
- _ = result[1]
55
- attentions = result[2]
56
-
57
- vec = vec[:, 0, :]
58
- vec = vec.view(-1, 768)
59
- output = self.fc(vec)
60
- return output, _, attentions
61
-
62
-
63
- #日本語Stopwords除去関数
64
- def load_stopwords() -> set[str]:
65
- with open(japanese_selection_path, "r", encoding="utf-8") as f:
66
- # stopwords = [w.strip() for w in f]
67
- # stopwords = set(stopwords)
68
- stopwords = {w.strip() for w in f if w.strip()}
69
- return stopwords
70
-
71
-
72
- class SudachiTokenizer:
73
- def __init__(self, split_mode="C"):
74
- self.tokenizer_obj = dictionary.Dictionary(dict_type="full").create()
75
- self.stopwords = load_stopwords()
76
- if split_mode == "A":
77
- self.mode = tokenizer.Tokenizer.SplitMode.C
78
- elif split_mode == "B":
79
- self.mode = tokenizer.Tokenizer.SplitMode.B
80
- else:
81
- self.mode = tokenizer.Tokenizer.SplitMode.C
82
- # ひらがなのみの文字列にマッチする正規表現
83
- self.kana_re = re.compile("^[ぁ-ゖ]+$")
84
- #Stopwords
85
- self.stopwords = load_stopwords()
86
-
87
- def get_wakati(self, text: str) -> list[str]:
88
- wakati_list = []
89
- normalized_wakati_list = []
90
- pos_list = []
91
- normalized_text = normalize("NFKC", text)
92
- tmp = re.sub(r'[0-9]','',normalized_text)
93
- tmp = re.sub(r'[0-9]', '', tmp)
94
- tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp)
95
- tmp = re.sub(r'[a-zA-Z]','',tmp)
96
- #絵文字除去
97
- tmp = re.sub(r'[❓]', "", tmp)
98
- for m in self.tokenizer_obj.tokenize(tmp, self.mode):
99
- word = m.surface()
100
- pos = m.part_of_speech()[0]
101
- normalized_word = m.normalized_form()
102
- wakati_list.append(word)
103
- normalized_wakati_list.append(normalized_word)
104
- pos_list.append(pos)
105
- #名詞,動詞,形容詞のみに絞り込み
106
- target_pos = ["名詞", "動詞", "形容詞"]
107
- #target_pos = ["名詞", "形容詞"]
108
- token_list = [t for t, p in zip(wakati_list, pos_list) if p in target_pos]
109
- #アルファベットを小文字に統一
110
- token_list = [t.lower() for t in token_list]
111
- #ひらがなのみの単語を除く
112
- #token_list = [t for t in token_list if not self.kana_re.match(t)]
113
- #ストップワード除去
114
- token_list = [t for t in token_list if t not in self.stopwords]
115
- return token_list
116
-
117
-
118
- def make_traind_model():
119
- trained_models = []
120
- for k in range(k_folds):
121
- k = k + 1
122
- # model_path = model_base_path / f"trained_model{k}.pt"
123
- # trained_model = copy.deepcopy(bert_model)
124
- # trained_model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
125
- # trained_models.append(trained_model)
126
- model_name = MODEL_BASE + str(k)
127
- trained_model = BertClassifier.from_pretrained(model_name, token=HF_AUTH_TOKEN).to(device)
128
- print(f"Got model {model_name}")
129
- trained_models.append(trained_model)
130
- return trained_models
131
-
132
-
133
- @st.cache_resource
134
- def init_models():
135
- # bert_model = BertClassifier(cls_num=1) #出力ノードを1に設定
136
- # bert_model.eval()
137
- # bert_model.to(device)
138
-
139
- tokenizer_sudachi = SudachiTokenizer(split_mode="C")
140
- #Tokenizerの設定(���こではtokenizerをtokenizer_c2にしている)
141
- tokenizer_c2 = T.BertJapaneseTokenizer.from_pretrained(bert_model_name)
142
- # trained_models = make_traind_model(bert_model)
143
- trained_models = make_traind_model()
144
- return tokenizer_sudachi, tokenizer_c2, trained_models
145
-
146
-
147
- tokenizer_sudachi, tokenizer_c2, trained_models = init_models()
148
-
149
-
150
- # Attentionマップを算出する関数の定義
151
- def f_a(sentences: list[str], tokenizer_c2, model, device):
152
- encoded = tokenizer_c2.batch_encode_plus(
153
- sentences,
154
- padding="max_length",
155
- max_length=max_length,
156
- truncation=True,
157
- return_attention_mask=True
158
- )
159
-
160
- input_ids = torch.tensor(encoded["input_ids"]).to(device)
161
- attention_mask = torch.tensor(encoded["attention_mask"]).to(device)
162
-
163
- with torch.no_grad():
164
- outputs, _, attentions = model(input_ids, attention_mask)
165
- #return input_ids.detach().cpu(), attentions[-1].detach().cpu()
166
- return input_ids.detach().cpu(), attentions[-1].detach().cpu(), outputs.detach().cpu()
167
-
168
-
169
- def get_word_attn(input_ids, attention_weight) -> Generator[tuple[str, float], None, None]:
170
- # 文章の長さ分のzero tensorを宣言
171
- seq_len = attention_weight.size()[2]
172
- all_attens = torch.zeros(seq_len)
173
-
174
- # 12個のMulti Head Attentionの結果を全部足し合わせる
175
- # 最初の0はinput_idsは1文章だけを想定しているため
176
- # 次の0はCLSトークンのAttention結果を取得している、という意味です。
177
- for i in range(12):
178
- all_attens += attention_weight[0, i, 0, :]
179
-
180
- for word, attn in zip(input_ids.flatten(), all_attens):
181
- if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[CLS]":
182
- continue
183
- if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[SEP]":
184
- break
185
- converted_word = tokenizer_c2.convert_ids_to_tokens([word.numpy().tolist()])[0]
186
- yield converted_word, attn
187
-
188
-
189
- def classify_ma(sentence: str) -> tuple[int, torch.Tensor, torch.Tensor]:
190
- normalized_sentence = normalize("NFKC", sentence)
191
- tmp = re.sub(r'[0-9]','',normalized_sentence)
192
- tmp = re.sub(r'[0-9]', '', tmp)
193
- tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp)
194
- tmp = re.sub(r'[a-zA-Z]','',tmp)
195
- #絵文字除去
196
- tmp = re.sub(r'[❓]', "", tmp)
197
-
198
- attention_list, output_list = [], []
199
- for trained_model in trained_models:
200
- input_ids, attention, output = f_a([tmp], tokenizer_c2, trained_model, device)
201
- attention_list.append(attention)
202
- output_list.append(output)
203
-
204
- #出力された10個の予測値の多数決を算出
205
- outputs = np.concatenate(output_list)
206
- prob_column = torch.sigmoid(torch.tensor(outputs))
207
- pred_column = torch.ge(prob_column, 0.5).float()
208
- ensemble_pred, count = stats.mode(pred_column)
209
-
210
- #出力された10個のattention mapの平均値を算出
211
- attentions = torch.concat(attention_list)
212
- mean_attention = torch.mean(attentions, dim=0).unsqueeze(dim=0)
213
- return ensemble_pred.item(), input_ids, mean_attention
214
-
215
-
216
- #モデルのロードとinferの関数化
217
- def infer_topic(new_text: str) -> tuple[np.ndarray, float]:
218
- model_trained = tp.CTModel.load(str(topic_model_trained))
219
- new_word_list = tokenizer_sudachi.get_wakati(new_text)
220
- new_doc = model_trained.make_doc(new_word_list)
221
- topic_dist, ll = model_trained.infer(new_doc)
222
- return topic_dist, ll
223
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements-dev.txt CHANGED
@@ -1,4 +1,3 @@
1
  -r requirements.txt
2
  ruff
3
  mypy
4
- pytest
 
1
  -r requirements.txt
2
  ruff
3
  mypy
 
requirements.txt CHANGED
@@ -1,14 +1 @@
1
  streamlit
2
- numpy
3
- pandas
4
- plotly
5
- transformers
6
- scipy
7
- torch
8
- fugashi
9
- unidic-lite
10
- sudachipy
11
- sudachidict_full
12
- sudachidict_core
13
- tomotopy
14
-
 
1
  streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
saved_model/stop_words/Japanese_selection.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b9654e2f6f739a61285f80538e1076d938f54f090974d8f872ad59b246a66da8
3
- size 2202
 
 
 
 
saved_model/topic/trained_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:46c0cc05fcb664763839ca099f04aa5275a269cfbea8847f33214bc73affdcce
3
- size 695117
 
 
 
 
tests/__pycache__/test_app.cpython-311-pytest-8.1.1.pyc DELETED
Binary file (4.87 kB)
 
tests/test_app.py DELETED
@@ -1,30 +0,0 @@
1
- import sys
2
- from pathlib import Path
3
-
4
- import pytest
5
- from streamlit.testing.v1 import AppTest
6
-
7
- sys.path.append(str(Path(__file__).parent.parent))
8
-
9
-
10
- def test_text_no_input():
11
- at = AppTest.from_file("app.py").run()
12
- at.button[0].click().run()
13
- assert at.warning[0].value == "入力が空です。何か入力してください。"
14
-
15
-
16
- def test_text_with_input():
17
- at = AppTest.from_file("app.py").run()
18
- # at.text_input[0].assert_exists()
19
- at.text_input[0].input("test").run()
20
- at.button[0].click().run()
21
- assert "判定結果: **マイクロアグレッションで" in at.markdown[2].value
22
-
23
-
24
- @pytest.mark.skip(reason="まだ実装していないのでランダムに返ってくる")
25
- def test_aggression():
26
- at = AppTest.from_file("app.py").run()
27
- text = "サンプルの入力文字列NHKの番組を見ていると,発達障害者の才能を特集されることが多い。それを見ていると自分もそのような才能を期待されているように感じる"
28
- at.text_input[0].input(text).run()
29
- at.button[0].click().run()
30
- assert "提供元: " not in at.markdown[3].value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
visualization.py DELETED
@@ -1,31 +0,0 @@
1
- from typing import Iterable
2
-
3
- import numpy as np
4
- import plotly.express as px # type: ignore
5
-
6
-
7
- def highlight(word: str, attn: float) -> str:
8
- color = "#%02X%02X%02X" % (255, int(255 * (1 - attn)), int(255 * (1 - attn)))
9
- return f'<span style="background-color: {color}">{word}</span>'
10
-
11
-
12
- def html_hext(words_attn: Iterable[tuple[str, float]]) -> str:
13
- return " ".join(highlight(word, attn) for word, attn in words_attn)
14
-
15
-
16
- def heatmap(data: np.ndarray):
17
- y_labels = [
18
- "嘲笑や特性を理解されない",
19
- "特性や能力への攻撃",
20
- "学校や職場で受け入れられない",
21
- "特性をおかしいとみなされる",
22
- "障害への差別や苦悩をなかったことにされる",
23
- "うまくコミュニケーションがとれない",
24
- "障害について理解されない",
25
- "侮蔑される,認められない",
26
- "周囲の理解不足",
27
- "障害をなかったことにされる,責められる",
28
- ]
29
- fig = px.imshow(data, labels=dict(x="判定", y="名称"), y=y_labels)
30
- fig.update_layout(coloraxis_colorbar=dict(title="得点"))
31
- return fig