Quentin Fisch
feat(demo): add demo files
efb5688
raw
history blame contribute delete
2.14 kB
"""
Gradio UI for Mistral 7B with RAG
"""
import os
from typing import List
import gradio as gr
from langchain_core.runnables.base import RunnableSequence
import numpy as np
from confluence_rag import generate_rag_chain, load_pdf, store_vector, load_multiple_pdf
def initialize_chain(file: gr.File) -> RunnableSequence:
"""
Initializes the chain with the given file.
If no file is provided, the llm is used without RAG.
Args:
file (gr.File): file to initialize the chain with
Returns:
RunnableSequence: the chain
"""
if file is None:
return generate_rag_chain()
if len(file) == 1:
pdf = load_pdf(file[0].name)
else:
pdf = load_multiple_pdf([f.name for f in file])
retriever = store_vector(pdf)
return generate_rag_chain(retriever)
def invoke_chain(message: str, history: List[str], file: gr.File = None) -> str:
"""
Invokes the chain with the given message and updates the chain if a new file is provided.
Args:
message (str): message to invoke the chain with
history (List[str]): history of messages
file (gr.File, optional): file to update the chain with. Defaults to None.
Returns:
str: the response of the chain
"""
# Check if file is provided and exists
if file is not None and not np.all([os.path.exists(f.name) for f in file]) or len(file) == 0:
return "Error: File not found."
if file is not None and not np.all([f.name.endswith(".pdf") for f in file]):
return "Error: File is not a pdf."
chain = initialize_chain(file)
return chain.invoke(message)
def create_demo() -> gr.Interface:
"""
Creates and returns a Gradio Chat Interface.
Returns:
gr.Interface: the Gradio Chat Interface
"""
return gr.ChatInterface(
invoke_chain,
additional_inputs=[gr.File(label="File", file_count='multiple')],
title="Mistral 7B with RAG",
description="Ask questions to Mistral about your pdf document.",
theme="soft",
)
if __name__ == "__main__":
demo = create_demo()
demo.launch()