vumichien commited on
Commit
5708eb0
·
1 Parent(s): 84736e5

improve cache model

Browse files
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import streamlit as st
2
- import faiss
3
  import numpy as np
4
  import json
5
- from sentence_transformers import SentenceTransformer
6
  import time
7
 
 
8
  # データを読み込む
9
  with open("data/qa_data.json", "r", encoding="utf-8") as f:
10
  data = json.load(f)
@@ -12,15 +12,27 @@ with open("data/qa_data.json", "r", encoding="utf-8") as f:
12
  questions = [item["question"] for item in data]
13
  answers = [item["answer"] for item in data]
14
 
15
- # 埋め込みモデルをロード
16
- model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
17
 
18
- # FAISSインデックスをロード
19
- index_q = faiss.read_index("data/faiss_question.index")
20
- index_a = faiss.read_index("data/faiss_answer.index")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # サイドバー設定
23
- st.set_page_config(initial_sidebar_state="collapsed")
24
  with st.sidebar.expander("⚙️ 設定", expanded=False):
25
  threshold_q = st.slider("質問の類似度しきい値", 0.0, 1.0, 0.7, 0.01)
26
  threshold_a = st.slider("回答の類似度しきい値", 0.0, 1.0, 0.65, 0.01)
@@ -31,24 +43,37 @@ with st.sidebar.expander("⚙️ 設定", expanded=False):
31
 
32
 
33
  def search_answer(user_input):
34
- """FAISSを使用して最適な回答を検索"""
35
- user_embedding = model.encode([user_input]).astype(np.float32)
36
-
37
- # 質問に対して検索
38
- D_q, I_q = index_q.search(user_embedding, 1)
39
- score_q = 1 / (1 + D_q[0][0])
 
 
 
 
 
 
 
 
40
 
41
  if score_q >= threshold_q:
42
- # Replace \n with markdown line breaks
43
- return answers[I_q[0][0]].replace("\n", " \n"), f"質問にマッチ ({score_q:.2f})"
 
 
44
 
45
- # 回答に対して検索
46
- D_a, I_a = index_a.search(user_embedding, 1)
47
- score_a = 1 / (1 + D_a[0][0])
 
48
 
49
  if score_a >= threshold_a:
50
- # Replace \n with markdown line breaks
51
- return answers[I_a[0][0]].replace("\n", " \n"), f"回答にマッチ ({score_a:.2f})"
 
 
52
 
53
  return "申し訳ありませんが、ご質問の答えを見つけることができませんでした。もう少し詳しく説明していただけますか?", "一致なし"
54
 
@@ -81,6 +106,7 @@ if user_input := st.chat_input("質問を入力してください:"):
81
 
82
  with st.spinner("考え中... お待ちください。"):
83
  answer, info = search_answer(user_input)
 
84
 
85
  with st.chat_message("assistant"):
86
  response_placeholder = st.empty()
 
1
  import streamlit as st
 
2
  import numpy as np
3
  import json
4
+ from sentence_transformers import SentenceTransformer, util
5
  import time
6
 
7
+ st.set_page_config(initial_sidebar_state="collapsed")
8
  # データを読み込む
9
  with open("data/qa_data.json", "r", encoding="utf-8") as f:
10
  data = json.load(f)
 
12
  questions = [item["question"] for item in data]
13
  answers = [item["answer"] for item in data]
14
 
 
 
15
 
16
+ # Cache model ở level app
17
+ @st.cache_resource
18
+ def load_model():
19
+ return SentenceTransformer("pkshatech/GLuCoSE-base-ja")
20
+
21
+
22
+ # Cache embeddings data
23
+ @st.cache_data
24
+ def load_embeddings():
25
+ return (
26
+ np.load("data/question_embeddings.npy"),
27
+ np.load("data/answer_embeddings.npy"),
28
+ )
29
+
30
+
31
+ # Load model và embeddings một lần
32
+ model = load_model()
33
+ question_embeddings, answer_embeddings = load_embeddings()
34
 
35
  # サイドバー設定
 
36
  with st.sidebar.expander("⚙️ 設定", expanded=False):
37
  threshold_q = st.slider("質問の類似度しきい値", 0.0, 1.0, 0.7, 0.01)
38
  threshold_a = st.slider("回答の類似度しきい値", 0.0, 1.0, 0.65, 0.01)
 
