Update streamlit_app.py
Browse files- streamlit_app.py +17 -13
streamlit_app.py
CHANGED
@@ -2,6 +2,10 @@ import streamlit as st
|
|
2 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
3 |
import torch
|
4 |
|
|
|
|
|
|
|
|
|
5 |
id_to_label = {
|
6 |
0: 'O',
|
7 |
1: 'B-TOPIC',
|
@@ -16,8 +20,8 @@ id_to_label = {
|
|
16 |
|
17 |
@st.cache_resource
|
18 |
def load_model():
|
19 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
20 |
-
model = AutoModelForTokenClassification.from_pretrained(
|
21 |
return tokenizer, model
|
22 |
|
23 |
tokenizer, model = load_model()
|
@@ -28,8 +32,8 @@ def predict(text, model, tokenizer, id_to_label):
|
|
28 |
model.eval()
|
29 |
with torch.no_grad():
|
30 |
outputs = model(**inputs)
|
31 |
-
|
32 |
-
|
33 |
|
34 |
word_ids = inputs.word_ids(batch_index=0)
|
35 |
pred_labels = []
|
@@ -47,10 +51,9 @@ def predict(text, model, tokenizer, id_to_label):
|
|
47 |
|
48 |
def post_process(tokens, labels):
|
49 |
words, word_labels = [], []
|
50 |
-
current_word =
|
51 |
-
current_label = None
|
52 |
for token, label in zip(tokens, labels):
|
53 |
-
if token in [
|
54 |
continue
|
55 |
if token.startswith("##"):
|
56 |
current_word += token[2:]
|
@@ -80,13 +83,14 @@ def extract_entities(aligned_result):
|
|
80 |
if prefix == "B":
|
81 |
if current_entity:
|
82 |
entities.append({"entity": current_entity, "text": current_text})
|
83 |
-
current_entity
|
|
|
84 |
elif prefix == "I" and current_entity == entity_type:
|
85 |
current_text += word
|
86 |
else:
|
87 |
if current_entity:
|
88 |
entities.append({"entity": current_entity, "text": current_text})
|
89 |
-
current_entity, current_text =
|
90 |
if current_entity:
|
91 |
entities.append({"entity": current_entity, "text": current_text})
|
92 |
return entities
|
@@ -95,17 +99,17 @@ def extract_entities(aligned_result):
|
|
95 |
st.title("๐ฏ Learning Condition Extractor")
|
96 |
st.write("์ฌ์ฉ์์ ํ์ต ๋ชฉํ ๋ฌธ์ฅ์์ ์กฐ๊ฑด(TOPIC, STYLE, LENGTH, LANGUAGE)์ ์ถ์ถํฉ๋๋ค.")
|
97 |
|
98 |
-
user_input = st.text_input("ํ์ต ๋ชฉํ๋ฅผ
|
99 |
|
100 |
-
if st.button("
|
101 |
tokens, pred_labels = predict(user_input, model, tokenizer, id_to_label)
|
102 |
words, word_labels = post_process(tokens, pred_labels)
|
103 |
aligned = align_words_labels(words, word_labels)
|
104 |
entities = extract_entities(aligned)
|
105 |
|
106 |
-
result_dict = {
|
107 |
for ent in entities:
|
108 |
-
result_dict[ent[
|
109 |
|
110 |
st.subheader("๐ ์ถ์ถ๋ ์กฐ๊ฑด")
|
111 |
st.json(result_dict)
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
3 |
import torch
|
4 |
|
5 |
+
# Hugging Face ๋ชจ๋ธ ์ ์ฅ์ ๊ฒฝ๋ก
|
6 |
+
MODEL_REPO = "zhixiusue/EduTubeNavigator"
|
7 |
+
|
8 |
+
# ID โ ๋ผ๋ฒจ ๋งคํ (์ฌ์ฉ์ ์ ์)
|
9 |
id_to_label = {
|
10 |
0: 'O',
|
11 |
1: 'B-TOPIC',
|
|
|
20 |
|
21 |
@st.cache_resource
|
22 |
def load_model():
|
23 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
|
24 |
+
model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
|
25 |
return tokenizer, model
|
26 |
|
27 |
tokenizer, model = load_model()
|
|
|
32 |
model.eval()
|
33 |
with torch.no_grad():
|
34 |
outputs = model(**inputs)
|
35 |
+
logits = outputs.logits
|
36 |
+
predictions = torch.argmax(logits, dim=-1)
|
37 |
|
38 |
word_ids = inputs.word_ids(batch_index=0)
|
39 |
pred_labels = []
|
|
|
51 |
|
52 |
def post_process(tokens, labels):
|
53 |
words, word_labels = [], []
|
54 |
+
current_word, current_label = None, None
|
|
|
55 |
for token, label in zip(tokens, labels):
|
56 |
+
if token in ['[CLS]', '[SEP]', '[PAD]']:
|
57 |
continue
|
58 |
if token.startswith("##"):
|
59 |
current_word += token[2:]
|
|
|
83 |
if prefix == "B":
|
84 |
if current_entity:
|
85 |
entities.append({"entity": current_entity, "text": current_text})
|
86 |
+
current_entity = entity_type
|
87 |
+
current_text = word
|
88 |
elif prefix == "I" and current_entity == entity_type:
|
89 |
current_text += word
|
90 |
else:
|
91 |
if current_entity:
|
92 |
entities.append({"entity": current_entity, "text": current_text})
|
93 |
+
current_entity, current_text = None, ""
|
94 |
if current_entity:
|
95 |
entities.append({"entity": current_entity, "text": current_text})
|
96 |
return entities
|
|
|
99 |
st.title("๐ฏ Learning Condition Extractor")
|
100 |
st.write("์ฌ์ฉ์์ ํ์ต ๋ชฉํ ๋ฌธ์ฅ์์ ์กฐ๊ฑด(TOPIC, STYLE, LENGTH, LANGUAGE)์ ์ถ์ถํฉ๋๋ค.")
|
101 |
|
102 |
+
user_input = st.text_input("๐ฌ ํ์ต ๋ชฉํ๋ฅผ ์
๋ ฅํ์ธ์", value="์ ํ๋ธ ์์์ ์ค์ต ์์ฃผ๋ก 30๋ถ ์ด๋ด์ ๋ฐฐ์ฐ๊ณ ์ถ์ด์")
|
103 |
|
104 |
+
if st.button("๐ ์ถ์ถ ์์"):
|
105 |
tokens, pred_labels = predict(user_input, model, tokenizer, id_to_label)
|
106 |
words, word_labels = post_process(tokens, pred_labels)
|
107 |
aligned = align_words_labels(words, word_labels)
|
108 |
entities = extract_entities(aligned)
|
109 |
|
110 |
+
result_dict = {"TOPIC": None, "STYLE": None, "LENGTH": None, "LANGUAGE": None}
|
111 |
for ent in entities:
|
112 |
+
result_dict[ent["entity"]] = ent["text"]
|
113 |
|
114 |
st.subheader("๐ ์ถ์ถ๋ ์กฐ๊ฑด")
|
115 |
st.json(result_dict)
|