Kadi-IAM commited on
Commit
1a20a59
·
1 Parent(s): 4e765a8

Clean code and add readme

Browse files
Files changed (12) hide show
  1. LISA_mini.ipynb +23 -25
  2. README.md +31 -0
  3. app.py +25 -20
  4. documents.py +51 -130
  5. embeddings.py +26 -15
  6. llms.py +16 -34
  7. preprocess_documents.py +9 -4
  8. ragchain.py +18 -5
  9. requirements.txt +1 -1
  10. rerank.py +3 -2
  11. retrievers.py +12 -7
  12. vectorestores.py +8 -3
LISA_mini.ipynb CHANGED
@@ -1,8 +1,16 @@
1
  {
2
  "cells": [
 
 
 
 
 
 
 
 
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "adcfdba2",
7
  "metadata": {},
8
  "outputs": [],
@@ -18,14 +26,13 @@
18
  "from langchain.chains import ConversationalRetrievalChain\n",
19
  "from langchain.llms import HuggingFaceTextGenInference\n",
20
  "from langchain.chains.conversation.memory import (\n",
21
- " ConversationBufferMemory,\n",
22
  " ConversationBufferWindowMemory,\n",
23
  ")"
24
  ]
25
  },
26
  {
27
  "cell_type": "code",
28
- "execution_count": 2,
29
  "id": "2d85c6d9",
30
  "metadata": {},
31
  "outputs": [],
@@ -68,7 +75,7 @@
68
  },
69
  {
70
  "cell_type": "code",
71
- "execution_count": 4,
72
  "id": "2d5bacd5",
73
  "metadata": {},
74
  "outputs": [],
@@ -107,7 +114,7 @@
107
  },
108
  {
109
  "cell_type": "code",
110
- "execution_count": 5,
111
  "id": "8cd31248",
112
  "metadata": {},
113
  "outputs": [],
@@ -140,21 +147,12 @@
140
  },
141
  {
142
  "cell_type": "code",
143
- "execution_count": 7,
144
- "id": "73d560de",
145
- "metadata": {},
146
- "outputs": [],
147
- "source": [
148
- "# Create retrievers"
149
- ]
150
- },
151
- {
152
- "cell_type": "code",
153
- "execution_count": 12,
154
  "id": "e5796990",
155
  "metadata": {},
156
  "outputs": [],
157
  "source": [
 
158
  "# Some advanced RAG, with parent document retriever, hybrid-search and rerank\n",
159
  "\n",
160
  "# 1. ParentDocumentRetriever. Note: this will take a long time (~several minutes)\n",
@@ -178,7 +176,7 @@
178
  },
179
  {
180
  "cell_type": "code",
181
- "execution_count": 11,
182
  "id": "bc299740",
183
  "metadata": {},
184
  "outputs": [],
@@ -191,7 +189,7 @@
191
  },
192
  {
193
  "cell_type": "code",
194
- "execution_count": null,
195
  "id": "2eb8bc8f",
196
  "metadata": {},
197
  "outputs": [],
@@ -214,7 +212,7 @@
214
  "\n",
215
  "from sentence_transformers import CrossEncoder\n",
216
  "\n",
217
- "model_name = \"BAAI/bge-reranker-large\" #\n",
218
  "\n",
219
  "class BgeRerank(BaseDocumentCompressor):\n",
220
  " model_name:str = model_name\n",
@@ -273,7 +271,7 @@
273
  },
274
  {
275
  "cell_type": "code",
276
- "execution_count": 14,
277
  "id": "af780912",
278
  "metadata": {},
279
  "outputs": [],
@@ -283,7 +281,7 @@
283
  "# Ensemble all above\n",
284
  "ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5])\n",
285
  "\n",
286
- "# Re-rank\n",
287
  "compressor = BgeRerank()\n",
288
  "rerank_retriever = ContextualCompressionRetriever(\n",
289
  " base_compressor=compressor, base_retriever=ensemble_retriever\n",
@@ -292,7 +290,7 @@
292
  },
293
  {
294
  "cell_type": "code",
295
- "execution_count": 15,
296
  "id": "beb9ab21",
297
  "metadata": {},
298
  "outputs": [],
@@ -307,7 +305,7 @@
307
  " self.return_messages = return_messages\n",
308
  "\n",
309
  " def create(self, retriver, llm):\n",
310
- " memory = ConversationBufferWindowMemory( # ConversationBufferMemory(\n",
311
  " memory_key=self.memory_key,\n",
312
  " return_messages=self.return_messages,\n",
313
  " output_key=self.output_key,\n",
@@ -634,7 +632,7 @@
634
  ],
635
  "metadata": {
636
  "kernelspec": {
637
- "display_name": "Python 3 (ipykernel)",
638
  "language": "python",
639
  "name": "python3"
640
  },
@@ -648,7 +646,7 @@
648
  "name": "python",
649
  "nbconvert_exporter": "python",
650
  "pygments_lexer": "ipython3",
651
- "version": "3.10.13"
652
  }
653
  },
654
  "nbformat": 4,
 
1
  {
2
  "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9267529d",
6
+ "metadata": {},
7
+ "source": [
8
+ "A mini version of LISA in a Jupyter notebook for easier testing and playing around."
9
+ ]
10
+ },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 2,
14
  "id": "adcfdba2",
15
  "metadata": {},
16
  "outputs": [],
 
26
  "from langchain.chains import ConversationalRetrievalChain\n",
27
  "from langchain.llms import HuggingFaceTextGenInference\n",
28
  "from langchain.chains.conversation.memory import (\n",
 
29
  " ConversationBufferWindowMemory,\n",
30
  ")"
31
  ]
32
  },
33
  {
34
  "cell_type": "code",
35
+ "execution_count": 3,
36
  "id": "2d85c6d9",
37
  "metadata": {},
38
  "outputs": [],
 
75
  },
76
  {
77
  "cell_type": "code",
78
+ "execution_count": 5,
79
  "id": "2d5bacd5",
80
  "metadata": {},
81
  "outputs": [],
 
114
  },
115
  {
116
  "cell_type": "code",
117
+ "execution_count": 6,
118
  "id": "8cd31248",
119
  "metadata": {},
120
  "outputs": [],
 
147
  },
