zhixiusue commited on
Commit
5def19a
ยท
verified ยท
1 Parent(s): 7f029d3

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +92 -67
streamlit_app.py CHANGED
@@ -1,21 +1,22 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification
 
 
 
 
3
  import torch
4
 
5
- # โœ… Hugging Face model repo
6
  MODEL_REPO = "zhixiusue/EduTubeNavigator"
7
 
8
- # โœ… ID to label mapping
9
- id_to_label = {
10
- 0: 'O',
11
- 1: 'B-TOPIC',
12
- 2: 'I-TOPIC',
13
- 3: 'B-STYLE',
14
- 4: 'I-STYLE',
15
- 5: 'B-LENGTH',
16
- 6: 'I-LENGTH',
17
- 7: 'B-LANGUAGE',
18
- 8: 'I-LANGUAGE'
19
  }
20
 
21
  @st.cache_resource
@@ -24,69 +25,93 @@ def load_model():
24
  model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
25
  return tokenizer, model
26
 
27
- tokenizer, model = load_model()
28
-
29
- def predict(text, model, tokenizer, id_to_label):
30
- words = text.split()
31
- inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=128)
32
-
33
  model.eval()
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
- predictions = torch.argmax(outputs.logits, dim=-1)
37
-
38
  word_ids = inputs.word_ids(batch_index=0)
39
- aligned = []
40
- for idx, word_idx in enumerate(word_ids):
41
- if word_idx is not None:
42
- word = words[word_idx]
43
- label = id_to_label[predictions[0][idx].item()]
44
- aligned.append((word, label))
45
-
46
- return aligned
47
-
48
- def extract_entities(aligned_result):
49
- entities = []
50
- current_entity, current_text = None, ""
51
-
52
- for word, label in aligned_result:
53
- if label == "O":
54
- if current_entity:
55
- entities.append({"entity": current_entity, "text": current_text})
56
- current_entity, current_text = None, ""
57
  continue
58
-
59
- prefix, entity_type = label.split("-", 1)
60
- if prefix == "B":
61
- if current_entity:
62
- entities.append({"entity": current_entity, "text": current_text})
63
- current_entity = entity_type
64
- current_text = word
65
- elif prefix == "I" and current_entity == entity_type:
66
- current_text += word
67
  else:
68
- if current_entity:
69
- entities.append({"entity": current_entity, "text": current_text})
70
- current_entity, current_text = None, ""
71
-
72
- if current_entity:
73
- entities.append({"entity": current_entity, "text": current_text})
74
  return entities
75
 
76
- # โœ… Streamlit UI
77
- st.title("๐ŸŽฏ Learning Condition Extractor")
78
- st.write("์‚ฌ์šฉ์ž์˜ ํ•™์Šต ๋ชฉํ‘œ ๋ฌธ์žฅ์—์„œ ์กฐ๊ฑด (TOPIC, STYLE, LENGTH, LANGUAGE)์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- user_input = st.text_input("๐Ÿ’ฌ ํ•™์Šต ๋ชฉํ‘œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”", value="์œ ํŠœ๋ธŒ ์˜์ƒ์€ ์‹ค์Šต ์œ„์ฃผ๋กœ 30๋ถ„ ์ด๋‚ด์— ๋ฐฐ์šฐ๊ณ  ์‹ถ์–ด์š”")
81
-
82
- if st.button("๐Ÿ” ์ถ”์ถœ ์‹œ์ž‘"):
83
- aligned = predict(user_input, model, tokenizer, id_to_label)
84
- entities = extract_entities(aligned)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # ๊ฒฐ๊ณผ ์ •๋ฆฌ
87
- result_dict = {"TOPIC": None, "STYLE": None, "LENGTH": None, "LANGUAGE": None}
88
- for ent in entities:
89
- result_dict[ent["entity"]] = ent["text"]
90
 
91
- st.subheader("๐Ÿ“Œ ์ถ”์ถœ๋œ ์กฐ๊ฑด")
92
- st.json(result_dict)
 
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification
3
+ from sentence_transformers import SentenceTransformer
4
+ from youtube_transcript_api import YouTubeTranscriptApi
5
+ from googleapiclient.discovery import build
6
+ from torch.nn.functional import cosine_similarity
7
  import torch
8
 
9
+ # Hugging Face ๋ชจ๋ธ
10
  MODEL_REPO = "zhixiusue/EduTubeNavigator"
11
 
12
+ # YouTube API Key
13
+ YOUTUBE_API_KEY = "AIzaSyA8SG7--MfQvWET6UOam0PVAcC5MDm4sbc"
14
+ youtube = build("youtube", "v3", developerKey=YOUTUBE_API_KEY)
15
+
16
+ # ID โ†’ ๋ผ๋ฒจ ๋งคํ•‘
17
+ token_label_map = {
18
+ 0: 'O', 1: 'B-TOPIC', 2: 'I-TOPIC', 3: 'B-STYLE', 4: 'I-STYLE',
19
+ 5: 'B-LENGTH', 6: 'I-LENGTH', 7: 'B-LANGUAGE', 8: 'I-LANGUAGE'
 
 
 
20
  }
21
 
22
  @st.cache_resource
 
25
  model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
26
  return tokenizer, model
27
 
28
+ def predict_entities(text, tokenizer, model):
29
+ tokens = list(text)
30
+ inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=128)
 
 
 
