Nerva5678 commited on
Commit
92e4d9e
·
verified ·
1 Parent(s): fec0b17

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +426 -64
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,71 +1,433 @@
 
1
  import streamlit as st
2
- from langchain_community.embeddings import HuggingFaceEmbeddings
3
- from langchain_community.vectorstores import FAISS
4
- from langchain_community.llms import HuggingFaceHub
5
- from langchain.chains import RetrievalQA
6
- from langchain.text_splitter import CharacterTextSplitter
7
- from langchain.docstore.document import Document
8
- from langchain.prompts import PromptTemplate
9
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # 設定 HuggingFace token
12
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = "hf_your_token"
13
-
14
- # 假資料:簡單 Q&A 列表
15
- qa_data = [
16
- {"問題": "什麼是AI?", "答案": "人工智慧(AI)是一種模擬人類智能的技術。"},
17
- {"問題": "LangChain是什麼?", "答案": "LangChain 是一個用於構建基於 LLM 的應用框架。"},
18
- {"問題": "FAISS有什麼用?", "答案": "FAISS 是一個用於高效相似度搜尋的向量資料庫工具。"},
19
- ]
20
-
21
- # 將問答資料轉換為 Document 格式
22
- def build_documents(qa_data):
23
- docs = []
24
- for item in qa_data:
25
- content = f"問題:{item['問題']}\n答案:{item['答案']}"
26
- docs.append(Document(page_content=content))
27
- return docs
28
-
29
- # 向量資料庫
30
- def create_vectorstore(docs, embeddings):
31
- return FAISS.from_documents(docs, embedding=embeddings)
32
-
33
- # 建立嵌入模型
34
- def get_embedding_model():
35
- return HuggingFaceEmbeddings(model_name="text2vec-base-chinese")
36
-
37
- # 建立語言模型(ChatGLM3)
38
- def get_llm_model():
39
- return HuggingFaceHub(
40
- repo_id="THUDM/chatglm3-6b",
41
- model_kwargs={"temperature": 0.1, "max_length": 2048}
42
  )
43
 
