Spaces:
Sleeping
Sleeping
terada/init-package
#1
by
terapyon
- opened
- .gitattributes +0 -2
- .gitignore +0 -160
- README.md +0 -37
- app.py +14 -26
- inference.py +0 -223
- requirements-dev.txt +0 -1
- requirements.txt +0 -13
- saved_model/stop_words/Japanese_selection.txt +0 -3
- saved_model/topic/trained_model.bin +0 -3
- tests/__pycache__/test_app.cpython-311-pytest-8.1.1.pyc +0 -0
- tests/test_app.py +0 -30
- visualization.py +0 -31
.gitattributes
CHANGED
@@ -33,5 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
saved_model/stop_words/Japanese_selection.txt filter=lfs diff=lfs merge=lfs -text
|
37 |
-
saved_model/topic/trained_model.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
.gitignore
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
# Byte-compiled / optimized / DLL files
|
2 |
-
__pycache__/
|
3 |
-
*.py[cod]
|
4 |
-
*$py.class
|
5 |
-
|
6 |
-
# C extensions
|
7 |
-
*.so
|
8 |
-
|
9 |
-
# Distribution / packaging
|
10 |
-
.Python
|
11 |
-
build/
|
12 |
-
develop-eggs/
|
13 |
-
dist/
|
14 |
-
downloads/
|
15 |
-
eggs/
|
16 |
-
.eggs/
|
17 |
-
lib/
|
18 |
-
lib64/
|
19 |
-
parts/
|
20 |
-
sdist/
|
21 |
-
var/
|
22 |
-
wheels/
|
23 |
-
share/python-wheels/
|
24 |
-
*.egg-info/
|
25 |
-
.installed.cfg
|
26 |
-
*.egg
|
27 |
-
MANIFEST
|
28 |
-
|
29 |
-
# PyInstaller
|
30 |
-
# Usually these files are written by a python script from a template
|
31 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
-
*.manifest
|
33 |
-
*.spec
|
34 |
-
|
35 |
-
# Installer logs
|
36 |
-
pip-log.txt
|
37 |
-
pip-delete-this-directory.txt
|
38 |
-
|
39 |
-
# Unit test / coverage reports
|
40 |
-
htmlcov/
|
41 |
-
.tox/
|
42 |
-
.nox/
|
43 |
-
.coverage
|
44 |
-
.coverage.*
|
45 |
-
.cache
|
46 |
-
nosetests.xml
|
47 |
-
coverage.xml
|
48 |
-
*.cover
|
49 |
-
*.py,cover
|
50 |
-
.hypothesis/
|
51 |
-
.pytest_cache/
|
52 |
-
cover/
|
53 |
-
|
54 |
-
# Translations
|
55 |
-
*.mo
|
56 |
-
*.pot
|
57 |
-
|
58 |
-
# Django stuff:
|
59 |
-
*.log
|
60 |
-
local_settings.py
|
61 |
-
db.sqlite3
|
62 |
-
db.sqlite3-journal
|
63 |
-
|
64 |
-
# Flask stuff:
|
65 |
-
instance/
|
66 |
-
.webassets-cache
|
67 |
-
|
68 |
-
# Scrapy stuff:
|
69 |
-
.scrapy
|
70 |
-
|
71 |
-
# Sphinx documentation
|
72 |
-
docs/_build/
|
73 |
-
|
74 |
-
# PyBuilder
|
75 |
-
.pybuilder/
|
76 |
-
target/
|
77 |
-
|
78 |
-
# Jupyter Notebook
|
79 |
-
.ipynb_checkpoints
|
80 |
-
|
81 |
-
# IPython
|
82 |
-
profile_default/
|
83 |
-
ipython_config.py
|
84 |
-
|
85 |
-
# pyenv
|
86 |
-
# For a library or package, you might want to ignore these files since the code is
|
87 |
-
# intended to run in multiple environments; otherwise, check them in:
|
88 |
-
# .python-version
|
89 |
-
|
90 |
-
# pipenv
|
91 |
-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
-
# install all needed dependencies.
|
95 |
-
#Pipfile.lock
|
96 |
-
|
97 |
-
# poetry
|
98 |
-
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
-
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
-
# commonly ignored for libraries.
|
101 |
-
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
-
#poetry.lock
|
103 |
-
|
104 |
-
# pdm
|
105 |
-
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
-
#pdm.lock
|
107 |
-
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
-
# in version control.
|
109 |
-
# https://pdm.fming.dev/#use-with-ide
|
110 |
-
.pdm.toml
|
111 |
-
|
112 |
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
-
__pypackages__/
|
114 |
-
|
115 |
-
# Celery stuff
|
116 |
-
celerybeat-schedule
|
117 |
-
celerybeat.pid
|
118 |
-
|
119 |
-
# SageMath parsed files
|
120 |
-
*.sage.py
|
121 |
-
|
122 |
-
# Environments
|
123 |
-
.env
|
124 |
-
.venv
|
125 |
-
env/
|
126 |
-
venv/
|
127 |
-
ENV/
|
128 |
-
env.bak/
|
129 |
-
venv.bak/
|
130 |
-
|
131 |
-
# Spyder project settings
|
132 |
-
.spyderproject
|
133 |
-
.spyproject
|
134 |
-
|
135 |
-
# Rope project settings
|
136 |
-
.ropeproject
|
137 |
-
|
138 |
-
# mkdocs documentation
|
139 |
-
/site
|
140 |
-
|
141 |
-
# mypy
|
142 |
-
.mypy_cache/
|
143 |
-
.dmypy.json
|
144 |
-
dmypy.json
|
145 |
-
|
146 |
-
# Pyre type checker
|
147 |
-
.pyre/
|
148 |
-
|
149 |
-
# pytype static type analyzer
|
150 |
-
.pytype/
|
151 |
-
|
152 |
-
# Cython debug symbols
|
153 |
-
cython_debug/
|
154 |
-
|
155 |
-
# PyCharm
|
156 |
-
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
-
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
-
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
-
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
-
#.idea/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -17,40 +17,3 @@ license: unknown
|
|
17 |
- Python 3.11
|
18 |
- Streamlit 1.33
|
19 |
|
20 |
-
|
21 |
-
### 仮想環境
|
22 |
-
|
23 |
-
venvを用いてインストールを行います。
|
24 |
-
venvは、Pythonの標準ライブラリです。
|
25 |
-
|
26 |
-
https://docs.python.org/ja/3/tutorial/venv.html
|
27 |
-
|
28 |
-
|
29 |
-
```sh
|
30 |
-
% cd (任意のフォルダ)
|
31 |
-
% python3 -m venv venv
|
32 |
-
% source venv/bin/activate
|
33 |
-
```
|
34 |
-
|
35 |
-
### インストール
|
36 |
-
|
37 |
-
GitHubからパッケージをダウンロードしてインストール
|
38 |
-
|
39 |
-
```sh
|
40 |
-
(venv) % git clone https://github.com/awarefy/amp.git
|
41 |
-
(venv) % cd amp
|
42 |
-
(venv) % pip install -r requirements.txt -c constraints.txt
|
43 |
-
```
|
44 |
-
|
45 |
-
## 起動方法
|
46 |
-
|
47 |
-
```
|
48 |
-
(venv) % streamlit run app.py
|
49 |
-
```
|
50 |
-
|
51 |
-
## 表示確認
|
52 |
-
|
53 |
-
起動すると、デフォルトブラウザが立ち上がり表示確認ができる。
|
54 |
-
|
55 |
-
もし、ブラウザが立ち上がらない場合は、コンソールに表示されるポート付URLをブラウザで呼び出す。
|
56 |
-
|
|
|
17 |
- Python 3.11
|
18 |
- Streamlit 1.33
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1 |
-
import
|
2 |
|
3 |
-
|
4 |
-
from visualization import heatmap, html_hext
|
5 |
|
6 |
ID2CAT = {
|
7 |
-
0: "
|
8 |
-
1: "
|
9 |
}
|
10 |
explanation_text = """
|
11 |
-
このマイクロアグレッションチェッカーは、機械学習(AI
|
12 |
"""
|
13 |
attention_text = """
|
14 |
-
|
15 |
-
この技術は「文中にマイクロアグレッションに結びつく要素が含まれているかどうか」を判定するモデルであり、
|
16 |
必ずしも「この文章の書き手がマイクロアグレッションをしている」ことを明確に示すものではありません。
|
17 |
|
18 |
判定結果を元に、改めて人間同士で「なぜ/どのようにしてマイクロアグレッションたりうるか」議論をするために利用してください。
|
@@ -22,31 +20,21 @@ provide_by = """提供元: オールマイノリティプロジェクト
|
|
22 |
[https://all-minorities.com/](https://all-minorities.com/)
|
23 |
"""
|
24 |
|
|
|
25 |
st.title("マイクロアグレッション判別モデル")
|
26 |
st.markdown(explanation_text)
|
27 |
|
28 |
user_input = st.text_input("文章を入力してください:", key="user_input")
|
29 |
|
30 |
|
31 |
-
if st.button("判定"
|
32 |
if not user_input:
|
33 |
-
st.
|
34 |
else:
|
35 |
-
pred_class, input_ids, attention_list = classify_ma(user_input)
|
36 |
-
st.markdown(f"判定結果: **{ID2CAT[pred_class]}**")
|
37 |
-
if pred_class == 1:
|
38 |
-
topic_dist, ll = infer_topic(user_input)
|
39 |
-
words_atten = get_word_attn(input_ids, attention_list)
|
40 |
-
html_hext_result = html_hext(((word, attn) for word, attn in words_atten))
|
41 |
-
st.markdown(html_hext_result, unsafe_allow_html=True)
|
42 |
-
|
43 |
-
data = topic_dist.reshape(-1, 1)
|
44 |
-
st.plotly_chart(heatmap(data), use_container_width=True)
|
45 |
-
|
46 |
-
st.divider()
|
47 |
st.markdown(attention_text)
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
st.divider()
|
52 |
-
st.markdown(provide_by)
|
|
|
1 |
+
import random
|
2 |
|
3 |
+
import streamlit as st
|
|
|
4 |
|
5 |
ID2CAT = {
|
6 |
+
0: "マイクロアグレッションではない",
|
7 |
+
1: "マイクロアグレッションである",
|
8 |
}
|
9 |
explanation_text = """
|
10 |
+
このマイクロアグレッションチェッカーは、機械学習(AI技術のようなもの)によって、マイクロアグレッションらしい言語を検出できるように設計されています。
|
11 |
"""
|
12 |
attention_text = """
|
13 |
+
この技術は「文中にマイクロアグレッションに結びつく要素が含まれているかどうか」を判定するモデルになっています。
|
|
|
14 |
必ずしも「この文章の書き手がマイクロアグレッションをしている」ことを明確に示すものではありません。
|
15 |
|
16 |
判定結果を元に、改めて人間同士で「なぜ/どのようにしてマイクロアグレッションたりうるか」議論をするために利用してください。
|
|
|
20 |
[https://all-minorities.com/](https://all-minorities.com/)
|
21 |
"""
|
22 |
|
23 |
+
|
24 |
st.title("マイクロアグレッション判別モデル")
|
25 |
st.markdown(explanation_text)
|
26 |
|
27 |
user_input = st.text_input("文章を入力してください:", key="user_input")
|
28 |
|
29 |
|
30 |
+
if st.button("判定"):
|
31 |
if not user_input:
|
32 |
+
st.write("入力が空です。何か入力してください。")
|
33 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
st.markdown(attention_text)
|
35 |
+
st.divider()
|
36 |
+
random_id = random.randint(0, 1)
|
37 |
+
st.markdown(f"判定結果: **{ID2CAT[random_id]}**")
|
38 |
+
st.divider()
|
39 |
|
40 |
+
st.markdown(provide_by)
|
|
|
|
|
|
inference.py
DELETED
@@ -1,223 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Generator
|
5 |
-
from unicodedata import normalize
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import streamlit as st
|
9 |
-
import tomotopy as tp # type: ignore
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
import transformers as T # type: ignore
|
13 |
-
from huggingface_hub import PyTorchModelHubMixin # type: ignore
|
14 |
-
from scipy import stats # type: ignore
|
15 |
-
from sudachipy import dictionary, tokenizer # type: ignore
|
16 |
-
|
17 |
-
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
|
18 |
-
|
19 |
-
MODELS_PATH = Path(__file__).parent / "saved_model"
|
20 |
-
# model_base_path = MODELS_PATH / "two_class"
|
21 |
-
MODEL_BASE = "awarefy/awarefy-two_class-trained-"
|
22 |
-
topic_model_trained = MODELS_PATH / "topic" / "trained_model.bin"
|
23 |
-
japanese_selection_path = MODELS_PATH / "stop_words" / "Japanese_selection.txt"
|
24 |
-
|
25 |
-
# GPUの指定
|
26 |
-
if torch.cuda.is_available():
|
27 |
-
gpu = 0
|
28 |
-
# gpu = -1 # For debugging
|
29 |
-
else:
|
30 |
-
gpu = -1 # gpu = -1 # GPUが使用できなければ(CPUで処理)-1を指定
|
31 |
-
|
32 |
-
|
33 |
-
# cls_num = 3
|
34 |
-
max_length = 512
|
35 |
-
k_folds = 10
|
36 |
-
bert_model_name = "cl-tohoku/bert-base-japanese-v3"
|
37 |
-
device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu")
|
38 |
-
|
39 |
-
|
40 |
-
#BERTモデルの定義
|
41 |
-
class BertClassifier(nn.Module, PyTorchModelHubMixin):
|
42 |
-
def __init__(self, cls_num: int):
|
43 |
-
super().__init__()
|
44 |
-
self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True)
|
45 |
-
self.fc = nn.Linear(768, cls_num, bias=True)
|
46 |
-
|
47 |
-
nn.init.normal_(self.fc.weight, std=0.02)
|
48 |
-
nn.init.normal_(self.fc.bias, 0)
|
49 |
-
|
50 |
-
def forward(self, input_ids, masks):
|
51 |
-
result = self.bert(input_ids, masks)
|
52 |
-
|
53 |
-
vec = result[0]
|
54 |
-
_ = result[1]
|
55 |
-
attentions = result[2]
|
56 |
-
|
57 |
-
vec = vec[:, 0, :]
|
58 |
-
vec = vec.view(-1, 768)
|
59 |
-
output = self.fc(vec)
|
60 |
-
return output, _, attentions
|
61 |
-
|
62 |
-
|
63 |
-
#日本語Stopwords除去関数
|
64 |
-
def load_stopwords() -> set[str]:
|
65 |
-
with open(japanese_selection_path, "r", encoding="utf-8") as f:
|
66 |
-
# stopwords = [w.strip() for w in f]
|
67 |
-
# stopwords = set(stopwords)
|
68 |
-
stopwords = {w.strip() for w in f if w.strip()}
|
69 |
-
return stopwords
|
70 |
-
|
71 |
-
|
72 |
-
class SudachiTokenizer:
|
73 |
-
def __init__(self, split_mode="C"):
|
74 |
-
self.tokenizer_obj = dictionary.Dictionary(dict_type="full").create()
|
75 |
-
self.stopwords = load_stopwords()
|
76 |
-
if split_mode == "A":
|
77 |
-
self.mode = tokenizer.Tokenizer.SplitMode.C
|
78 |
-
elif split_mode == "B":
|
79 |
-
self.mode = tokenizer.Tokenizer.SplitMode.B
|
80 |
-
else:
|
81 |
-
self.mode = tokenizer.Tokenizer.SplitMode.C
|
82 |
-
# ひらがなのみの文字列にマッチする正規表現
|
83 |
-
self.kana_re = re.compile("^[ぁ-ゖ]+$")
|
84 |
-
#Stopwords
|
85 |
-
self.stopwords = load_stopwords()
|
86 |
-
|
87 |
-
def get_wakati(self, text: str) -> list[str]:
|
88 |
-
wakati_list = []
|
89 |
-
normalized_wakati_list = []
|
90 |
-
pos_list = []
|
91 |
-
normalized_text = normalize("NFKC", text)
|
92 |
-
tmp = re.sub(r'[0-9]','',normalized_text)
|
93 |
-
tmp = re.sub(r'[0-9]', '', tmp)
|
94 |
-
tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp)
|
95 |
-
tmp = re.sub(r'[a-zA-Z]','',tmp)
|
96 |
-
#絵文字除去
|
97 |
-
tmp = re.sub(r'[❓]', "", tmp)
|
98 |
-
for m in self.tokenizer_obj.tokenize(tmp, self.mode):
|
99 |
-
word = m.surface()
|
100 |
-
pos = m.part_of_speech()[0]
|
101 |
-
normalized_word = m.normalized_form()
|
102 |
-
wakati_list.append(word)
|
103 |
-
normalized_wakati_list.append(normalized_word)
|
104 |
-
pos_list.append(pos)
|
105 |
-
#名詞,動詞,形容詞のみに絞り込み
|
106 |
-
target_pos = ["名詞", "動詞", "形容詞"]
|
107 |
-
#target_pos = ["名詞", "形容詞"]
|
108 |
-
token_list = [t for t, p in zip(wakati_list, pos_list) if p in target_pos]
|
109 |
-
#アルファベットを小文字に統一
|
110 |
-
token_list = [t.lower() for t in token_list]
|
111 |
-
#ひらがなのみの単語を除く
|
112 |
-
#token_list = [t for t in token_list if not self.kana_re.match(t)]
|
113 |
-
#ストップワード除去
|
114 |
-
token_list = [t for t in token_list if t not in self.stopwords]
|
115 |
-
return token_list
|
116 |
-
|
117 |
-
|
118 |
-
def make_traind_model():
|
119 |
-
trained_models = []
|
120 |
-
for k in range(k_folds):
|
121 |
-
k = k + 1
|
122 |
-
# model_path = model_base_path / f"trained_model{k}.pt"
|
123 |
-
# trained_model = copy.deepcopy(bert_model)
|
124 |
-
# trained_model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
|
125 |
-
# trained_models.append(trained_model)
|
126 |
-
model_name = MODEL_BASE + str(k)
|
127 |
-
trained_model = BertClassifier.from_pretrained(model_name, token=HF_AUTH_TOKEN).to(device)
|
128 |
-
print(f"Got model {model_name}")
|
129 |
-
trained_models.append(trained_model)
|
130 |
-
return trained_models
|
131 |
-
|
132 |
-
|
133 |
-
@st.cache_resource
|
134 |
-
def init_models():
|
135 |
-
# bert_model = BertClassifier(cls_num=1) #出力ノードを1に設定
|
136 |
-
# bert_model.eval()
|
137 |
-
# bert_model.to(device)
|
138 |
-
|
139 |
-
tokenizer_sudachi = SudachiTokenizer(split_mode="C")
|
140 |
-
#Tokenizerの設定(���こではtokenizerをtokenizer_c2にしている)
|
141 |
-
tokenizer_c2 = T.BertJapaneseTokenizer.from_pretrained(bert_model_name)
|
142 |
-
# trained_models = make_traind_model(bert_model)
|
143 |
-
trained_models = make_traind_model()
|
144 |
-
return tokenizer_sudachi, tokenizer_c2, trained_models
|
145 |
-
|
146 |
-
|
147 |
-
tokenizer_sudachi, tokenizer_c2, trained_models = init_models()
|
148 |
-
|
149 |
-
|
150 |
-
# Attentionマップを算出する関数の定義
|
151 |
-
def f_a(sentences: list[str], tokenizer_c2, model, device):
|
152 |
-
encoded = tokenizer_c2.batch_encode_plus(
|
153 |
-
sentences,
|
154 |
-
padding="max_length",
|
155 |
-
max_length=max_length,
|
156 |
-
truncation=True,
|
157 |
-
return_attention_mask=True
|
158 |
-
)
|
159 |
-
|
160 |
-
input_ids = torch.tensor(encoded["input_ids"]).to(device)
|
161 |
-
attention_mask = torch.tensor(encoded["attention_mask"]).to(device)
|
162 |
-
|
163 |
-
with torch.no_grad():
|
164 |
-
outputs, _, attentions = model(input_ids, attention_mask)
|
165 |
-
#return input_ids.detach().cpu(), attentions[-1].detach().cpu()
|
166 |
-
return input_ids.detach().cpu(), attentions[-1].detach().cpu(), outputs.detach().cpu()
|
167 |
-
|
168 |
-
|
169 |
-
def get_word_attn(input_ids, attention_weight) -> Generator[tuple[str, float], None, None]:
|
170 |
-
# 文章の長さ分のzero tensorを宣言
|
171 |
-
seq_len = attention_weight.size()[2]
|
172 |
-
all_attens = torch.zeros(seq_len)
|
173 |
-
|
174 |
-
# 12個のMulti Head Attentionの結果を全部足し合わせる
|
175 |
-
# 最初の0はinput_idsは1文章だけを想定しているため
|
176 |
-
# 次の0はCLSトークンのAttention結果を取得している、という意味です。
|
177 |
-
for i in range(12):
|
178 |
-
all_attens += attention_weight[0, i, 0, :]
|
179 |
-
|
180 |
-
for word, attn in zip(input_ids.flatten(), all_attens):
|
181 |
-
if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[CLS]":
|
182 |
-
continue
|
183 |
-
if tokenizer_c2.convert_ids_to_tokens(word.tolist()) == "[SEP]":
|
184 |
-
break
|
185 |
-
converted_word = tokenizer_c2.convert_ids_to_tokens([word.numpy().tolist()])[0]
|
186 |
-
yield converted_word, attn
|
187 |
-
|
188 |
-
|
189 |
-
def classify_ma(sentence: str) -> tuple[int, torch.Tensor, torch.Tensor]:
|
190 |
-
normalized_sentence = normalize("NFKC", sentence)
|
191 |
-
tmp = re.sub(r'[0-9]','',normalized_sentence)
|
192 |
-
tmp = re.sub(r'[0-9]', '', tmp)
|
193 |
-
tmp = re.sub(r'[、。:()「」%『』()?!%→+`.・×,〜~—+=♪/!?]','',tmp)
|
194 |
-
tmp = re.sub(r'[a-zA-Z]','',tmp)
|
195 |
-
#絵文字除去
|
196 |
-
tmp = re.sub(r'[❓]', "", tmp)
|
197 |
-
|
198 |
-
attention_list, output_list = [], []
|
199 |
-
for trained_model in trained_models:
|
200 |
-
input_ids, attention, output = f_a([tmp], tokenizer_c2, trained_model, device)
|
201 |
-
attention_list.append(attention)
|
202 |
-
output_list.append(output)
|
203 |
-
|
204 |
-
#出力された10個の予測値の多数決を算出
|
205 |
-
outputs = np.concatenate(output_list)
|
206 |
-
prob_column = torch.sigmoid(torch.tensor(outputs))
|
207 |
-
pred_column = torch.ge(prob_column, 0.5).float()
|
208 |
-
ensemble_pred, count = stats.mode(pred_column)
|
209 |
-
|
210 |
-
#出力された10個のattention mapの平均値を算出
|
211 |
-
attentions = torch.concat(attention_list)
|
212 |
-
mean_attention = torch.mean(attentions, dim=0).unsqueeze(dim=0)
|
213 |
-
return ensemble_pred.item(), input_ids, mean_attention
|
214 |
-
|
215 |
-
|
216 |
-
#モデルのロードとinferの関数化
|
217 |
-
def infer_topic(new_text: str) -> tuple[np.ndarray, float]:
|
218 |
-
model_trained = tp.CTModel.load(str(topic_model_trained))
|
219 |
-
new_word_list = tokenizer_sudachi.get_wakati(new_text)
|
220 |
-
new_doc = model_trained.make_doc(new_word_list)
|
221 |
-
topic_dist, ll = model_trained.infer(new_doc)
|
222 |
-
return topic_dist, ll
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements-dev.txt
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-r requirements.txt
|
2 |
ruff
|
3 |
mypy
|
4 |
-
pytest
|
|
|
1 |
-r requirements.txt
|
2 |
ruff
|
3 |
mypy
|
|
requirements.txt
CHANGED
@@ -1,14 +1 @@
|
|
1 |
streamlit
|
2 |
-
numpy
|
3 |
-
pandas
|
4 |
-
plotly
|
5 |
-
transformers
|
6 |
-
scipy
|
7 |
-
torch
|
8 |
-
fugashi
|
9 |
-
unidic-lite
|
10 |
-
sudachipy
|
11 |
-
sudachidict_full
|
12 |
-
sudachidict_core
|
13 |
-
tomotopy
|
14 |
-
|
|
|
1 |
streamlit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
saved_model/stop_words/Japanese_selection.txt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b9654e2f6f739a61285f80538e1076d938f54f090974d8f872ad59b246a66da8
|
3 |
-
size 2202
|
|
|
|
|
|
|
|
saved_model/topic/trained_model.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:46c0cc05fcb664763839ca099f04aa5275a269cfbea8847f33214bc73affdcce
|
3 |
-
size 695117
|
|
|
|
|
|
|
|
tests/__pycache__/test_app.cpython-311-pytest-8.1.1.pyc
DELETED
Binary file (4.87 kB)
|
|
tests/test_app.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
from pathlib import Path
|
3 |
-
|
4 |
-
import pytest
|
5 |
-
from streamlit.testing.v1 import AppTest
|
6 |
-
|
7 |
-
sys.path.append(str(Path(__file__).parent.parent))
|
8 |
-
|
9 |
-
|
10 |
-
def test_text_no_input():
|
11 |
-
at = AppTest.from_file("app.py").run()
|
12 |
-
at.button[0].click().run()
|
13 |
-
assert at.warning[0].value == "入力が空です。何か入力してください。"
|
14 |
-
|
15 |
-
|
16 |
-
def test_text_with_input():
|
17 |
-
at = AppTest.from_file("app.py").run()
|
18 |
-
# at.text_input[0].assert_exists()
|
19 |
-
at.text_input[0].input("test").run()
|
20 |
-
at.button[0].click().run()
|
21 |
-
assert "判定結果: **マイクロアグレッションで" in at.markdown[2].value
|
22 |
-
|
23 |
-
|
24 |
-
@pytest.mark.skip(reason="まだ実装していないのでランダムに返ってくる")
|
25 |
-
def test_aggression():
|
26 |
-
at = AppTest.from_file("app.py").run()
|
27 |
-
text = "サンプルの入力文字列NHKの番組を見ていると,発達障害者の才能を特集されることが多い。それを見ていると自分もそのような才能を期待されているように感じる"
|
28 |
-
at.text_input[0].input(text).run()
|
29 |
-
at.button[0].click().run()
|
30 |
-
assert "提供元: " not in at.markdown[3].value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
visualization.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
from typing import Iterable
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import plotly.express as px # type: ignore
|
5 |
-
|
6 |
-
|
7 |
-
def highlight(word: str, attn: float) -> str:
|
8 |
-
color = "#%02X%02X%02X" % (255, int(255 * (1 - attn)), int(255 * (1 - attn)))
|
9 |
-
return f'<span style="background-color: {color}">{word}</span>'
|
10 |
-
|
11 |
-
|
12 |
-
def html_hext(words_attn: Iterable[tuple[str, float]]) -> str:
|
13 |
-
return " ".join(highlight(word, attn) for word, attn in words_attn)
|
14 |
-
|
15 |
-
|
16 |
-
def heatmap(data: np.ndarray):
|
17 |
-
y_labels = [
|
18 |
-
"嘲笑や特性を理解されない",
|
19 |
-
"特性や能力への攻撃",
|
20 |
-
"学校や職場で受け入れられない",
|
21 |
-
"特性をおかしいとみなされる",
|
22 |
-
"障害への差別や苦悩をなかったことにされる",
|
23 |
-
"うまくコミュニケーションがとれない",
|
24 |
-
"障害について理解されない",
|
25 |
-
"侮蔑される,認められない",
|
26 |
-
"周囲の理解不足",
|
27 |
-
"障害をなかったことにされる,責められる",
|
28 |
-
]
|
29 |
-
fig = px.imshow(data, labels=dict(x="判定", y="名称"), y=y_labels)
|
30 |
-
fig.update_layout(coloraxis_colorbar=dict(title="得点"))
|
31 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|