zhixiusue commited on
Commit
ee377c9
ยท
verified ยท
1 Parent(s): 285a9e6

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +17 -13
streamlit_app.py CHANGED
@@ -2,6 +2,10 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification
3
  import torch
4
 
 
 
 
 
5
  id_to_label = {
6
  0: 'O',
7
  1: 'B-TOPIC',
@@ -16,8 +20,8 @@ id_to_label = {
16
 
17
  @st.cache_resource
18
  def load_model():
19
- tokenizer = AutoTokenizer.from_pretrained(".")
20
- model = AutoModelForTokenClassification.from_pretrained(".")
21
  return tokenizer, model
22
 
23
  tokenizer, model = load_model()
@@ -28,8 +32,8 @@ def predict(text, model, tokenizer, id_to_label):
28
  model.eval()
29
  with torch.no_grad():
30
  outputs = model(**inputs)
31
- logits = outputs.logits
32
- predictions = torch.argmax(logits, dim=-1)
33
 
34
  word_ids = inputs.word_ids(batch_index=0)
35
  pred_labels = []
@@ -47,10 +51,9 @@ def predict(text, model, tokenizer, id_to_label):
47
 
48
  def post_process(tokens, labels):
49
  words, word_labels = [], []
50
- current_word = ""
51
- current_label = None
52
  for token, label in zip(tokens, labels):
53
- if token in ["[CLS]", "[SEP]", "[PAD]"]:
54
  continue
55
  if token.startswith("##"):
56
  current_word += token[2:]
@@ -80,13 +83,14 @@ def extract_entities(aligned_result):
80
  if prefix == "B":
81
  if current_entity:
82
  entities.append({"entity": current_entity, "text": current_text})
83
- current_entity, current_text = entity_type, word
 
84
  elif prefix == "I" and current_entity == entity_type:
85
  current_text += word
86
  else:
87
  if current_entity:
88
  entities.append({"entity": current_entity, "text": current_text})
89
- current_entity, current_text = entity_type, word
90
  if current_entity:
91
  entities.append({"entity": current_entity, "text": current_text})
92
  return entities
@@ -95,17 +99,17 @@ def extract_entities(aligned_result):
95
  st.title("๐ŸŽฏ Learning Condition Extractor")
96
  st.write("์‚ฌ์šฉ์ž์˜ ํ•™์Šต ๋ชฉํ‘œ ๋ฌธ์žฅ์—์„œ ์กฐ๊ฑด(TOPIC, STYLE, LENGTH, LANGUAGE)์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.")
97
 
98
- user_input = st.text_input("ํ•™์Šต ๋ชฉํ‘œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”:", value="๋”ฅ๋Ÿฌ๋‹์„ ์‹ค์Šต ์œ„์ฃผ๋กœ 30๋ถ„ ์ด๋‚ด์— ๋ฐฐ์šฐ๊ณ  ์‹ถ์–ด์š”")
99
 
100
- if st.button("์ถ”๋ก  ์‹œ์ž‘"):
101
  tokens, pred_labels = predict(user_input, model, tokenizer, id_to_label)
102
  words, word_labels = post_process(tokens, pred_labels)
103
  aligned = align_words_labels(words, word_labels)
104
  entities = extract_entities(aligned)
105
 
106
- result_dict = {'TOPIC': None, 'STYLE': None, 'LENGTH': None, 'LANGUAGE': None}
107
  for ent in entities:
108
- result_dict[ent['entity']] = ent['text']
109
 
110
  st.subheader("๐Ÿ“Œ ์ถ”์ถœ๋œ ์กฐ๊ฑด")
111
  st.json(result_dict)
 
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification
3
  import torch
4
 
5
+ # Hugging Face ๋ชจ๋ธ ์ €์žฅ์†Œ ๊ฒฝ๋กœ
6
+ MODEL_REPO = "zhixiusue/EduTubeNavigator"
7
+
8
+ # ID โ†’ ๋ผ๋ฒจ ๋งคํ•‘ (์‚ฌ์šฉ์ž ์ •์˜)
9
  id_to_label = {
10
  0: 'O',
11
  1: 'B-TOPIC',
 
20
 
21
  @st.cache_resource
22
  def load_model():
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
24
+ model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
25
  return tokenizer, model
26
 
27
  tokenizer, model = load_model()
 
32
  model.eval()
33
  with torch.no_grad():
34
  outputs = model(**inputs)
35
+ logits = outputs.logits
36
+ predictions = torch.argmax(logits, dim=-1)
37
 
38
  word_ids = inputs.word_ids(batch_index=0)
39
  pred_labels = []
 
51
 
52
  def post_process(tokens, labels):
53
  words, word_labels = [], []
54
+ current_word, current_label = None, None
 
55
  for token, label in zip(tokens, labels):
56
+ if token in ['[CLS]', '[SEP]', '[PAD]']:
57
  continue
58
  if token.startswith("##"):
59
  current_word += token[2:]
 
83
  if prefix == "B":
84
  if current_entity:
85
  entities.append({"entity": current_entity, "text": current_text})
86
+ current_entity = entity_type
87
+ current_text = word
88
  elif prefix == "I" and current_entity == entity_type:
89
  current_text += word
90
  else:
91
  if current_entity:
92
  entities.append({"entity": current_entity, "text": current_text})
93
+ current_entity, current_text = None, ""
94
  if current_entity:
95
  entities.append({"entity": current_entity, "text": current_text})
96
  return entities
 
99
  st.title("๐ŸŽฏ Learning Condition Extractor")
100
  st.write("์‚ฌ์šฉ์ž์˜ ํ•™์Šต ๋ชฉํ‘œ ๋ฌธ์žฅ์—์„œ ์กฐ๊ฑด(TOPIC, STYLE, LENGTH, LANGUAGE)์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.")
101
 
102
+ user_input = st.text_input("๐Ÿ’ฌ ํ•™์Šต ๋ชฉํ‘œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”", value="์œ ํŠœ๋ธŒ ์˜์ƒ์€ ์‹ค์Šต ์œ„์ฃผ๋กœ 30๋ถ„ ์ด๋‚ด์— ๋ฐฐ์šฐ๊ณ  ์‹ถ์–ด์š”")
103
 
104
+ if st.button("๐Ÿ” ์ถ”์ถœ ์‹œ์ž‘"):
105
  tokens, pred_labels = predict(user_input, model, tokenizer, id_to_label)
106
  words, word_labels = post_process(tokens, pred_labels)
107
  aligned = align_words_labels(words, word_labels)
108
  entities = extract_entities(aligned)
109
 
110
+ result_dict = {"TOPIC": None, "STYLE": None, "LENGTH": None, "LANGUAGE": None}
111
  for ent in entities:
112
+ result_dict[ent["entity"]] = ent["text"]
113
 
114
  st.subheader("๐Ÿ“Œ ์ถ”์ถœ๋œ ์กฐ๊ฑด")
115
  st.json(result_dict)