44
- # 建立問答鏈
45
- def build_qa_chain(llm, vectorstore):
46
- return RetrievalQA.from_chain_type(
47
- llm=llm,
48
- retriever=vectorstore.as_retriever(),
49
- chain_type="stuff",
50
- return_source_documents=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
 
53
- # Streamlit UI
54
- st.title("💬 小型知識問答機器人")
55
- st.markdown("目前使用內建知識,無需上傳 Excel")
56
-
57
- # 準備模型與資料
58
- embedding_model = get_embedding_model()
59
- llm_model = get_llm_model()
60
- documents = build_documents(qa_data)
61
- vectorstore = create_vectorstore(documents, embedding_model)
62
- qa_chain = build_qa_chain(llm_model, vectorstore)
63
-
64
- # 問答輸入
65
- user_question = st.text_input("請輸入你的問題:")
66
-
67
- if st.button("送出") and user_question:
68
- with st.spinner("思考中..."):
69
- result = qa_chain({"query": user_question})
70
- st.success("回答:")
71
- st.write(result["result"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  import streamlit as st
3
+ import pandas as pd
4
+ import torch
 
 
 
 
 
5
  import os
6
+ import time
7
+ import logging
8
+ import subprocess
9
+ import sys
10
+
11
+ # 設定logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # 頁面配置
16
+ st.set_page_config(
17
+ page_title="Excel 問答 AI(ChatGLM 驅動)",
18
+ page_icon="🤖",
19
+ layout="wide"
20
+ )
21
+
22
+ # 應用標題與說明
23
+ st.title("🤖 Excel 問答 AI(ChatGLM 驅動)")
24
+ st.markdown("""
25
+ ### 使用說明
26
+ 1. 可直接提問一般知識,AI 將使用內建能力回答
27
+ 2. 上傳 Excel 檔案(包含「問題」和「答案」欄位)以添加專業知識
28
+ 3. 系統會優先使用您上傳的知識庫進行回答
29
+ """)
30
+
31
+ # 檢查並安裝必要套件
32
+ def install_missing_packages():
33
+ required_packages = ["sentencepiece", "protobuf", "bitsandbytes"] # 加入 bitsandbytes
34
+ for package in required_packages:
35
+ try:
36
+ __import__(package)
37
+ st.write(f"{package} 已安裝")
38
+ except ImportError:
39
+ st.write(f"安裝缺失的套件: {package}")
40
+ try:
41
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
42
+ st.write(f"{package} 已安裝成功")
43
+ except Exception as e:
44
+ st.error(f"安裝 {package} 失敗: {str(e)}")
45
+ return False
46
+ return True
47
+
48
+ # 安裝缺失的套件
49
+ if not install_missing_packages():
50
+ st.error("必要套件安裝失敗,請刷新頁面重試")
51
+ st.stop()
52
+
53
+ st.write("正在導入依賴項...")
54
+
55
+ # 依次導入並檢查每個依賴
56
+ try:
57
+ from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
58
+ st.write("成功導入 HuggingFaceEmbeddings")
59
+ except Exception as e:
60
+ st.error(f"導入 HuggingFaceEmbeddings 失敗: {str(e)}")
61
+ st.stop()
62
+
63
+ try:
64
+ from langchain_community.vectorstores import FAISS
65
+ st.write("成功導入 FAISS")
66
+ except Exception as e:
67
+ st.error(f"導入 FAISS 失敗: {str(e)}")
68
+ st.stop()
69
+
70
+ try:
71
+ from langchain_community.llms import HuggingFacePipeline
72
+ st.write("成功導入 HuggingFacePipeline")
73
+ except Exception as e:
74
+ st.error(f"導入 HuggingFacePipeline 失敗: {str(e)}")
75
+ st.stop()
76
+
77
+ try:
78
+ from langchain.chains import RetrievalQA, LLMChain
79
+ st.write("成功導入 RetrievalQA, LLMChain")
80
+ except Exception as e:
81
+ st.error(f"導入 RetrievalQA, LLMChain 失敗: {str(e)}")
82
+ st.stop()
83
+
84
+ try:
85
+ from langchain.prompts import PromptTemplate
86
+ st.write("成功導入 PromptTemplate")
87
+ except Exception as e:
88
+ st.error(f"導入 PromptTemplate 失敗: {str(e)}")
89
+ st.stop()
90
+
91
+ try:
92
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
93
+ st.write("成功導入 transformers 組件")
94
+ except Exception as e:
95
+ st.error(f"導入 transformers 組件失敗: {str(e)}")
96
+ st.stop()
97
+
98
+ try:
99
+ import bitsandbytes # 檢查 bitsandbytes
100
+ st.write("成功導入 bitsandbytes")
101
+ has_bitsandbytes = True
102
+ except ImportError:
103
+ st.warning("未安裝 bitsandbytes,將無法使用 4 位元量化。")
104
+ has_bitsandbytes = False
105
+
106
+ st.write("所有依賴項導入成功!")
107
+
108
+ # 側邊欄設定
109
+ with st.sidebar:
110
+ st.header("參數設定")
111
+
112
+ model_option = st.selectbox(
113
+ "選擇模型",
114
+ ["THUDM/chatglm3-6b", "THUDM/chatglm2-6b", "THUDM/chatglm-6b"],
115
+ index=0
116
+ )
117
+
118
+ embedding_option = st.selectbox(
119
+ "選擇嵌入模型",
120
+ ["shibing624/text2vec-base-chinese", "GanymedeNil/text2vec-large-chinese"],
121
+ index=0
122
+ )
123
 
124
+ mode = st.radio(
125
+ "回答模式",
126
+ ["混合模式(優先使用上傳資料)", "僅使用上傳資料", "僅使用模型知識"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  )
128
 
129
+ max_tokens = st.slider("最大回應長度", 128, 2048, 512)
130
+ temperature = st.slider("溫度(創造性)", 0.0, 1.0, 0.7, 0.1)
131
+ top_k = st.slider("檢索相關文檔數", 1, 5, 3)
132
+
133
+ st.markdown("---")
134
+ st.markdown("### 關於")
135
+ st.markdown("此應用使用 ChatGLM 模型結合 LangChain 框架,將您的 Excel 數據轉化為智能問答系統。同時支持一般知識問答。")
136
+
137
+ # 全局變量
138
+ @st.cache_resource
139
+ def load_embeddings(model_name):
140
+ try:
141
+ logger.info(f"加載嵌入模型: {model_name}")
142
+ st.write(f"開始加載嵌入模型: {model_name}...")
143
+ embeddings = HuggingFaceEmbeddings(model_name=model_name)
144
+ st.write(f"嵌入模型加載成功!")
145
+ return embeddings
146
+ except Exception as e:
147
+ logger.error(f"嵌入模型加載失敗: {str(e)}")
148
+ st.error(f"嵌入模型加載失敗: {str(e)}")
149
+ return None
150
+
151
+ @st.cache_resource
152
+ def load_llm(_model_name, _max_tokens, _temperature):
153
+ try:
154
+ logger.info(f"加載語言模型: {_model_name}")
155
+ st.write(f"開始加載語言模型: {_model_name}...")
156
+
157
+ # 檢查可用資源
158
+ free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
159
+ st.write(f"可用GPU記憶體: {free_memory / (1024**3):.2f} GB" if torch.cuda.is_available() else "無GPU可用,將使用CPU")
160
+
161
+ device = "cuda" if torch.cuda.is_available() else "cpu"
162
+ dtype = torch.float16
163
+
164
+ load_args = {"trust_remote_code": True, "device_map": device, "torch_dtype": dtype}
165
+
166
+ if device == "cpu":
167
+ st.warning("注意:在 CPU 上載入大型語言模型可能會非常緩慢且需要大量記憶體。")
168
+ elif has_bitsandbytes:
169
+ try:
170
+ load_args["load_in_4bit"] = True
171
+ load_args["bnb_4bit_compute_dtype"] = torch.float16
172
+ st.info("嘗試使用 4 位元量化載入模型 (需要 bitsandbytes)。")
173
+ except Exception as e:
174
+ st.warning(f"載入 4 位元量化模型失敗: {e}")
175
+ st.info("將嘗試以半精度浮點數載入。")
176
+
177
+ # 使用超時保護
178
+ with st.spinner(f"正在加載 {_model_name} 模型,這可能需要幾分鐘..."):
179
+ # 加載tokenizer
180
+ st.write("加載tokenizer...")
181
+ tokenizer = AutoTokenizer.from_pretrained(_model_name, trust_remote_code=True)
182
+ st.write("Tokenizer加載成功")
183
+
184
+ # 加載模型
185
+ st.write(f"開始加載模型到{device}...")
186
+ try:
187
+ model = AutoModelForCausalLM.from_pretrained(_model_name, **load_args)
188
+ st.write("模型加載成功!")
189
+ except Exception as e:
190
+ st.error(f"模型加載失敗: {e}")
191
+ st.error("嘗試使用不同的載入配置。")
192
+ raise e
193
+
194
+ # 創建pipeline
195
+ st.write("創建文本生成pipeline...")
196
+ pipe = pipeline(
197
+ "text-generation",
198
+ model=model,
199
+ tokenizer=tokenizer,
200
+ max_new_tokens=_max_tokens,
201
+ temperature=_temperature,
202
+ top_p=0.9,
203
+ repetition_penalty=1.1
204
+ )
205
+ st.write("Pipeline創建成功!")
206
+
207
+ return HuggingFacePipeline(pipeline=pipe)
208
+ except Exception as e:
209
+ logger.error(f"語言模型加載失敗: {str(e)}")
210
+ st.error(f"語言模型加載失敗: {str(e)}")
211
+ st.error("如果是因為記憶體不足,請考慮使用較小的模型或增加系統記憶體")
212
+ return None
213
+
214
+ # 創建向量資料庫
215
+ def create_vectorstore(texts, embeddings):
216
+ try:
217
+ st.write("開始創建向量資料庫...")
218
+ vectorstore = FAISS.from_texts(texts, embedding=embeddings)
219
+ st.write("向量資料庫創建成功!")
220
+ return vectorstore
221
+ except Exception as e:
222
+ logger.error(f"向量資料庫創建失敗: {str(e)}")
223
+ st.error(f"向量資料庫創建失敗: {str(e)}")
224
+ return None
225
+
226
+ # 創建直接問答的LLM鏈
227
+ def create_general_qa_chain(llm):
228
+ prompt_template = """請回答以下問題:
229
+
230
+ 問題: {question}
231
+
232
+ ���提供詳細且有幫助的回答:"""
233
+
234
+ prompt = PromptTemplate(
235
+ template=prompt_template,
236
+ input_variables=["question"]
237
  )
238
 
239
+ return LLMChain(llm=llm, prompt=prompt)
240
+
241
+ # 混合模式問答處理
242
+ def hybrid_qa(query, qa_chain, general_chain, confidence_threshold=0.7):
243
+ # 先嘗試使用知識庫回答
244
+ try:
245
+ st.write("嘗試從知識庫查詢答案...")
246
+ kb_result = qa_chain({"query": query})
247
+ # 檢查向量存儲的相似度分數,判斷是否有足夠相關的內容
248
+ if (hasattr(kb_result, 'source_documents') and
249
+ kb_result.get("source_documents") and
250
+ len(kb_result["source_documents"]) > 0):
251
+ # 這裡假設我們能獲取到相似度分數,實際上可能需要根據您使用的向量存儲方法調整
252
+ relevance = True # 在實際應用中,這裡應根據相似度分數確定
253
+
254
+ if relevance:
255
+ st.write("找到相關知識庫內容")
256
+ return kb_result, "knowledge_base", kb_result["source_documents"]
257
+ st.write("知識庫中未找到足夠相關的內容")
258
+ except Exception as e:
259
+ logger.warning(f"知識庫查詢失敗: {str(e)}")
260
+ st.warning(f"知識庫查詢失敗: {str(e)}")
261
+
262
+ # 如果知識庫沒有足夠相關的答案,使用一般知識模式
263
+ try:
264
+ st.write("使用模型一般知識回答...")
265
+ general_result = general_chain.run(question=query)
266
+ return {"result": general_result}, "general", []
267
+ except Exception as e:
268
+ logger.error(f"一般知識查詢失敗: {str(e)}")
269
+ st.error(f"一般知識查詢失敗: {str(e)}")
270
+ return {"result": "很抱歉,無法處理您的問題,請稍後再試。"}, "error", []
271
+
272
+ # 主應用邏輯
273
+ # 加載嵌入模型(先加載嵌入模型,因為這通常較小較快)
274
+ embeddings = None
275
+ if "embeddings" not in st.session_state:
276
+ with st.spinner("正在加載嵌入模型..."):
277
+ embeddings = load_embeddings(embedding_option)
278
+ if embeddings is not None:
279
+ st.session_state.embeddings = embeddings
280
+ else:
281
+ st.error("嵌入模型加載失敗,請刷新頁面重試")
282
+ st.stop()
283
+ else:
284
+ embeddings = st.session_state.embeddings
285
+
286
+ # 加載語言模型(不管是否上傳文件都需要)
287
+ llm = None
288
+ if "llm" not in st.session_state:
289
+ llm = load_llm(model_option, max_tokens, temperature)
290
+ if llm is not None:
291
+ st.session_state.llm = llm
292
+ else:
293
+ st.error("語言模型加載失敗,請刷新頁面重試")
294
+ st.stop()
295
+ else:
296
+ llm = st.session_state.llm
297
+
298
+ # 創建一般問答鏈
299
+ general_qa_chain = create_general_qa_chain(llm)
300
+ st.write("一般問答鏈創建成功!")
301
+
302
+ # 變數初始化
303
+ kb_qa_chain = None
304
+ has_knowledge_base = False
305
+ vectorstore = None
306
+
307
+ # 上傳Excel文件
308
+ uploaded_file = st.file_uploader("上傳你的問答 Excel(可選)", type=["xlsx"])
309
+
310
+ if uploaded_file:
311
+ # 讀取Excel文件
312
+ try:
313
+ st.write("開始讀取Excel文件...")
314
+ df = pd.read_excel(uploaded_file)
315
+
316
+ # 檢查必要欄位
317
+ if not {'問題', '答案'}.issubset(df.columns):
318
+ st.error("Excel 檔案需包含 '問題' 和 '答案' 欄位")
319
+ else:
320
+ # 顯示資料預覽
321
+ with st.expander("Excel 資料預覽"):
322
+ st.dataframe(df.head())
323
+
324
+ st.info(f"成功讀取 {len(df)} 筆問答對")
325
+
326
+ # 建立文本列表
327
+ texts = [f"問題:{q}\n答案:{a}" for q, a in zip(df['問題'], df['答案'])]
328
+
329
+ # 進度條
330
+ progress_text = "正在處理中..."
331
+ my_bar = st.progress(0, text=progress_text)
332
+
333
+ # 使用之前加載的嵌入模型
334
+ my_bar.progress(25, text="準備嵌入模型...")
335
+
336
+ # 建立向量資料庫
337
+ my_bar.progress(50, text="正在建立向量資料庫...")
338
+ vectorstore = create_vectorstore(texts, embeddings)
339
+ if vectorstore is None:
340
+ st.stop()
341
+
342
+ # 創建問答鏈
343
+ my_bar.progress(75, text="正在建立知識庫問答系統...")
344
+ kb_qa_chain = RetrievalQA.from_chain_type(
345
+ llm=llm,
346
+ retriever=vectorstore.as_retriever(search_kwargs={"k": top_k}),
347
+ chain_type="stuff",
348
+ return_source_documents=True
349
+ )
350
+
351
+ has_knowledge_base = True
352
+
353
+ my_bar.progress(100, text="準備完成!")
354
+ time.sleep(1)
355
+ my_bar.empty()
356
+
357
+ st.success("知識庫已準備就緒,請輸入您的問題")
358
+
359
+ except Exception as e:
360
+ logger.error(f"Excel 檔案處理失敗: {str(e)}")
361
+ st.error(f"Excel 檔案處理失敗: {str(e)}")
362
+
363
+ # 查詢部分
364
+ st.markdown("## 開始對話")
365
+ query = st.text_input("請輸入你的問題:")
366
+
367
+ if query:
368
+ with st.spinner("AI 思考中..."):
369
+ try:
370
+ start_time = time.time()
371
+
372
+ # 根據模式選擇問答方式
373
+ if mode == "僅使用上傳資料":
374
+ if has_knowledge_base:
375
+ st.write("使用知識庫模式回答...")
376
+ result = kb_qa_chain({"query": query})
377
+ source = "knowledge_base"
378
+ source_docs = result["source_documents"]
379
+ else:
380
+ st.warning("您選擇了僅使用上傳資料模式,但尚未上傳Excel檔案。請上傳檔案或變更模式。")
381
+ st.stop()
382
+
383
+ elif mode == "僅使用模型知識":
384
+ st.write("使用模型一般知識模式回答...")
385
+ result = {"result": general_qa_chain.run(question=query)}
386
+ source = "general"
387
+ source_docs = []
388
+
389
+ else: # 混合模式
390
+ if has_knowledge_base:
391
+ st.write("使用混合模式回答...")
392
+ result, source, source_docs = hybrid_qa(query, kb_qa_chain, general_qa_chain)
393
+ else:
394
+ st.write("未檢測到知識庫,使用模型一般知識回答...")
395
+ result = {"result": general_qa_chain.run(question=query)}
396
+ source = "general"
397
+ source_docs = []
398
+
399
+ end_time = time.time()
400
+
401
+ # 顯示回答
402
+ st.markdown("### AI 回答:")
403
+ st.markdown(result["result"])
404
+
405
+ # 根據來源顯示不同信息
406
+ if source == "knowledge_base":
407
+ st.success("✅ 回答來自您的知識庫")
408
+ # 顯示參考資料
409
+ with st.expander("參考資料"):
410
+ for i, doc in enumerate(source_docs):
411
+ st.markdown(f"**參考 {i+1}**")
412
+ st.markdown(doc.page_content)
413
+ st.markdown("---")
414
+ elif source == "general":
415
+ if has_knowledge_base:
416
+ st.info("ℹ️ 回答來自模型的一般知識(知識庫中未找到相關內容)")
417
+ else:
418
+ st.info("ℹ️ 回答來自模型的一般知識")
419
+
420
+ st.text(f"回答生成時間: {(end_time - start_time):.2f} 秒")
421
+
422
+ except Exception as e:
423
+ logger.error(f"查詢處理失敗: {str(e)}")
424
+ st.error(f"查詢處理失敗,請重試: {str(e)}")
425
+ st.error(f"錯誤詳情: {str(e)}")
426
+
427
+ # 添加會話歷史功能
428
+ if "chat_history" not in st.session_state:
429
+ st.session_state.chat_history = []
430
+
431
+ # 底部資訊
432
+ st.markdown("---")
433
+ st.markdown("Made with ❤️ | Excel 問答 AI")
requirements.txt CHANGED
@@ -10,3 +10,4 @@ protobuf>=3.20.0
10
  openpyxl>=3.1.0
11
  huggingface_hub>=0.19.0
12
  accelerate==0.25.0
 
 
10
  openpyxl>=3.1.0
11
  huggingface_hub>=0.19.0
12
  accelerate==0.25.0
13
+ bitsandbytes>=0.41.1