148
  {
149
  "cell_type": "code",
150
+ "execution_count": 8,
 
 
 
 
 
 
 
 
 
 
151
  "id": "e5796990",
152
  "metadata": {},
153
  "outputs": [],
154
  "source": [
155
+ "# Create retrievers\n",
156
  "# Some advanced RAG, with parent document retriever, hybrid-search and rerank\n",
157
  "\n",
158
  "# 1. ParentDocumentRetriever. Note: this will take a long time (~several minutes)\n",
 
176
  },
177
  {
178
  "cell_type": "code",
179
+ "execution_count": 9,
180
  "id": "bc299740",
181
  "metadata": {},
182
  "outputs": [],
 
189
  },
190
  {
191
  "cell_type": "code",
192
+ "execution_count": 10,
193
  "id": "2eb8bc8f",
194
  "metadata": {},
195
  "outputs": [],
 
212
  "\n",
213
  "from sentence_transformers import CrossEncoder\n",
214
  "\n",
215
+ "model_name = \"BAAI/bge-reranker-large\"\n",
216
  "\n",
217
  "class BgeRerank(BaseDocumentCompressor):\n",
218
  " model_name:str = model_name\n",
 
271
  },
272
  {
273
  "cell_type": "code",
274
+ "execution_count": 11,
275
  "id": "af780912",
276
  "metadata": {},
277
  "outputs": [],
 
281
  "# Ensemble all above\n",
282
  "ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5])\n",
283
  "\n",
284
+ "# Rerank\n",
285
  "compressor = BgeRerank()\n",
286
  "rerank_retriever = ContextualCompressionRetriever(\n",
287
  " base_compressor=compressor, base_retriever=ensemble_retriever\n",
 
290
  },
291
  {
292
  "cell_type": "code",
293
+ "execution_count": 12,
294
  "id": "beb9ab21",
295
  "metadata": {},
296
  "outputs": [],
 
305
  " self.return_messages = return_messages\n",
306
  "\n",
307
  " def create(self, retriver, llm):\n",
308
+ " memory = ConversationBufferWindowMemory(\n",
309
  " memory_key=self.memory_key,\n",
310
  " return_messages=self.return_messages,\n",
311
  " output_key=self.output_key,\n",
 
632
  ],
633
  "metadata": {
634
  "kernelspec": {
635
+ "display_name": "lisa",
636
  "language": "python",
637
  "name": "python3"
638
  },
 
646
  "name": "python",
647
  "nbconvert_exporter": "python",
648
  "pygments_lexer": "ipython3",
649
+ "version": "3.11.10"
650
  }
651
  },
652
  "nbformat": 4,
README.md CHANGED
@@ -11,3 +11,34 @@ startup_duration_timeout: 2h
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ LISA (Lithium Ion Solid-state Assistant) is a question-and-answer (Q&A) research assistant designed for efficient knowledge management with a primary focus on battery science, yet versatile enough to support broader scientific domains. Built on a Retrieval-Augmented Generation (RAG) architecture, LISA uses advanced Large Language Models (LLMs) to provide reliable, detailed answers to research questions.
16
+
17
+ DEMO: https://huggingface.co/spaces/Kadi-IAM/LISA
18
+
19
+ ### Installation
20
+ 1. Clone the Repository:
21
+ ```bash
22
+ git clone "link of this repo"
23
+ cd LISA
24
+ ```
25
+
26
+ 2. Install Dependencies:
27
+ ```bash
28
+ pip install -r requirements.txt
29
+ ```
30
+
31
+ 3. Set Up the Knowledge Base
32
+ Populate the knowledge base with relevant documents or research papers. Ensure that documents are in a format (pdf or xml) compatible with the RAG pipeline. By default documents should be located at `data/documents`. After running the following comand, some caches files are saved into `data/db`. ATTENTION: pickle is used to save these caches, be careful with potential security risks.
33
+ ```bash
34
+ python preprocess_documents.py
35
+ ```
36
+
37
+ 4. Running LISA
38
+ Once setup is complete, run the following command to launch LISA:
39
+ ```bash
40
+ python app.py
41
+ ```
42
+
43
+ ### About
44
+ For more information on our work in intelligent research data management systems, please visit [KadiAI](https://kadi.iam.kit.edu/kadi-ai).
app.py CHANGED
@@ -1,12 +1,15 @@
 
 
 
 
1
  import os
2
  import time
3
  import re
4
- from pathlib import Path
5
- from dotenv import load_dotenv
6
  import pickle
7
 
8
-
9
- import gradio as gr
10
 
11
  from huggingface_hub import login
12
  from langchain.vectorstores import FAISS
@@ -15,24 +18,21 @@ from llms import get_groq_chat
15
  from documents import load_pdf_as_docs, load_xml_as_docs
16
  from vectorestores import get_faiss_vectorestore
17
 
18
-
19
  # For debug
20
  # from langchain.globals import set_debug
21
  # set_debug(True)
22
 
23
-
24
  # Load and set env variables
25
  load_dotenv()
26
 
 
27
  HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
28
  login(HUGGINGFACEHUB_API_TOKEN)
29
  TAVILY_API_KEY = os.environ["TAVILY_API_KEY"] # Search engine
30
 
31
- # Other settings
32
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
33
 
34
  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
35
-
36
  database_root = "./data/db"
37
  document_path = "./data/documents"
38
 
@@ -80,12 +80,13 @@ from langchain.retrievers import BM25Retriever, EnsembleRetriever
80
 
81
  bm25_retriever = BM25Retriever.from_documents(
82
  document_chunks, k=5
83
- ) # 1/2 of dense retriever, experimental value
84
 
85
- # Ensemble all above
86
  ensemble_retriever = EnsembleRetriever(
87
  retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5]
88
  )
 
89
  # Reranker
90
  from rerank import BgeRerank
91
 
@@ -98,7 +99,7 @@ print("rerank loaded")
98
  llm = get_groq_chat(model_name="llama-3.1-70b-versatile")
99
 
100
 
101
- # # # Create conversation qa chain (Note: conversation is not supported yet)
102
  from ragchain import RAGChain
103
 
104
  rag_chain = RAGChain()
