bainskarman commited on
Commit
7370c39
·
verified ·
1 Parent(s): ac1a86b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ api_token = os.getenv("HF_TOKEN")
4
+
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.document_loaders import PyPDFLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.llms import HuggingFaceEndpoint
10
+ from langchain.chains import ConversationalRetrievalChain
11
+ from langchain.memory import ConversationBufferMemory
12
+
13
+ list_llm = ["meta-llama/Llama-3.2-3B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
14
+
15
+ # Load and split PDF document
16
+ def load_doc(list_file_path):
17
+ loaders = [PyPDFLoader(x) for x in list_file_path]
18
+ pages = []
19
+ for loader in loaders:
20
+ pages.extend(loader.load())
21
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
22
+ return text_splitter.split_documents(pages)
23
+
24
+ # Create vector database
25
+ def create_db(splits):
26
+ embeddings = HuggingFaceEmbeddings()
27
+ return FAISS.from_documents(splits, embeddings)
28
+
29
+ # Initialize langchain LLM chain
30
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
31
+ llm = HuggingFaceEndpoint(
32
+ repo_id=llm_model,
33
+ huggingfacehub_api_token=api_token,
34
+ temperature=temperature,
35
+ max_new_tokens=max_tokens,
36
+ top_k=top_k,
37
+ )
38
+
39
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
40
+ retriever = vector_db.as_retriever()
41
+
42
+ return ConversationalRetrievalChain.from_llm(
43
+ llm,
44
+ retriever=retriever,
45
+ chain_type="stuff",
46
+ memory=memory,
47
+ return_source_documents=True,
48
+ verbose=False,
49
+ )
50
+
51
+ st.title("RAG PDF Chatbot")
52
+
53
+ uploaded_files = st.file_uploader("Upload PDF files", accept_multiple_files=True, type="pdf")
54
+
55
+ if uploaded_files:
56
+ # Save uploaded files to local disk
57
+ file_paths = []
58
+ for uploaded_file in uploaded_files:
59
+ file_path = os.path.join("temp", uploaded_file.name)
60
+ os.makedirs("temp", exist_ok=True)
61
+ with open(file_path, "wb") as f:
62
+ f.write(uploaded_file.getbuffer())
63
+ file_paths.append(file_path)
64
+
65
+ st.session_state["doc_splits"] = load_doc(file_paths)
66
+ st.success("Documents successfully loaded and split!")
67
+
68
+ if 'vector_db' not in st.session_state and 'doc_splits' in st.session_state:
69
+ st.session_state['vector_db'] = create_db(st.session_state['doc_splits'])
70
+
71
+ llm_option = st.selectbox("Select LLM", list_llm)
72
+
73
+ temperature = st.slider("Temperature", 0.01, 1.0, 0.5, 0.1)
74
+ max_tokens = st.slider("Max Tokens", 128, 9192, 4096, 128)
75
+ top_k = st.slider("Top K", 1, 10, 3, 1)
76
+
77
+ if 'qa_chain' not in st.session_state and 'vector_db' in st.session_state:
78
+ st.session_state['qa_chain'] = initialize_llmchain(llm_option, temperature, max_tokens, top_k, st.session_state['vector_db'])
79
+
80
+ if "chat_history" not in st.session_state:
81
+ st.session_state["chat_history"] = []
82
+
83
+ user_input = st.text_input("Ask a question")
84
+
85
+ if st.button("Submit") and user_input:
86
+ qa_chain = st.session_state['qa_chain']
87
+ response = qa_chain.invoke({"question": user_input, "chat_history": st.session_state["chat_history"]})
88
+
89
+ st.session_state["chat_history"].append((user_input, response["answer"]))
90
+
91
+ st.write("### Response:")
92
+ st.write(response["answer"])
93
+
94
+ st.write("### Sources:")
95
+ for doc in response["source_documents"][:3]:
96
+ st.write(f"Page {doc.metadata['page'] + 1}: {doc.page_content[:300]}...")
97
+
98
+ st.write("### Chat History")
99
+ for user_msg, bot_msg in st.session_state["chat_history"]:
100
+ st.text(f"User: {user_msg}")
101
+ st.text(f"Assistant: {bot_msg}")