pdf-rag-chatbot / app.py
farmax's picture
Update app.py
a8cd24e verified
raw
history blame
11.6 kB
from langchain_huggingface import HuggingFaceEmbeddings
import gradio as gr
import os
from googletrans import Translator
import requests
from dotenv import load_dotenv
import numpy as np
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import UnstructuredPDFLoader, PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.schema import Document
from langchain.memory import ConversationBufferMemory
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.llms.base import LLM
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
from tqdm import tqdm
import torch
import logging
# Configurazione del logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Aggiornamento dell'inizializzazione di HuggingFaceEmbeddings
embedding_function = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Definizione della lista di modelli LLM
list_llm_simple = ["Gemma 7B (Italian)", "Mistral 7B"]
list_llm = ["google/gemma-7b-it", "mistralai/Mistral-7B-Instruct-v0.2"]
def initialize_database(document, chunk_size, chunk_overlap, progress=gr.Progress()):
logger.info("Initializing database...")
documents = []
splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
for file in document:
try:
loader = UnstructuredPDFLoader(file.name)
docs = loader.load()
except ImportError:
logger.warning("UnstructuredPDFLoader non disponibile. Tentativo di utilizzo di PyPDFLoader.")
try:
loader = PyPDFLoader(file.name)
docs = loader.load()
except ImportError:
logger.error("Impossibile caricare il documento PDF. Assicurati di aver installato 'unstructured' o 'pypdf'.")
return None, "Errore: Pacchetti necessari non installati. Esegui 'pip install unstructured pypdf' e riprova."
for doc in docs:
text_chunks = splitter.split_text(doc.page_content)
for chunk in text_chunks:
documents.append(Document(page_content=chunk, metadata={"filename": file.name, "page": doc.metadata.get("page", 0)}))
if not documents:
return None, "Errore: Nessun documento caricato correttamente."
vectorstore = Chroma.from_documents(documents, embedding_function)
progress.update(0.5)
logger.info("Database initialized successfully.")
return vectorstore, "Initialized"
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress(), language="italian"):
logger.info("Initializing LLM chain...")
llm_name = list_llm[llm_option]
print("llm_name: ",llm_name)
if language == "italian":
default_llm = "google/gemma-7b-it"
else:
default_llm = "mistralai/Mistral-7B-Instruct-v0.2"
if llm_name != default_llm:
print(f"Using default LLM {default_llm} for {language}")
llm_name = default_llm
qa_chain = load_qa_with_sources_from_chain_type(
llm=llm_name,
chain_type="stuff",
retriever=vector_db.as_retriever(),
temperature=llm_temperature,
top_k_per_token=top_k,
max_tokens=max_tokens,
)
progress.update(1.0)
logger.info("LLM chain initialized successfully.")
return qa_chain, "Complete!"
def format_chat_history(message, history):
chat_history = ""
for item in history:
chat_history += f"\nUser: {item[0]}\nAI: {item[1]}"
chat_history += f"\n\nUser: {message}"
return chat_history
def translate_text(text, src_lang, dest_lang):
translator = Translator()
result = translator.translate(text, src=src_lang, dest=dest_lang)
return result.text
def conversation(qa_chain, message, history, language):
formatted_chat_history = format_chat_history(message, history)
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
if response_answer.find("Helpful Answer:")!= -1:
response_answer = response_answer.split("Helpful Answer:")[-1]
if language != "italian":
try:
translated_response = translate_text(response_answer, src="en", dest="it")
except Exception as e:
logger.error(f"Error translating response: {e}")
translated_response = response_answer
else:
translated_response = response_answer
response_sources = response["source_documents"]
response_source1 = response_sources[0].page_content.strip()
response_source2 = response_sources[1].page_content.strip()
response_source3 = response_sources[2].page_content.strip()
response_source1_page = response_sources[0].metadata["page"] + 1
response_source2_page = response_sources[1].metadata["page"] + 1
response_source3_page = response_sources[2].metadata["page"] + 1
new_history = history + [(message, translated_response)]
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
def demo():
with gr.Blocks(theme="base") as demo:
vector_db = gr.State()
qa_chain = gr.State()
collection_name = gr.State()
language = gr.State(value="italian") # Modifica qui
gr.Markdown(
"""<center><h2>Chatbot basato su PDF</center></h2>
<h3>Fai domande sui tuoi documenti PDF</h3>""")
gr.Markdown(
"""<b>Note:</b> Questo assistente AI, utilizzando Langchain e LLM open-source, esegue retrieval-augmented generation (RAG) dai tuoi documenti PDF. \
L'interfaccia utente mostra esplicitamente più passaggi per aiutare a comprendere il flusso di lavoro RAG.
Questo chatbot tiene conto delle domande precedenti quando genera risposte (tramite memoria conversazionale), e include riferimenti al documento per scopi di chiarezza.<br>
<br><b>Avviso:</b> Questo spazio utilizza l'hardware CPU Basic gratuito da Hugging Face. Alcuni passaggi e modelli LLM utilizzati qui sotto (endpoint di inferenza gratuiti) possono richiedere del tempo per generare una risposta.
""")
with gr.Tab("Step 1 - Carica PDF"):
with gr.Row():
document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carica i tuoi documenti PDF (singolo o multiplo)")
with gr.Tab("Step 2 - Processa documento"):
with gr.Row():
db_btn = gr.Radio(["ChromaDB"], label="Tipo di database vettoriale", value = "ChromaDB", type="index", info="Scegli il tuo database vettoriale")
with gr.Accordion("Opzioni avanzate - Divisore testo documento", open=False):
with gr.Row():
slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Dimensione chunk", info="Dimensione chunk", interactive=True)
with gr.Row():
slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label=" Sovrapposizione chunk", info="Sovrapposizione chunk", interactive=True)
with gr.Row():
db_progress = gr.Textbox(label="Inizializzazione database vettoriale", value="Nessuno")
with gr.Row():
db_btn = gr.Button("Genera database vettoriale")
with gr.Tab("Step 3 - Inizializza catena QA"):
with gr.Row():
llm_btn = gr.Radio(list_llm_simple, \
label="Modelli LLM", value = list_llm_simple[0], type="index", info="Scegli il tuo modello LLM")
with gr.Accordion("Opzioni avanzate - Modello LLM", open=False):
with gr.Row():
slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperatura", info="Temperatura del modello", interactive=True)
with gr.Row():
slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Token massimi", info="Token massimi del modello", interactive=True)
with gr.Row():
slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="Campioni top-k", info="Campioni top-k del modello", interactive=True)
with gr.Row():
llm_progress = gr.Textbox(value="Nessuno",label="Inizializzazione catena QA")
with gr.Row():
qachain_btn = gr.Button("Inizializza catena Question Answering")
with gr.Tab("Step 4 - Chatbot"):
chatbot = gr.Chatbot(height=300)
with gr.Accordion("Avanzate - Riferimenti documento", open=False):
with gr.Row():
doc_source1 = gr.Textbox(label="Riferimento 1", lines=2, container=True, scale=20)
source1_page = gr.Number(label="Pagina", scale=1)
with gr.Row():
doc_source2 = gr.Textbox(label="Riferimento 2", lines=2, container=True, scale=20)
source2_page = gr.Number(label="Pagina", scale=1)
with gr.Row():
doc_source3 = gr.Textbox(label="Riferimento 3", lines=2, container=True, scale=20)
source3_page = gr.Number(label="Pagina", scale=1)
with gr.Row():
msg = gr.Textbox(placeholder="Digita un messaggio (es. 'Di cosa parla questo documento?')", container=True)
with gr.Row():
submit_btn = gr.Button("Invia messaggio")
clear_btn = gr.ClearButton([msg, chatbot], value="Pulisci conversazione")
with gr.Row():
language_selector = gr.Radio(choices=["italiano", "inglese"], value="italiano", label="Lingua")
# Preprocessing events
db_btn.click(initialize_database, \
inputs=[document, slider_chunk_size, slider_chunk_overlap], \
outputs=[vector_db, collection_name, db_progress])
qachain_btn.click(initialize_LLM, \
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, language], \
outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
inputs=None, \
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
queue=False)
# Chatbot events
msg.submit(conversation, \
inputs=[qa_chain, msg, chatbot, language], \
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
queue=False)
submit_btn.click(conversation, \
inputs=[qa_chain, msg, chatbot, language], \
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
queue=False)
clear_btn.click(lambda:[None,"",0,"",0,"",0], \
inputs=None, \
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
queue=False)
demo.queue().launch(debug=True)
if __name__ == "__main__":
demo()