@@ -108,13 +109,11 @@ lisa_qa_conversation = rag_chain.create(rerank_retriever, llm, add_citation=True
108
  from langchain_community.retrievers import TavilySearchAPIRetriever
109
  from langchain.chains import RetrievalQAWithSourcesChain
110
 
111
- web_search_retriever = TavilySearchAPIRetriever(
112
- k=4
113
- ) # , include_raw_content=True)#, include_raw_content=True)
114
  web_qa_chain = RetrievalQAWithSourcesChain.from_chain_type(
115
  llm, retriever=web_search_retriever, return_source_documents=True
116
  )
117
- print("chain loaded")
118
 
119
 
120
  # Gradio utils
@@ -136,7 +135,7 @@ def add_text(history, text):
136
 
137
 
138
  def postprocess_remove_cite_misinfo(text, allowed_max_cite_num=6):
139
- """Exp.-based removal of misinfo. of citations."""
140
 
141
  # Remove trailing references at end of text
142
  if "References:\n[" in text:
@@ -480,7 +479,7 @@ def main():
480
  # flag_web_search = gr.Checkbox(label="Search web", info="Search information from Internet")
481
  gr.Markdown("More in DEV...")
482
 
483
- # Manage functions
484
  user_txt.submit(check_input_text, user_txt, None).success(
485
  add_text, [chatbot, user_txt], [chatbot, user_txt]
486
  ).then(bot_lisa, [chatbot, flag_web_search], [chatbot, doc_citation])
@@ -575,6 +574,7 @@ def main():
575
  with gr.Tab("Setting"):
576
  gr.Markdown("More in DEV...")
577
 
 
578
  load_document.click(
579
  document_changes,
580
  inputs=[uploaded_doc], # , repo_id],
@@ -606,8 +606,9 @@ def main():
606
  )
607
 
608
  ##########################
609
- # Preview tab
610
  with gr.Tab("Preview feature 🔬"):
 
611
  with gr.Tab("Vision LM 🖼"):
612
  vision_tmp_link = (
613
  "https://kadi-iam-lisa-vlm.hf.space/" # vision model link
@@ -620,6 +621,7 @@ def main():
620
  )
621
  # gr.Markdown("placeholder")
622
 
 
623
  with gr.Tab("KadiChat 💬"):
624
  kadichat_tmp_link = (
625
  "https://kadi-iam-kadichat.hf.space/" # vision model link
@@ -631,9 +633,12 @@ def main():
631
  )
632
  )
633
 
 
634
  with gr.Tab("RAG enhanced with Knowledge Graph (dev) 🔎"):
635
  kg_tmp_link = "https://kadi-iam-kadikgraph.static.hf.space/index.html"
636
- gr.Markdown("[If rendering fails, look at the graph here](https://kadi-iam-kadikgraph.static.hf.space)")
 
 
637
  with gr.Blocks(css="""footer {visibility: hidden};""") as preview_tab:
638
  gr.HTML(
639
  """<iframe
 
1
+ """
2
+ Main app for LISA RAG chatbot based on langchain.
3
+ """
4
+
5
  import os
6
  import time
7
  import re
8
+ import gradio as gr
 
9
  import pickle
10
 
11
+ from pathlib import Path
12
+ from dotenv import load_dotenv
13
 
14
  from huggingface_hub import login
15
  from langchain.vectorstores import FAISS
 
18
  from documents import load_pdf_as_docs, load_xml_as_docs
19
  from vectorestores import get_faiss_vectorestore
20
 
 
21
  # For debug
22
  # from langchain.globals import set_debug
23
  # set_debug(True)
24
 
 
25
  # Load and set env variables
26
  load_dotenv()
27
 
28
+ # Set API keys
29
  HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
30
  login(HUGGINGFACEHUB_API_TOKEN)
31
  TAVILY_API_KEY = os.environ["TAVILY_API_KEY"] # Search engine
32
 
 
 
33
 
34
  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
35
+ # Set database path
36
  database_root = "./data/db"
37
  document_path = "./data/documents"
38
 
 
80
 
81
  bm25_retriever = BM25Retriever.from_documents(
82
  document_chunks, k=5
83
+ ) # k = 1/2 of dense retriever, experimental value
84
 
85
+ # Ensemble all above retrievers
86
  ensemble_retriever = EnsembleRetriever(
87
  retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5]
88
  )
89
+
90
  # Reranker
91
  from rerank import BgeRerank
92
 
 
99
  llm = get_groq_chat(model_name="llama-3.1-70b-versatile")
100
 
101
 
102
+ # Create conversation qa chain (Note: conversation is not supported yet)
103
  from ragchain import RAGChain
104
 
105
  rag_chain = RAGChain()
 
109
  from langchain_community.retrievers import TavilySearchAPIRetriever
110
  from langchain.chains import RetrievalQAWithSourcesChain
111
 
112
+ web_search_retriever = TavilySearchAPIRetriever(k=4) # , include_raw_content=True)
 
 
113
  web_qa_chain = RetrievalQAWithSourcesChain.from_chain_type(
114
  llm, retriever=web_search_retriever, return_source_documents=True
115
  )
116
+ print("chains loaded")
117
 
118
 
119
  # Gradio utils
 
135
 
136
 
137
  def postprocess_remove_cite_misinfo(text, allowed_max_cite_num=6):
138
+ """Heuristic removal of misinfo. of citations."""
139
 
140
  # Remove trailing references at end of text
141
  if "References:\n[" in text:
 
479
  # flag_web_search = gr.Checkbox(label="Search web", info="Search information from Internet")
480
  gr.Markdown("More in DEV...")
481
 
482
+ # Action functions
483
  user_txt.submit(check_input_text, user_txt, None).success(
484
  add_text, [chatbot, user_txt], [chatbot, user_txt]
485
  ).then(bot_lisa, [chatbot, flag_web_search], [chatbot, doc_citation])
 
574
  with gr.Tab("Setting"):
575
  gr.Markdown("More in DEV...")
576
 
577
+ # Actions
578
  load_document.click(
579
  document_changes,
580
  inputs=[uploaded_doc], # , repo_id],
 
606
  )
607
 
608
  ##########################
609
+ # Preview tabs
610
  with gr.Tab("Preview feature 🔬"):
611
+ # VLM model
612
  with gr.Tab("Vision LM 🖼"):
613
  vision_tmp_link = (
614
  "https://kadi-iam-lisa-vlm.hf.space/" # vision model link
 
621
  )
622
  # gr.Markdown("placeholder")
623
 
624
+ # OAuth2 linkage to Kadi-demo
625
  with gr.Tab("KadiChat 💬"):
626
  kadichat_tmp_link = (
627
  "https://kadi-iam-kadichat.hf.space/" # vision model link
 
633
  )
634
  )
635
 
636
+ # Knowledge graph-enhanced RAG
637
  with gr.Tab("RAG enhanced with Knowledge Graph (dev) 🔎"):
638
  kg_tmp_link = "https://kadi-iam-kadikgraph.static.hf.space/index.html"
639
+ gr.Markdown(
640
+ "[If rendering fails, look at the graph here](https://kadi-iam-kadikgraph.static.hf.space)"
641
+ )
642
  with gr.Blocks(css="""footer {visibility: hidden};""") as preview_tab:
643
  gr.HTML(
644
  """<iframe
documents.py CHANGED
@@ -1,25 +1,30 @@
 
 
 
 
1
  import os
2
- import shutil
3
 
4
  from langchain.document_loaders import (
5
  PyMuPDFLoader,
6
  )
7
  from langchain.docstore.document import Document
8
-
9
- from langchain.vectorstores import Chroma
10
-
11
  from langchain.text_splitter import (
12
- RecursiveCharacterTextSplitter,
13
  SpacyTextSplitter,
14
  )
15
 
 
16
  def load_pdf_as_docs(pdf_path, loader_module=None, load_kwargs=None):
17
  """Load and parse pdf file(s)."""
18
-
19
- if pdf_path.endswith('.pdf'): # single file
20
  pdf_docs = [pdf_path]
21
  else: # a directory
22
- pdf_docs = [os.path.join(pdf_path, f) for f in os.listdir(pdf_path) if f.endswith('.pdf')]
 
 
 
 
23
 
24
  if load_kwargs is None:
25
  load_kwargs = {}
@@ -31,180 +36,96 @@ def load_pdf_as_docs(pdf_path, loader_module=None, load_kwargs=None):
31
  loader = loader_module(pdf, **load_kwargs)
32
  doc = loader.load()
33
  docs.extend(doc)
34
-
35
  return docs
36
 
 
37
  def load_xml_as_docs(xml_path, loader_module=None, load_kwargs=None):
38
  """Load and parse xml file(s)."""
39
-
40
  from bs4 import BeautifulSoup
41
  from unstructured.cleaners.core import group_broken_paragraphs
42
-
43
- if xml_path.endswith('.xml'): # single file
44
  xml_docs = [xml_path]
45
  else: # a directory
46
- xml_docs = [os.path.join(xml_path, f) for f in os.listdir(xml_path) if f.endswith('.xml')]
47
-
 
 
 
 
48
  if load_kwargs is None:
49
  load_kwargs = {}
50
 
51
  docs = []
52
  for xml_file in xml_docs:
53
- # print("now reading file...")
54
  with open(xml_file) as fp:
55
- soup = BeautifulSoup(fp, features="xml") # txt is simply the a string with your XML file
 
 
56
  pageText = soup.findAll(string=True)
57
- parsed_text = '\n'.join(pageText) # or " ".join, seems similar
58
- # # Clean text
59
  parsed_text_grouped = group_broken_paragraphs(parsed_text)
60
-
61
  # get metadata
62
  try:
63
  from lxml import etree as ET
 
64
  tree = ET.parse(xml_file)
65
 
66
  # Define namespace
67
  ns = {"tei": "http://www.tei-c.org/ns/1.0"}
68
  # Read Author personal names as an example
69
- pers_name_elements = tree.xpath("tei:teiHeader/tei:fileDesc/tei:titleStmt/tei:author/tei:persName", namespaces=ns)
 
 
 
70
  first_per = pers_name_elements[0].text
71
  author_info = first_per + " et al"
72
 
73
- title_elements = tree.xpath("tei:teiHeader/tei:fileDesc/tei:titleStmt/tei:title", namespaces=ns)
 
 
74
  title = title_elements[0].text
75
 
76
  # Combine source info
77
  source_info = "_".join([author_info, title])
78
  except:
79
  source_info = "unknown"
80
-
81
- # maybe even better TODO: discuss with Jens
82
  # first_author = soup.find("author")
83
  # publication_year = soup.find("date", attrs={'type': 'published'})
84
  # title = soup.find("title")
85
  # source_info = [first_author, publication_year, title]
86
  # source_info_str = "_".join([info.text.strip() if info is not None else "unknown" for info in source_info])
87
-
88
- doc = [Document(page_content=parsed_text_grouped, metadata={"source": source_info})]#, metadata={"source": "local"})
 
 
 
 
89
 
90
  docs.extend(doc)
91
-
92
  return docs
93
 
94
 
95
  def get_doc_chunks(docs, splitter=None):
96
  """Split docs into chunks."""
97
-
98
  if splitter is None:
99
- # splitter = RecursiveCharacterTextSplitter(
100
  # # separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256
101
  # separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=128
102
  # )
 
103
  splitter = SpacyTextSplitter.from_tiktoken_encoder(
104
  chunk_size=512,
105
  chunk_overlap=128,
106
  )
107
  chunks = splitter.split_documents(docs)
108
-
109
- return chunks
110
-
111
 
112
- def persist_vectorstore(document_chunks, embeddings, persist_directory="db", overwrite=False):
113
- # embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
114
- # vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
115
- if overwrite:
116
- shutil.rmtree(persist_directory) # Empty and reset db
117
- db = Chroma.from_documents(documents=document_chunks, embedding=embeddings, persist_directory=persist_directory)
118
- # db.delete_collection()
119
- db.persist()
120
- # db = None
121
- # db = Chroma(persist_directory="db", embedding_function = embeddings, client_settings=CHROMA_SETTINGS)
122
- # vectorstore = FAISS.from_documents(documents=document_chunks, embedding=embeddings)
123
- return db
124
-
125
-
126
- class VectorstoreManager:
127
-
128
- def __init__(self):
129
- self.vectorstore_class = Chroma
130
-
131
- def create_db(self, embeddings):
132
- db = self.vectorstore_class(embedding_function=embeddings)
133
-
134
- self.db = db
135
- return db
136
-
137
-
138
- def load_db(self, persist_directory, embeddings):
139
- """Load local vectorestore."""
140
-
141
- db = self.vectorstore_class(persist_directory=persist_directory, embedding_function=embeddings)
142
- self.db = db
143
-
144
- return db
145
-
146
- def create_db_from_documents(self, document_chunks, embeddings, persist_directory="db", overwrite=False):
147
- """Create db from documents."""
148
-
149
- if overwrite:
150
- shutil.rmtree(persist_directory) # Empty and reset db
151
- db = self.vectorstore_class.from_documents(documents=document_chunks, embedding=embeddings, persist_directory=persist_directory)
152
- self.db = db
153
-
154
- return db
155
-
156
- def persist_db(self, persist_directory="db"):
157
- """Persist db."""
158
-
159
- assert self.db
160
- self.db.persist() # Chroma
161
-
162
- class RetrieverManager:
163
- # some other retrievers Using Advanced Retrievers in LangChain https://www.comet.com/site/blog/using-advanced-retrievers-in-langchain/
164
-
165
- def __init__(self, vectorstore, k=10):
166
-
167
- self.vectorstore = vectorstore
168
- self.retriever = vectorstore.as_retriever(search_kwargs={"k": k}) #search_kwargs={"k": 8}),
169
-
170
- def get_rerank_retriver(self, base_retriever=None):
171
-
172
- if base_retriever is None:
173
- base_retriever = self.retriever
174
- # with rerank
175
- from rerank import BgeRerank
176
- from langchain.retrievers import ContextualCompressionRetriever
177
-
178
- compressor = BgeRerank()
179
- compression_retriever = ContextualCompressionRetriever(
180
- base_compressor=compressor, base_retriever=base_retriever
181
- )
182
-
183
- return compression_retriever
184
-
185
- def get_parent_doc_retriver(self, documents, store_file="./store_location"):
186
- # TODO need better design
187
- # Ref: explain how it works: https://clusteredbytes.pages.dev/posts/2023/langchain-parent-document-retriever/
188
- from langchain.storage.file_system import LocalFileStore
189
- from langchain.storage import InMemoryStore
190
- from langchain.storage._lc_store import create_kv_docstore
191
- from langchain.retrievers import ParentDocumentRetriever
192
- # Ref: https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain
193
- # fs = LocalFileStore("./store_location")
194
- # store = create_kv_docstore(fs)
195
- docstore = InMemoryStore()
196
-
197
- # TODO: how to better set this?
198
- parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
199
- child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=128)
200
-
201
- retriever = ParentDocumentRetriever(
202
- vectorstore=self.vectorstore,
203
- docstore=docstore,
204
- child_splitter=child_splitter,
205
- parent_splitter=parent_splitter,
206
- search_kwargs={"k":10} # Better settings?
207
- )
208
- retriever.add_documents(documents)#, ids=None)
209
-
210
- return retriever
 