43
 
44
 
45
  def search_answer(user_input):
46
+ """Tìm kiếm câu trả lời sử dụng cosine similarity"""
47
+ # Encode với batch_size và show_progress_bar=False để tăng tốc
48
+ user_embedding = model.encode(
49
+ [user_input],
50
+ convert_to_numpy=True,
51
+ batch_size=1,
52
+ show_progress_bar=False,
53
+ normalize_embeddings=True, # Pre-normalize để tăng tốc cosine similarity
54
+ )
55
+
56
+ # Tính cosine similarity với câu hỏi
57
+ cos_scores_q = util.cos_sim(user_embedding, question_embeddings)[0]
58
+ best_q_idx = np.argmax(cos_scores_q)
59
+ score_q = cos_scores_q[best_q_idx]
60
 
61
  if score_q >= threshold_q:
62
+ return (
63
+ answers[best_q_idx].replace("\n", " \n"),
64
+ f"質問にマッチ ({score_q:.2f})",
65
+ )
66
 
67
+ # Tính cosine similarity với câu trả lời
68
+ cos_scores_a = model.util.cos_sim(user_embedding, answer_embeddings)[0]
69
+ best_a_idx = np.argmax(cos_scores_a)
70
+ score_a = cos_scores_a[best_a_idx]
71
 
72
  if score_a >= threshold_a:
73
+ return (
74
+ answers[best_a_idx].replace("\n", " \n"),
75
+ f"回答にマッチ ({score_a:.2f})",
76
+ )
77
 
78
  return "申し訳ありませんが、ご質問の答えを見つけることができませんでした。もう少し詳しく説明していただけますか?", "一致なし"
79
 
 
106
 
107
  with st.spinner("考え中... お待ちください。"):
108
  answer, info = search_answer(user_input)
109
+ print(info)
110
 
111
  with st.chat_message("assistant"):
112
  response_placeholder = st.empty()
data/answer_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:925632dc69ab4df0223970df60cc9054dd46e2958e597e6998514bd3b33fc703
3
+ size 67712
data/faiss_answer.index DELETED
Binary file (67.6 kB)
 
data/faiss_question.index DELETED
Binary file (67.6 kB)
 
data/question_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45127b6e8615f93324b2debb37305b93d3963c5f91f054f9c56def8cd00c1ca5
3
+ size 67712
preprocess.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import faiss
3
  import numpy as np
4
  from sentence_transformers import SentenceTransformer
5
 
@@ -14,20 +13,13 @@ answers = [item["answer"] for item in data]
14
  model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
15
 
16
  # Tạo embedding cho câu hỏi và câu trả lời
17
- question_embeddings = model.encode(questions)
18
- answer_embeddings = model.encode(answers)
19
 
20
- # Lưu FAISS index
21
- dim = question_embeddings.shape[1]
22
- index_q = faiss.IndexFlatL2(dim)
23
- index_a = faiss.IndexFlatL2(dim)
24
-
25
- index_q.add(np.array(question_embeddings).astype(np.float32))
26
- index_a.add(np.array(answer_embeddings).astype(np.float32))
27
-
28
- faiss.write_index(index_q, "faiss_question.index")
29
- faiss.write_index(index_a, "faiss_answer.index")
30
 
31
  # Lưu dữ liệu gốc
32
- with open("qa_data.json", "w", encoding="utf-8") as f:
33
  json.dump(data, f, ensure_ascii=False, indent=2)
 
1
  import json
 
2
  import numpy as np
3
  from sentence_transformers import SentenceTransformer
4
 
 
13
  model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
14
 
15
  # Tạo embedding cho câu hỏi và câu trả lời
16
+ question_embeddings = model.encode(questions, convert_to_numpy=True)
17
+ answer_embeddings = model.encode(answers, convert_to_numpy=True)
18
 
19
+ # Lưu embedding dưới dạng numpy array
20
+ np.save("data/question_embeddings.npy", question_embeddings)
21
+ np.save("data/answer_embeddings.npy", answer_embeddings)
 
 
 
 
 
 
 
22
 
23
  # Lưu dữ liệu gốc
24
+ with open("data/qa_data.json", "w", encoding="utf-8") as f:
25
  json.dump(data, f, ensure_ascii=False, indent=2)