{ "cells": [ { "cell_type": "markdown", "id": "9267529d", "metadata": {}, "source": [ "A mini version of LISA in a Jupyter notebook for easier testing and playing around." ] }, { "cell_type": "code", "execution_count": 2, "id": "adcfdba2", "metadata": {}, "outputs": [], "source": [ "# import some packages\n", "import os\n", "\n", "from dotenv import load_dotenv\n", "from langchain.document_loaders import PyPDFLoader\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "from langchain.embeddings import HuggingFaceEmbeddings\n", "from langchain.vectorstores import FAISS\n", "from langchain.chains import ConversationalRetrievalChain\n", "from langchain.llms import HuggingFaceTextGenInference\n", "from langchain.chains.conversation.memory import (\n", " ConversationBufferWindowMemory,\n", ")" ] }, { "cell_type": "code", "execution_count": 3, "id": "2d85c6d9", "metadata": {}, "outputs": [], "source": [ "# Set api keys\n", "load_dotenv(\"API.env\") # put all the API tokens here, such as openai, huggingface...\n", "HUGGINGFACEHUB_API_TOKEN = os.getenv(\"HUGGINGFACEHUB_API_TOKEN\")" ] }, { "cell_type": "code", "execution_count": null, "id": "ffd3db32", "metadata": {}, "outputs": [], "source": [ "# Set inference link, use this online one for easier reproduce\n", "inference_api_url = 'https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta'\n", "# Recommend using better LLMs, such as Mixtral 7x8B\n", "\n", "llm = HuggingFaceTextGenInference(\n", " verbose=True, # Provides detailed logs of operation\n", " max_new_tokens=1024, # Maximum number of token that can be generated.\n", " top_p=0.95, # Threshold for controlling randomness in text generation process. \n", " typical_p=0.95, #\n", " temperature=0.1, # For choosing probable words.\n", " inference_server_url=inference_api_url, # URL des Inferenzservers\n", " timeout=120, # Timeout for connection with the url\n", " )\n", "\n", "# Alternative, you can load model locally, e.g.:\n", "# model_path = \"where/you/store/local/models/zephyr-7b-beta\" # change this to your model path\n", "# model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\")\n", "# tokenizer = AutoTokenizer.from_pretrained(model_path)\n", "# pipe = pipeline(\n", "# \"text-generation\", model=model, tokenizer=tokenizer, max_new_tokens=1024, model_kwargs={\"temperature\":0.1}\n", "# )\n", "# llm = HuggingFacePipeline(pipeline=pipe)" ] }, { "cell_type": "code", "execution_count": 5, "id": "2d5bacd5", "metadata": {}, "outputs": [], "source": [ "# Function for reading and chunking text\n", "def load_pdf_as_docs(pdf_path, loader_module=None):\n", " \"\"\"Load and parse pdf files.\"\"\"\n", " \n", " if pdf_path.endswith('.pdf'): # single file\n", " pdf_docs = [pdf_path]\n", " else: # a directory\n", " pdf_docs = [os.path.join(pdf_path, f) for f in os.listdir(pdf_path) if f.endswith('.pdf')]\n", " \n", " docs = []\n", " \n", " if loader_module is None: # Set PDFLoader\n", " loader_module = PyPDFLoader\n", " for pdf in pdf_docs:\n", " loader = loader_module(pdf)\n", " doc = loader.load()\n", " docs.extend(doc)\n", " \n", " return docs\n", "\n", "def get_doc_chunks(docs, splitter=None):\n", " \"\"\"Split docs into chunks.\"\"\"\n", " \n", " if splitter is None:\n", " splitter = RecursiveCharacterTextSplitter(\n", " separators=[\"\\n\\n\", \"\\n\"], chunk_size=256, chunk_overlap=128\n", " )\n", " chunks = splitter.split_documents(docs)\n", " \n", " return chunks" ] }, { "cell_type": "code", "execution_count": 6, "id": "8cd31248", "metadata": {}, "outputs": [], "source": [ "# Specify the directory containing your PDFs\n", "directory = \"data/documents\" # change to your pdf directory\n", "\n", "# Find and parse all PDFs in the directory\n", "pdf_docs = load_pdf_as_docs(directory, PyPDFLoader)\n", "\n", "document_chunks = get_doc_chunks(pdf_docs)" ] }, { "cell_type": "code", "execution_count": null, "id": "7bf62c76", "metadata": {}, "outputs": [], "source": [ "# Set embedding\n", "embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-base-en-v1.5') # choose the one you like\n", "\n", "# Set vectorstore, e.g. FAISS\n", "texts = [\"LISA - Lithium Ion Solid-state Assistant\"]\n", "vectorstore = FAISS.from_texts(texts, embeddings) # this is a workaround as FAISS cannot be initialized by 'FAISS(embedding_function=embeddings)', waiting for Langchain fix\n", "# You may also use Chroma\n", "# vectorstore = Chroma(embedding_function=embeddings)" ] }, { "cell_type": "code", "execution_count": 8, "id": "e5796990", "metadata": {}, "outputs": [], "source": [ "# Create retrievers\n", "# Some advanced RAG, with parent document retriever, hybrid-search and rerank\n", "\n", "# 1. ParentDocumentRetriever. Note: this will take a long time (~several minutes)\n", "\n", "from langchain.storage import InMemoryStore\n", "from langchain.retrievers import ParentDocumentRetriever\n", "# For local storage, ref: https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain\n", "store = InMemoryStore()\n", "\n", "parent_splitter = RecursiveCharacterTextSplitter(separators=[\"\\n\\n\", \"\\n\"], chunk_size=512, chunk_overlap=128)\n", "child_splitter = RecursiveCharacterTextSplitter(separators=[\"\\n\\n\", \"\\n\"], chunk_size=256, chunk_overlap=64)\n", "\n", "parent_doc_retriver = ParentDocumentRetriever(\n", " vectorstore=vectorstore,\n", " docstore=store,\n", " child_splitter=child_splitter,\n", " parent_splitter=parent_splitter,\n", ")\n", "parent_doc_retriver.add_documents(pdf_docs)" ] }, { "cell_type": "code", "execution_count": 9, "id": "bc299740", "metadata": {}, "outputs": [], "source": [ "# 2. Hybrid search\n", "from langchain.retrievers import BM25Retriever\n", "\n", "bm25_retriever = BM25Retriever.from_documents(document_chunks, k=5) # 1/2 of dense retriever, experimental value" ] }, { "cell_type": "code", "execution_count": 10, "id": "2eb8bc8f", "metadata": {}, "outputs": [], "source": [ "# 3. Rerank\n", "\"\"\"\n", "Ref:\n", "https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c\n", "https://github.com/langchain-ai/langchain/issues/13076\n", "good to read:\n", "https://teemukanstren.com/2023/12/25/llmrag-based-question-answering/\n", "\"\"\"\n", "from __future__ import annotations\n", "from typing import Dict, Optional, Sequence\n", "from langchain.schema import Document\n", "from langchain.pydantic_v1 import Extra, root_validator\n", "\n", "from langchain.callbacks.manager import Callbacks\n", "from langchain.retrievers.document_compressors.base import BaseDocumentCompressor\n", "\n", "from sentence_transformers import CrossEncoder\n", "\n", "model_name = \"BAAI/bge-reranker-large\"\n", "\n", "class BgeRerank(BaseDocumentCompressor):\n", " model_name:str = model_name\n", " \"\"\"Model name to use for reranking.\"\"\" \n", " top_n: int = 10 \n", " \"\"\"Number of documents to return.\"\"\"\n", " model:CrossEncoder = CrossEncoder(model_name)\n", " \"\"\"CrossEncoder instance to use for reranking.\"\"\"\n", "\n", " def bge_rerank(self,query,docs):\n", " model_inputs = [[query, doc] for doc in docs]\n", " scores = self.model.predict(model_inputs)\n", " results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)\n", " return results[:self.top_n]\n", "\n", "\n", " class Config:\n", " \"\"\"Configuration for this pydantic object.\"\"\"\n", "\n", " extra = Extra.forbid\n", " arbitrary_types_allowed = True\n", "\n", " def compress_documents(\n", " self,\n", " documents: Sequence[Document],\n", " query: str,\n", " callbacks: Optional[Callbacks] = None,\n", " ) -> Sequence[Document]:\n", " \"\"\"\n", " Compress documents using BAAI/bge-reranker models.\n", "\n", " Args:\n", " documents: A sequence of documents to compress.\n", " query: The query to use for compressing the documents.\n", " callbacks: Callbacks to run during the compression process.\n", "\n", " Returns:\n", " A sequence of compressed documents.\n", " \"\"\"\n", " \n", " if len(documents) == 0: # to avoid empty api call\n", " return []\n", " doc_list = list(documents)\n", " _docs = [d.page_content for d in doc_list]\n", " results = self.bge_rerank(query, _docs)\n", " final_results = []\n", " for r in results:\n", " doc = doc_list[r[0]]\n", " doc.metadata[\"relevance_score\"] = r[1]\n", " final_results.append(doc)\n", " return final_results\n", " \n", " \n", "from langchain.retrievers import ContextualCompressionRetriever" ] }, { "cell_type": "code", "execution_count": 11, "id": "af780912", "metadata": {}, "outputs": [], "source": [ "# Stack all the retrievers together\n", "from langchain.retrievers import EnsembleRetriever\n", "# Ensemble all above\n", "ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5])\n", "\n", "# Rerank\n", "compressor = BgeRerank()\n", "rerank_retriever = ContextualCompressionRetriever(\n", " base_compressor=compressor, base_retriever=ensemble_retriever\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "id": "beb9ab21", "metadata": {}, "outputs": [], "source": [ "## Now begin to build Q&A system\n", "class RAGChain:\n", " def __init__(\n", " self, memory_key=\"chat_history\", output_key=\"answer\", return_messages=True\n", " ):\n", " self.memory_key = memory_key\n", " self.output_key = output_key\n", " self.return_messages = return_messages\n", "\n", " def create(self, retriver, llm):\n", " memory = ConversationBufferWindowMemory(\n", " memory_key=self.memory_key,\n", " return_messages=self.return_messages,\n", " output_key=self.output_key,\n", " )\n", "\n", " # https://github.com/langchain-ai/langchain/issues/4608\n", " conversation_chain = ConversationalRetrievalChain.from_llm(\n", " llm=llm,\n", " retriever=retriver,\n", " memory=memory,\n", " return_source_documents=True,\n", " rephrase_question=False, # disable rephrase, for test purpose\n", " get_chat_history=lambda x: x,\n", " )\n", " \n", " return conversation_chain\n", " \n", " \n", "rag_chain = RAGChain()\n", "lisa_qa_conversation = rag_chain.create(rerank_retriever, llm)" ] }, { "cell_type": "code", "execution_count": null, "id": "59159951", "metadata": {}, "outputs": [], "source": [ "# Now begin to ask question\n", "question = \"Please name two common solid electrolytes.\"\n", "result = lisa_qa_conversation({\"question\":question, \"chat_history\": []})\n", "print(result[\"answer\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "f5e3c7b5", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d736960b", "metadata": {}, "outputs": [], "source": [ "# The rests are for Gradio GUI\n", "\n", "import gradio as gr\n", "import time\n", "from pathlib import Path\n", "\n", "# Gradio utils\n", "def add_text(history, text):\n", " \"\"\"Add conversation to history message.\"\"\"\n", " history = history + [(text, None)]\n", " yield history, \"\"\n", "\n", "\n", "def bot_lisa(history):\n", " \"\"\"Get answer from LLM.\"\"\"\n", " result = lisa_qa_conversation(\n", " {\n", " \"question\": history[-1][0], # or \"query\" if RetrievalQA\n", " \"chat_history\": history[:-1],\n", " }\n", " )\n", " print(f\"Answer: {result['answer']}\")\n", " print(f\"Source document: {result['source_documents']}\") # for debug\n", " # Citation post-processing\n", " answer_text = result[\"answer\"].strip()\n", " history[-1][1] = \"\" # Fake stream, TODO: implement streaming\n", " for character in result[\"answer\"].strip():\n", " time.sleep(0.002)\n", " history[-1][1] += character\n", " yield history, \"citation place holder\"\n", "\n", "\n", "def bot(history, qa_conversation):\n", " \"\"\"Get answer from LLM.\"\"\"\n", " # print(\"id of qa conver\", id(qa_conversation)) # for debug\n", " if qa_conversation is None:\n", " gr.Warning(\"Please upload a document first.\")\n", " \n", " result = qa_conversation(\n", " {\n", " \"question\": history[-1][0], # or \"query\" if RetrievalQA\n", " \"chat_history\": history[:-1],\n", " }\n", " )\n", " print(f\"Source document: {result['source_documents']}\") # for debug\n", " history[-1][1] = \"\" # Fake stream, TODO: implement streaming\n", " for character in result[\"answer\"].strip():\n", " time.sleep(0.002)\n", " history[-1][1] += character\n", " yield history\n", "\n", "\n", "# Ref: https://huggingface.co/spaces/fffiloni/langchain-chat-with-pdf\n", "def document_changes(doc_path):#, repo_id):\n", " if doc_path is None:\n", " gr.Warning(\"Please choose a document first and wait until uploaded.\")\n", " return \"Please choose a document and wait until uploaded.\", None # for langchain_status, qa_conversation\n", " \n", " print(\"now reading document\")\n", " print(f\"file is located at {doc_path[0]}\")\n", " \n", " file_extension = Path(doc_path[0]).suffix\n", " if file_extension == \".pdf\":\n", " pdf_docs = load_pdf_as_docs(doc_path[0])\n", " document_chunks = get_doc_chunks(pdf_docs)\n", " elif file_extension == \".xml\":\n", " raise\n", " # documents = load_xml_as_docs(doc_path[0])\n", " \n", " print(\"now creating vectordatabase\")\n", " \n", " texts = [\"LISA - Lithium Ion Solid-state Assistant\"]\n", " vectorstore = FAISS.from_texts(texts, embeddings)\n", "\n", " store = InMemoryStore()\n", "\n", " parent_splitter = RecursiveCharacterTextSplitter(separators=[\"\\n\\n\", \"\\n\"], chunk_size=512, chunk_overlap=256)\n", " child_splitter = RecursiveCharacterTextSplitter(separators=[\"\\n\\n\", \"\\n\"], chunk_size=256, chunk_overlap=128)\n", "\n", " parent_doc_retriver = ParentDocumentRetriever(\n", " vectorstore=vectorstore,\n", " docstore=store,\n", " child_splitter=child_splitter,\n", " parent_splitter=parent_splitter,\n", " )\n", " parent_doc_retriver.add_documents(pdf_docs)\n", "\n", " bm25_retriever = BM25Retriever.from_documents(document_chunks, k=5) # 1/2 of dense retriever, experimental value\n", "\n", " # Ensemble all above\n", " ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5])\n", "\n", " compressor = BgeRerank()\n", " rerank_retriever = ContextualCompressionRetriever(\n", " base_compressor=compressor, base_retriever=ensemble_retriever\n", " )\n", "\n", " rag_chain = RAGChain()\n", " qa_conversation = rag_chain.create(rerank_retriever, llm)\n", " \n", " print(\"now getting llm model\")\n", " \n", "\n", " file_name = Path(doc_path[0]).name # First file\n", " return f\"Ready for {file_name}\", qa_conversation\n", "\n", "\n", "# Main gradio UI\n", "def main():\n", " # Gradio interface\n", " with gr.Blocks() as demo:\n", " ######################################################################\n", " # LISA chat tab\n", "\n", " # Title info\n", " gr.Markdown(\"## LISA\")\n", " gr.Markdown(\"Q&A system with RAG.\")\n", "\n", " with gr.Tab(\"LISA\"):\n", " # Chatbot\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " label=\"Document Assistant (chat-history context is not supported at the moment, fixing...)\",\n", " bubble_full_width=False,\n", " show_copy_button=True,\n", " likeable=True,\n", " ) # .style(height=750)\n", " with gr.Row():\n", " with gr.Column(scale=80):\n", " user_txt = gr.Textbox(\n", " label=\"Question\",\n", " placeholder=\"Type question and press Enter\",\n", " ) # .style(container=False)\n", " with gr.Column(scale=10):\n", " submit_btn = gr.Button(\"Submit\", variant=\"primary\")\n", " with gr.Column(scale=10):\n", " clear_btn = gr.Button(\"Clear\", variant=\"stop\")\n", " # Reference (citations)\n", " with gr.Accordion(\"Advanced - Document references\", open=False):\n", " doc_citation = gr.Markdown()\n", " # alternative: https://www.gradio.app/guides/creating-a-chatbot-fast\n", " gr.Examples(\n", " examples=[\n", " \"Please name two common solid electrolytes.\",\n", " \"Please name two common oxide solid electrolytes.\",\n", " \"Please tell me what is solid-state battery.\",\n", " \"How to synthesize gc-LPSC?\",\n", " \"Please tell me the purpose of Kadi4Mat.\",\n", " \"Who is working on Kadi4Mat?\",\n", " \"Can you recommend a paper to get a deeper understanding of Kadi4Mat?\",\n", " # \"How to synthesize gc-LPSC, e.g., glass-ceramic Li5.5PS4.5Cl1.5?\",\n", " ],\n", " inputs=user_txt,\n", " outputs=chatbot,\n", " fn=add_text,\n", " # cache_examples=True,\n", " )\n", "\n", " # Manage functions\n", " user_txt.submit(add_text, [chatbot, user_txt], [chatbot, user_txt]).then(\n", " bot_lisa, chatbot, [chatbot, doc_citation]\n", " )\n", "\n", " submit_btn.click(\n", " add_text,\n", " [chatbot, user_txt],\n", " [chatbot, user_txt],\n", " # concurrency_limit=8,\n", " queue=False,\n", " ).then(bot_lisa, chatbot, [chatbot, doc_citation])\n", "\n", " clear_btn.click(lambda: None, None, chatbot, queue=False)\n", "\n", " ######################################################################\n", "\n", " ######################################################################\n", " # Document-based QA\n", "\n", " with gr.Tab(\"Document-based Q&A\"):\n", " qa_conversation = gr.State()\n", " \n", " with gr.Row():\n", " with gr.Column(scale=3, variant=\"load_file_panel\"):\n", " with gr.Row():\n", " gr.HTML(\n", " \"Upload a pdf/xml file, click the Load file button and when everything is ready, you can start asking questions about the document.\"\n", " )\n", " with gr.Row():\n", " uploaded_doc = gr.File(\n", " label=\"Upload pdf/xml file (single)\",\n", " file_count=\"multiple\", # For better looking, but only support 1 file\n", " file_types=[\".pdf\", \".xml\"],\n", " type=\"filepath\",\n", " height=100,\n", " )\n", "\n", " with gr.Row():\n", " langchain_status = gr.Textbox(\n", " label=\"Status\", placeholder=\"\", interactive=False\n", " )\n", " load_document = gr.Button(\"Load file\")\n", "\n", " with gr.Column(scale=7, variant=\"chat_panel\"):\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " # label=\"Document Assistant (chat-history context is not supported at the moment, fixing...)\",\n", " label=\"Document Assistant (chat-history context is not supported at the moment, fixing...)\",\n", " show_copy_button=True,\n", " likeable=True,\n", " ) # .style(height=350)\n", " docqa_question = gr.Textbox(\n", " label=\"Question\",\n", " placeholder=\"Type question and press Enter/click Submit\",\n", " )\n", " with gr.Row():\n", " with gr.Column(scale=50):\n", " docqa_submit_btn = gr.Button(\"Submit\", variant=\"primary\")\n", " with gr.Column(scale=50):\n", " docqa_clear_btn = gr.Button(\"Clear\", variant=\"stop\")\n", " \n", " gr.Examples(\n", " examples=[\n", " \"Summarize the paper\",\n", " \"Summarize the paper in 3 bullet points\",\n", " \"What are the contributions of this paper\",\n", " \"Explain the practical implications of this paper\",\n", " \"Methods used in this paper\",\n", " \"What data has been used in this paper\",\n", " \"Results of the paper\",\n", " \"Conclusions from the paper\",\n", " \"Limitations of this paper\",\n", " \"Future works suggested in this paper\",\n", " ],\n", " inputs=docqa_question,\n", " outputs=chatbot,\n", " fn=add_text,\n", " # cache_examples=True,\n", " )\n", "\n", " load_document.click(\n", " document_changes,\n", " inputs=[uploaded_doc], # , repo_id],\n", " outputs=[langchain_status, qa_conversation],#, docqa_db, docqa_retriever],\n", " queue=False,\n", " )\n", " \n", " docqa_question.submit(add_text, [chatbot, docqa_question], [chatbot, docqa_question]).then(\n", " bot, [chatbot, qa_conversation], chatbot\n", " )\n", " docqa_submit_btn.click(add_text, [chatbot, docqa_question], [chatbot, docqa_question]).then(\n", " bot, [chatbot, qa_conversation], chatbot\n", " )\n", "\n", " gr.Markdown(\"*Notes: The model may produce incorrect statements. Users should treat these outputs as suggestions or starting points, not as definitive or accurate facts.\")\n", "\n", " ######################################################################\n", "\n", " demo.queue().launch(share=True)\n", " \n", " \n", "main()" ] }, { "cell_type": "code", "execution_count": null, "id": "e2864a11", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "lisa", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 5 }