1
+ """
2
+ Parse documents, currently pdf and xml are supported.
3
+ """
4
+
5
  import os
 
6
 
7
  from langchain.document_loaders import (
8
  PyMuPDFLoader,
9
  )
10
  from langchain.docstore.document import Document
 
 
 
11
  from langchain.text_splitter import (
12
+ # RecursiveCharacterTextSplitter,
13
  SpacyTextSplitter,
14
  )
15
 
16
+
17
  def load_pdf_as_docs(pdf_path, loader_module=None, load_kwargs=None):
18
  """Load and parse pdf file(s)."""
19
+
20
+ if pdf_path.endswith(".pdf"): # single file
21
  pdf_docs = [pdf_path]
22
  else: # a directory
23
+ pdf_docs = [
24
+ os.path.join(pdf_path, f)
25
+ for f in os.listdir(pdf_path)
26
+ if f.endswith(".pdf")
27
+ ]
28
 
29
  if load_kwargs is None:
30
  load_kwargs = {}
 
36
  loader = loader_module(pdf, **load_kwargs)
37
  doc = loader.load()
38
  docs.extend(doc)
39
+
40
  return docs
41
 
42
+
43
  def load_xml_as_docs(xml_path, loader_module=None, load_kwargs=None):
44
  """Load and parse xml file(s)."""
45
+
46
  from bs4 import BeautifulSoup
47
  from unstructured.cleaners.core import group_broken_paragraphs
48
+
49
+ if xml_path.endswith(".xml"): # single file
50
  xml_docs = [xml_path]
51
  else: # a directory
52
+ xml_docs = [
53
+ os.path.join(xml_path, f)
54
+ for f in os.listdir(xml_path)
55
+ if f.endswith(".xml")
56
+ ]
57
+
58
  if load_kwargs is None:
59
  load_kwargs = {}
60
 
61
  docs = []
62
  for xml_file in xml_docs:
 
63
  with open(xml_file) as fp:
64
+ soup = BeautifulSoup(
65
+ fp, features="xml"
66
+ ) # txt is simply the a string with your XML file
67
  pageText = soup.findAll(string=True)
68
+ parsed_text = "\n".join(pageText) # or " ".join, seems similar
69
+ # Clean text
70
  parsed_text_grouped = group_broken_paragraphs(parsed_text)
71
+
72
  # get metadata
73
  try:
74
  from lxml import etree as ET
75
+
76
  tree = ET.parse(xml_file)
77
 
78
  # Define namespace
79
  ns = {"tei": "http://www.tei-c.org/ns/1.0"}
80
  # Read Author personal names as an example
81
+ pers_name_elements = tree.xpath(
82
+ "tei:teiHeader/tei:fileDesc/tei:titleStmt/tei:author/tei:persName",
83
+ namespaces=ns,
84
+ )
85
  first_per = pers_name_elements[0].text
86
  author_info = first_per + " et al"
87
 
88
+ title_elements = tree.xpath(
89
+ "tei:teiHeader/tei:fileDesc/tei:titleStmt/tei:title", namespaces=ns
90
+ )
91
  title = title_elements[0].text
92
 
93
  # Combine source info
94
  source_info = "_".join([author_info, title])
95
  except:
96
  source_info = "unknown"
97
+
98
+ # maybe even better parsing method. TODO: discuss with TUD
99
  # first_author = soup.find("author")
100
  # publication_year = soup.find("date", attrs={'type': 'published'})
101
  # title = soup.find("title")
102
  # source_info = [first_author, publication_year, title]
103
  # source_info_str = "_".join([info.text.strip() if info is not None else "unknown" for info in source_info])
104
+
105
+ doc = [
106
+ Document(
107
+ page_content=parsed_text_grouped, metadata={"source": source_info}
108
+ )
109
+ ]
110
 
111
  docs.extend(doc)
112
+
113
  return docs
114
 
115
 
116
  def get_doc_chunks(docs, splitter=None):
117
  """Split docs into chunks."""
118
+
119
  if splitter is None:
120
+ # splitter = RecursiveCharacterTextSplitter( # original default
121
  # # separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256
122
  # separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=128
123
  # )
124
+ # Spacy seems better
125
  splitter = SpacyTextSplitter.from_tiktoken_encoder(
126
  chunk_size=512,
127
  chunk_overlap=128,
128
  )
129
  chunks = splitter.split_documents(docs)
 
 
 
130
 
131
+ return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
embeddings.py CHANGED
@@ -1,39 +1,50 @@
 
 
 
1
 
2
  import torch
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
 
5
 
6
  def get_hf_embeddings(model_name=None):
7
- """Get huggingface embedding."""
8
-
9
  if model_name is None:
10
- # Some candiates:
11
  # "BAAI/bge-m3" (good, though large and slow)
12
- # "BAAI/bge-base-en-v1.5" -> seems not that good with current settings
13
- # "sentence-transformers/all-mpnet-base-v2", "maidalun1020/bce-embedding-base_v1", "intfloat/multilingual-e5-large"
14
- # Ref: https://huggingface.co/spaces/mteb/leaderboard, https://huggingface.co/maidalun1020/bce-embedding-base_v1
15
- model_name = "BAAI/bge-large-en-v1.5" # or ""
16
-
 
 
 