31
  model.eval()
32
  with torch.no_grad():
33
  outputs = model(**inputs)
34
+ logits = outputs.logits
35
+ predictions = torch.argmax(logits, dim=-1)[0]
36
  word_ids = inputs.word_ids(batch_index=0)
37
+ entities = {}
38
+ current_entity = ""
39
+ current_type = ""
40
+ for idx, word_id in enumerate(word_ids):
41
+ if word_id is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  continue
43
+ label = token_label_map[predictions[idx].item()]
44
+ if label.startswith("B-"):
45
+ if current_type:
46
+ entities[current_type] = current_entity
47
+ current_type = label[2:]
48
+ current_entity = tokens[word_id]
49
+ elif label.startswith("I-") and label[2:] == current_type:
50
+ current_entity += tokens[word_id]
 
51
  else:
52
+ if current_type:
53
+ entities[current_type] = current_entity
54
+ current_type = ""
55
+ current_entity = ""
56
+ if current_type:
57
+ entities[current_type] = current_entity
58
  return entities
59
 
60
+ def search_youtube_videos(query, max_results=10):
61
+ response = youtube.search().list(q=query, part="snippet", type="video", maxResults=max_results).execute()
62
+ results = []
63
+ for item in response['items']:
64
+ results.append({
65
+ 'video_id': item['id']['videoId'],
66
+ 'title': item['snippet']['title'],
67
+ 'description': item['snippet']['description']
68
+ })
69
+ return results
70
+
71
+ def get_transcript(video_id):
72
+ try:
73
+ transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['ko', 'en'])
74
+ return " ".join([t['text'] for t in transcript])
75
+ except:
76
+ return ""
77
 
78
+ @st.cache_resource
79
+ def load_embedder():
80
+ return SentenceTransformer('snunlp/KR-SBERT-V40K-klueNLI-augSTS')
81
+
82
+ def embed_texts(embedder, texts):
83
+ return embedder.encode(texts, convert_to_tensor=True)
84
+
85
+ def recommend_video(embedder, user_conditions, video_infos):
86
+ user_text = " ".join([v for v in user_conditions.values() if v])
87
+ user_embedding = embed_texts(embedder, [user_text])[0]
88
+
89
+ scored = []
90
+ for video in video_infos:
91
+ video_text = video['title'] + " " + video['description'] + " " + get_transcript(video['video_id'])
92
+ video_embedding = embed_texts(embedder, [video_text])[0]
93
+ score = cosine_similarity(user_embedding, video_embedding, dim=0).item()
94
+ scored.append((score, video))
95
+ return sorted(scored, reverse=True, key=lambda x: x[0])[:3]
96
+
97
+ # Streamlit UI
98
+ st.title("๐ŸŽฏ ํ•™์Šต ์กฐ๊ฑด ๊ธฐ๋ฐ˜ ์œ ํŠœ๋ธŒ ์ถ”์ฒœ๊ธฐ")
99
+ st.write("ํ•™์Šต ๋ชฉํ‘œ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด ์กฐ๊ฑด์„ ์ถ”์ถœํ•˜๊ณ  ์œ ํŠœ๋ธŒ ์˜์ƒ์„ ์ถ”์ฒœํ•ฉ๋‹ˆ๋‹ค.")
100
+
101
+ user_input = st.text_input("๐Ÿ’ฌ ํ•™์Šต ๋ชฉํ‘œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”", "๋”ฅ๋Ÿฌ๋‹์„ ์‹ค์Šต ์œ„์ฃผ๋กœ 30๋ถ„ ์•ˆ์— ๋ฐฐ์šฐ๊ณ  ์‹ถ์–ด์š”")
102
+ if st.button("๐Ÿ” ์ถ”์ฒœ ์‹œ์ž‘"):
103
+ tokenizer, model = load_model()
104
+ embedder = load_embedder()
105
+
106
+ entities = predict_entities(user_input, tokenizer, model)
107
+ st.subheader("๐Ÿ“Œ ์ถ”์ถœ๋œ ์กฐ๊ฑด")
108
+ st.json(entities)
109
 
110
+ search_query = " ".join([v for v in entities.values() if v])
111
+ video_candidates = search_youtube_videos(search_query)
112
+ top_recommendations = recommend_video(embedder, entities, video_candidates)
 
113
 
114
+ st.subheader("๐Ÿ“บ ์ถ”์ฒœ ์œ ํŠœ๋ธŒ ์˜์ƒ")
115
+ for score, video in top_recommendations:
116
+ st.markdown(f"**{video['title']}** ")
117
+ st.markdown(f"๐Ÿ”— [๋งํฌ](https://www.youtube.com/watch?v={video['video_id']})")