17
  embeddings = HuggingFaceEmbeddings(model_name=model_name)
18
-
19
  return embeddings
20
 
21
 
22
- def get_jinaai_embeddings(model_name="jinaai/jina-embeddings-v2-base-en", device="auto"):
 
 
23
  """Get jinaai embedding."""
24
-
25
  # device: cpu or cuda
26
  if device == "auto":
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  # For jinaai. Ref: https://github.com/langchain-ai/langchain/issues/6080
29
  from transformers import AutoModel
30
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # -> will yield error, need bug fixing
 
 
 
31
 
32
  model_name = model_name
33
- model_kwargs = {'device': device, "trust_remote_code": True}
34
  embeddings = HuggingFaceEmbeddings(
35
  model_name=model_name,
36
  model_kwargs=model_kwargs,
37
  )
38
-
39
- return embeddings
 
1
+ """
2
+ Load embedding models from huggingface.
3
+ """
4
 
5
  import torch
6
  from langchain.embeddings import HuggingFaceEmbeddings
7
 
8
 
9
  def get_hf_embeddings(model_name=None):
10
+ """Get huggingface embedding by name."""
11
+
12
  if model_name is None:
13
+ # Some candiates
14
  # "BAAI/bge-m3" (good, though large and slow)
15
+ # "BAAI/bge-base-en-v1.5" -> also good
16
+ # "sentence-transformers/all-mpnet-base-v2"
17
+ # "maidalun1020/bce-embedding-base_v1"
18
+ # "intfloat/multilingual-e5-large"
19
+ # Ref: https://huggingface.co/spaces/mteb/leaderboard
20
+ # https://huggingface.co/maidalun1020/bce-embedding-base_v1
21
+ model_name = "BAAI/bge-large-en-v1.5"
22
+
23
  embeddings = HuggingFaceEmbeddings(model_name=model_name)
24
+
25
  return embeddings
26
 
27
 
28
+ def get_jinaai_embeddings(
29
+ model_name="jinaai/jina-embeddings-v2-base-en", device="auto"
30
+ ):
31
  """Get jinaai embedding."""
32
+
33
  # device: cpu or cuda
34
  if device == "auto":
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  # For jinaai. Ref: https://github.com/langchain-ai/langchain/issues/6080
37
  from transformers import AutoModel
38
+
39
+ model = AutoModel.from_pretrained(
40
+ model_name, trust_remote_code=True
41
+ ) # -> will yield error, need bug fixing
42
 
43
  model_name = model_name
44
+ model_kwargs = {"device": device, "trust_remote_code": True}
45
  embeddings = HuggingFaceEmbeddings(
46
  model_name=model_name,
47
  model_kwargs=model_kwargs,
48
  )
49
+
50
+ return embeddings
llms.py CHANGED
@@ -1,22 +1,22 @@
1
- # from langchain import HuggingFaceHub, LLMChain
2
- from langchain.llms import HuggingFacePipeline
 
 
3
  from transformers import (
4
- AutoModelForCausalLM,
5
  AutoTokenizer,
6
  pipeline,
7
  )
8
- from transformers import LlamaForCausalLM, AutoModelForCausalLM, LlamaTokenizer
9
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
  from langchain_groq import ChatGroq
11
-
12
-
13
- from langchain.chat_models import ChatOpenAI
14
  from langchain.llms import HuggingFaceTextGenInference
15
 
 
 
16
 
17
  def get_llm_hf_online(inference_api_url=""):
18
  """Get LLM using huggingface inference."""
19
-
20
  if not inference_api_url: # default api url
21
  inference_api_url = (
22
  "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
@@ -35,20 +35,16 @@ def get_llm_hf_online(inference_api_url=""):
35
 
36
 
37
  def get_llm_hf_local(model_path):
38
- """Get local LLM."""
39
-
40
- model = LlamaForCausalLM.from_pretrained(
41
- model_path, device_map="auto"
42
- )
43
  tokenizer = AutoTokenizer.from_pretrained(model_path)
44
 
45
- # print('making a pipeline...')
46
- # max_length has typically been deprecated for max_new_tokens
47
  pipe = pipeline(
48
  "text-generation",
49
  model=model,
50
  tokenizer=tokenizer,
51
- max_new_tokens=1024, # better setting?
52
  model_kwargs={"temperature": 0.1}, # better setting?
53
  )
54
  llm = HuggingFacePipeline(pipeline=pipe)
@@ -56,22 +52,8 @@ def get_llm_hf_local(model_path):
56
  return llm
57
 
58
 
59
-
60
- def get_llm_openai_chat(model_name, inference_server_url):
61
- """Get openai-like LLM."""
62
-
63
- llm = ChatOpenAI(
64
- model=model_name,
65
- openai_api_key="EMPTY",
66
- openai_api_base=inference_server_url,
67
- max_tokens=1024, # better setting?
68
- temperature=0,
69
- )
70
-
71
- return llm
72
-
73
-
74
- def get_groq_chat(model_name="llama-3.1-70b-versatile"):
75
 
76
  llm = ChatGroq(temperature=0, model_name=model_name)
77
- return llm
 
1
+ """
2
+ Load LLMs from huggingface, Groq, etc.
3
+ """
4
+
5
  from transformers import (
6
+ # AutoModelForCausalLM,
7
  AutoTokenizer,
8
  pipeline,
9
  )
10
+ from langchain.llms import HuggingFacePipeline
 
11
  from langchain_groq import ChatGroq
 
 
 
12
  from langchain.llms import HuggingFaceTextGenInference
13
 
14
+ # from langchain.chat_models import ChatOpenAI # oai model
15
+
16
 
17
  def get_llm_hf_online(inference_api_url=""):
18
  """Get LLM using huggingface inference."""
19
+
20
  if not inference_api_url: # default api url
21
  inference_api_url = (
22
  "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
 
35
 
36
 
37
  def get_llm_hf_local(model_path):
38
+ """Get local LLM from huggingface."""
39
+
40
+ model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto")
 
 
41
  tokenizer = AutoTokenizer.from_pretrained(model_path)
42
 
 
 
43
  pipe = pipeline(
44
  "text-generation",
45
  model=model,
46
  tokenizer=tokenizer,
47
+ max_new_tokens=2048, # better setting?
48
  model_kwargs={"temperature": 0.1}, # better setting?
49
  )
50
  llm = HuggingFacePipeline(pipeline=pipe)
 
52
  return llm
53
 
54
 
55
+ def get_groq_chat(model_name="llama-3.1-70b-versatile"):
56
+ """Get LLM from Groq."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  llm = ChatGroq(temperature=0, model_name=model_name)
59
+ return llm
preprocess_documents.py CHANGED
@@ -1,15 +1,17 @@
1
  """
2
- Load and parse files (pdf) in the data/documents and save cached pkl files.
 
 
 
 
 
3
  """
4
 
5
  import os
6
  import pickle
7
 
8
  from dotenv import load_dotenv
9
-
10
-
11
  from huggingface_hub import login
12
-
13
  from documents import load_pdf_as_docs, get_doc_chunks
14
  from embeddings import get_jinaai_embeddings
15
 
@@ -23,11 +25,14 @@ login(HUGGINGFACEHUB_API_TOKEN)
23
 
24
 
25
  def save_to_pickle(obj, filename):
 
 
26
  with open(filename, "wb") as file:
27
  pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)
28
 
29
 
30
  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
 
31
  database_root = "./data/db"
32
  document_path = "./data/documents"
33
 
 
1
  """
2
+ Load and parse files (pdf) in the "data/documents" and save cached pkl files.
3
+ It will load and parse files and save 4 caches:
4
+ 1. "docs.pkl" for loaded text documents
5
+ 2. "docs_chunks.pkl" for chunked text
6
+ 3. "docstore.pkl" for small-to-big retriever
7
+ 4. faiss_index for FAISS vectore store
8
  """
9
 
10
  import os
11
  import pickle
12
 
13
  from dotenv import load_dotenv
 
 
14
  from huggingface_hub import login
 
15
  from documents import load_pdf_as_docs, get_doc_chunks
16
  from embeddings import get_jinaai_embeddings
17
 
 
25
 
26
 
27
  def save_to_pickle(obj, filename):
28
+ """Save obj to disk using pickle."""
29
+
30
  with open(filename, "wb") as file:
31
  pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)
32
 
33
 
34
  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
35
+ # Set database path, should be same as defined in "app.py"
36
  database_root = "./data/db"
37
  document_path = "./data/documents"
38
 
ragchain.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  from langchain.chains import LLMChain
2
 
3
  from langchain.prompts import (
@@ -11,17 +15,17 @@ from langchain.chains import ConversationalRetrievalChain
11
  from langchain.chains.conversation.memory import (
12
  ConversationBufferWindowMemory,
13
  )
14
-
15
-
16
  from langchain.chains import StuffDocumentsChain
17
 
18
 
19
  def get_cite_combine_docs_chain(llm):
 
20
 
21
  # Ref: https://github.com/langchain-ai/langchain/issues/7239
22
  # Function to format each document with an index, source, and content.
23
  def format_document(doc, index, prompt):
24
  """Format a document into a string based on a prompt template."""
 
25
  # Create a dictionary with document content and metadata.
26
  base_info = {
27
  "page_content": doc.page_content,
@@ -40,7 +44,11 @@ def get_cite_combine_docs_chain(llm):
40
 
41
  # Custom chain class to handle document combination with source indices.
42
  class StuffDocumentsWithIndexChain(StuffDocumentsChain):
 
 
43
  def _get_inputs(self, docs, **kwargs):
 
 
44
  # Format each document and combine them.
45
  doc_strings = [
46
  format_document(doc, i, self.document_prompt)
@@ -58,6 +66,7 @@ def get_cite_combine_docs_chain(llm):
58
  )
59
  return inputs
60
 
 
61
  # Ref: https://huggingface.co/spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
62
  # Define a chat prompt with instructions for citing documents.
63
  combine_doc_prompt = PromptTemplate(
@@ -103,6 +112,8 @@ def get_cite_combine_docs_chain(llm):
103
 
104
 
105
  class RAGChain:
 
 
106
  def __init__(
107
  self, memory_key="chat_history", output_key="answer", return_messages=True
108
  ):
@@ -111,14 +122,17 @@ class RAGChain:
111
  self.return_messages = return_messages
112
 
113
  def create(self, retriever, llm, add_citation=False):
114
- memory = ConversationBufferWindowMemory( # ConversationBufferMemory(
 
 
 
115
  k=2,
116
  memory_key=self.memory_key,
117
  return_messages=self.return_messages,
118
  output_key=self.output_key,
119
  )
120
 
121
- # https://github.com/langchain-ai/langchain/issues/4608
122
  conversation_chain = ConversationalRetrievalChain.from_llm(
123
  llm=llm,
124
  retriever=retriever,
@@ -127,7 +141,6 @@ class RAGChain:
127
  rephrase_question=False, # disable rephrase, for test purpose
128
  get_chat_history=lambda x: x,
129
  # return_generated_question=True, # for debug
130
- # verbose=True,
131
  # combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control
132
  # condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control
133
  )
 
1
+ """
2
+ Main RAG chain based on langchain.
3
+ """
4
+
5
  from langchain.chains import LLMChain
6
 
7
  from langchain.prompts import (
 
15
  from langchain.chains.conversation.memory import (
16
  ConversationBufferWindowMemory,
17
  )
 
 
18
  from langchain.chains import StuffDocumentsChain
19
 
20
 
21
  def get_cite_combine_docs_chain(llm):
22
+ """Get doc chain which adds metadata to text chunks."""
23
 
24
  # Ref: https://github.com/langchain-ai/langchain/issues/7239
25
  # Function to format each document with an index, source, and content.
26
  def format_document(doc, index, prompt):
27
  """Format a document into a string based on a prompt template."""
28
+
29
  # Create a dictionary with document content and metadata.
30
  base_info = {
31
  "page_content": doc.page_content,
 
44
 
45
  # Custom chain class to handle document combination with source indices.
46
  class StuffDocumentsWithIndexChain(StuffDocumentsChain):
47
+ """Custom chain class to handle document combination with source indices."""
48
+
49
  def _get_inputs(self, docs, **kwargs):
50
+ """Overwrite _get_inputs to add metadata for text chunks."""
51
+
52
  # Format each document and combine them.
53
  doc_strings = [
54
  format_document(doc, i, self.document_prompt)
 
66
  )
67
  return inputs
68
 
69
+ # Main prompt for RAG chain with citation
70
  # Ref: https://huggingface.co/spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
71
  # Define a chat prompt with instructions for citing documents.
72
  combine_doc_prompt = PromptTemplate(
 
112
 
113
 
114
  class RAGChain:
115
+ """Main RAG chain."""
116
+
117
  def __init__(
118
  self, memory_key="chat_history", output_key="answer", return_messages=True
119
  ):
 
122
  self.return_messages = return_messages
123
 
124
  def create(self, retriever, llm, add_citation=False):
125
+ """Create a rag chain instance."""
126
+
127
+ # Memory is kept for later support of conversational chat
128
+ memory = ConversationBufferWindowMemory( # Or ConversationBufferMemory
129
  k=2,
130
  memory_key=self.memory_key,
131
  return_messages=self.return_messages,
132
  output_key=self.output_key,
133
  )
134
 
135
+ # Ref: https://github.com/langchain-ai/langchain/issues/4608
136
  conversation_chain = ConversationalRetrievalChain.from_llm(
137
  llm=llm,
138
  retriever=retriever,
 
141
  rephrase_question=False, # disable rephrase, for test purpose
142
  get_chat_history=lambda x: x,
143
  # return_generated_question=True, # for debug
 
144
  # combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control
145
  # condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control
146
  )
requirements.txt CHANGED
@@ -5,7 +5,7 @@ langchain-community==0.2.4
5
  text-generation
6
  pypdf
7
  pymupdf
8
- gradio
9
  faiss-cpu
10
  chromadb
11
  rank-bm25
 
5
  text-generation
6
  pypdf
7
  pymupdf
8
+ gradio==4.44.1
9
  faiss-cpu
10
  chromadb
11
  rank-bm25
rerank.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- Retrank with cross encoder.
 
3
  https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c
4
  https://github.com/langchain-ai/langchain/issues/13076
5
  """
@@ -7,7 +8,7 @@ https://github.com/langchain-ai/langchain/issues/13076
7
  from __future__ import annotations
8
  from typing import Optional, Sequence
9
  from langchain.schema import Document
10
- from langchain.pydantic_v1 import Extra, root_validator
11
 
12
  from langchain.callbacks.manager import Callbacks
13
  from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
 
1
  """
2
+ Rerank with cross encoder.
3
+ Ref:
4
  https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c
5
  https://github.com/langchain-ai/langchain/issues/13076
6
  """
 
8
  from __future__ import annotations
9
  from typing import Optional, Sequence
10
  from langchain.schema import Document
11
+ from langchain.pydantic_v1 import Extra
12
 
13
  from langchain.callbacks.manager import Callbacks
14
  from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
retrievers.py CHANGED
@@ -1,7 +1,10 @@
 
 
 
 
1
  import os
2
 
3
  from langchain.text_splitter import (
4
- CharacterTextSplitter,
5
  RecursiveCharacterTextSplitter,
6
  SpacyTextSplitter,
7
  )
@@ -9,6 +12,7 @@ from langchain.text_splitter import (
9
  from rerank import BgeRerank
10
  from langchain.retrievers import ContextualCompressionRetriever
11
 
 
12
  def get_parent_doc_retriever(
13
  documents,
14
  vectorstore,
@@ -40,12 +44,14 @@ def get_parent_doc_retriever(
40
  from langchain_rag.storage import SQLStore
41
 
42
  # Instantiate the SQLStore with the root path
43
- docstore = SQLStore(namespace="test", db_url="sqlite:///parent_retrieval_db.db") # TODO: WIP
 
 
44
  else:
45
  docstore = docstore # TODO: add check
46
- # raise # TODO implement
47
 
48
- # TODO: how to better set this?
49
  # parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
50
  # child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=64)
51
  parent_splitter = SpacyTextSplitter.from_tiktoken_encoder(
@@ -62,11 +68,11 @@ def get_parent_doc_retriever(
62
  docstore=docstore,
63
  child_splitter=child_splitter,
64
  parent_splitter=parent_splitter,
65
- search_kwargs={"k": k}, # Better settings?
66
  )
67
 
68
  if add_documents:
69
- retriever.add_documents(documents) # , ids=None)
70
 
71
  if save_vectorstore:
72
  vectorstore.save_local(os.path.join(save_path_root, "faiss_index"))
@@ -80,7 +86,6 @@ def get_parent_doc_retriever(
80
 
81
  save_to_pickle(docstore, os.path.join(save_path_root, "docstore.pkl"))
82
 
83
-
84
  return retriever
85
 
86
 
 
1
+ """
2
+ Retrievers for text chunks.
3
+ """
4
+
5
  import os
6
 
7
  from langchain.text_splitter import (
 
8
  RecursiveCharacterTextSplitter,
9
  SpacyTextSplitter,
10
  )
 
12
  from rerank import BgeRerank
13
  from langchain.retrievers import ContextualCompressionRetriever
14
 
15
+
16
  def get_parent_doc_retriever(
17
  documents,
18
  vectorstore,
 
44
  from langchain_rag.storage import SQLStore
45
 
46
  # Instantiate the SQLStore with the root path
47
+ docstore = SQLStore(
48
+ namespace="test", db_url="sqlite:///parent_retrieval_db.db"
49
+ ) # TODO: WIP
50
  else:
51
  docstore = docstore # TODO: add check
52
+ # raise # TODO implement other docstores
53
 
54
+ # TODO: how to better set these values?
55
  # parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
56
  # child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=64)
57
  parent_splitter = SpacyTextSplitter.from_tiktoken_encoder(
 
68
  docstore=docstore,
69
  child_splitter=child_splitter,
70
  parent_splitter=parent_splitter,
71
+ search_kwargs={"k": k},
72
  )
73
 
74
  if add_documents:
75
+ retriever.add_documents(documents)
76
 
77
  if save_vectorstore:
78
  vectorstore.save_local(os.path.join(save_path_root, "faiss_index"))
 
86
 
87
  save_to_pickle(docstore, os.path.join(save_path_root, "docstore.pkl"))
88
 
 
89
  return retriever
90
 
91
 
vectorestores.py CHANGED
@@ -1,8 +1,13 @@
1
- from langchain.vectorstores import Chroma, FAISS
 
 
 
 
 
2
 
3
  def get_faiss_vectorestore(embeddings):
4
  # Add extra text to init
5
  texts = ["LISA - Lithium Ion Solid-state Assistant"]
6
  vectorstore = FAISS.from_texts(texts, embeddings)
7
-
8
- return vectorstore
 
1
+ """
2
+ Vector stores.
3
+ """
4
+
5
+ from langchain.vectorstores import FAISS
6
+
7
 
8
  def get_faiss_vectorestore(embeddings):
9
  # Add extra text to init
10
  texts = ["LISA - Lithium Ion Solid-state Assistant"]
11
  vectorstore = FAISS.from_texts(texts, embeddings)
12
+
13
+ return vectorstore