Spaces:
Running
Running
from langchain_huggingface import HuggingFaceEmbeddings | |
import gradio as gr | |
import os | |
from googletrans import Translator | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.document_loaders import UnstructuredPDFLoader, PyPDFLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.schema import Document | |
from langchain.memory import ConversationBufferMemory | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.llms.base import LLM | |
from typing import List, Dict, Any, Optional | |
from pydantic import BaseModel | |
from langchain.llms.base import LLM | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import logging | |
# Configurazione del logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Aggiornamento dell'inizializzazione di HuggingFaceEmbeddings | |
embedding_function = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# Definizione della lista di modelli LLM | |
list_llm_simple = ["Gemma 7B (Italian)", "Mistral 7B"] | |
list_llm = ["google/gemma-7b-it", "mistralai/Mistral-7B-Instruct-v0.2"] | |
def initialize_database(document, chunk_size, chunk_overlap, progress=gr.Progress()): | |
logger.info("Initializing database...") | |
documents = [] | |
splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
for file in document: | |
try: | |
loader = UnstructuredPDFLoader(file.name) | |
docs = loader.load() | |
except ImportError: | |
logger.warning("UnstructuredPDFLoader non disponibile. Tentativo di utilizzo di PyPDFLoader.") | |
try: | |
loader = PyPDFLoader(file.name) | |
docs = loader.load() | |
except ImportError: | |
logger.error("Impossibile caricare il documento PDF. Assicurati di aver installato 'unstructured' o 'pypdf'.") | |
return None, None, "Errore: Pacchetti necessari non installati. Esegui 'pip install unstructured pypdf' e riprova." | |
for doc in docs: | |
text_chunks = splitter.split_text(doc.page_content) | |
for chunk in text_chunks: | |
documents.append(Document(page_content=chunk, metadata={"filename": file.name, "page": doc.metadata.get("page", 0)})) | |
if not documents: | |
return None, None, "Errore: Nessun documento caricato correttamente." | |
vectorstore = Chroma.from_documents(documents, embedding_function) | |
progress.update(0.5) | |
logger.info("Database initialized successfully.") | |
return vectorstore, None, "Initialized" # Aggiunto None come secondo output | |
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress(), language="italiano"): | |
logger.info("Initializing LLM chain...") | |
# Define the default LLMS based on the language | |
if language == "italiano": | |
default_llm = "google/gemma-7b-it" | |
else: | |
default_llm = "google/gemma-7b" # English version | |
# Try to load the tokenizer and model with authentication | |
try: | |
# Option 1: Using HF_TOKEN environment variable | |
hf_token = os.getenv("HF_TOKEN") | |
if not hf_token: | |
raise ValueError("HF_TOKEN environment variable is not set") | |
tokenizer = AutoTokenizer.from_pretrained(default_llm, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(default_llm, token=hf_token) | |
except Exception as e: | |
logger.error(f"Error initializing LLM: {e}") | |
return None, "Failed to initialize LLM" | |
# Resize token embeddings if needed | |
if len(tokenizer) > model.config.max_position_embeddings: | |
model.resize_token_embeddings(len(tokenizer)) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=model, | |
retriever=vector_db.as_retriever(), | |
chain_type="stuff", | |
temperature=llm_temperature, | |
verbose=False, | |
) | |
progress.update(1.0) | |
logger.info("LLM chain initialized successfully.") | |
return qa_chain, "Complete!" | |
def format_chat_history(message, history): | |
chat_history = "" | |
for item in history: | |
chat_history += f"\nUser: {item[0]}\nAI: {item[1]}" | |
chat_history += f"\n\nUser: {message}" | |
return chat_history | |
def translate_text(text, src_lang, dest_lang): | |
translator = Translator() | |
result = translator.translate(text, src=src_lang, dest=dest_lang) | |
return result.text | |
def conversation(qa_chain, message, history, language): | |
formatted_chat_history = format_chat_history(message, history) | |
response = qa_chain({"question": message, "chat_history": formatted_chat_history}) | |
response_answer = response["answer"] | |
if response_answer.find("Helpful Answer:")!= -1: | |
response_answer = response_answer.split("Helpful Answer:")[-1] | |
if language != "italian": | |
try: | |
translated_response = translate_text(response_answer, src="en", dest="it") | |
except Exception as e: | |
logger.error(f"Error translating response: {e}") | |
translated_response = response_answer | |
else: | |
translated_response = response_answer | |
response_sources = response["source_documents"] | |
response_source1 = response_sources[0].page_content.strip() | |
response_source2 = response_sources[1].page_content.strip() | |
response_source3 = response_sources[2].page_content.strip() | |
response_source1_page = response_sources[0].metadata["page"] + 1 | |
response_source2_page = response_sources[1].metadata["page"] + 1 | |
response_source3_page = response_sources[2].metadata["page"] + 1 | |
new_history = history + [(message, translated_response)] | |
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page | |
def demo(): | |
with gr.Blocks(theme="base") as demo: | |
vector_db = gr.State() | |
qa_chain = gr.State() | |
collection_name = gr.State() | |
language = gr.State(value="italian") # Modifica qui | |
gr.Markdown( | |
"""<center><h2>Chatbot basato su PDF</center></h2> | |
<h3>Fai domande sui tuoi documenti PDF</h3>""") | |
gr.Markdown( | |
"""<b>Note:</b> Questo assistente AI, utilizzando Langchain e LLM open-source, esegue retrieval-augmented generation (RAG) dai tuoi documenti PDF. \ | |
L'interfaccia utente mostra esplicitamente più passaggi per aiutare a comprendere il flusso di lavoro RAG. | |
Questo chatbot tiene conto delle domande precedenti quando genera risposte (tramite memoria conversazionale), e include riferimenti al documento per scopi di chiarezza.<br> | |
<br><b>Avviso:</b> Questo spazio utilizza l'hardware CPU Basic gratuito da Hugging Face. Alcuni passaggi e modelli LLM utilizzati qui sotto (endpoint di inferenza gratuiti) possono richiedere del tempo per generare una risposta. | |
""") | |
with gr.Tab("Step 1 - Carica PDF"): | |
with gr.Row(): | |
document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carica i tuoi documenti PDF (singolo o multiplo)") | |
with gr.Tab("Step 2 - Processa documento"): | |
with gr.Row(): | |
db_btn = gr.Radio(["ChromaDB"], label="Tipo di database vettoriale", value = "ChromaDB", type="index", info="Scegli il tuo database vettoriale") | |
with gr.Accordion("Opzioni avanzate - Divisore testo documento", open=False): | |
with gr.Row(): | |
slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Dimensione chunk", info="Dimensione chunk", interactive=True) | |
with gr.Row(): | |
slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label=" Sovrapposizione chunk", info="Sovrapposizione chunk", interactive=True) | |
with gr.Row(): | |
db_progress = gr.Textbox(label="Inizializzazione database vettoriale", value="Nessuno") | |
with gr.Row(): | |
db_btn = gr.Button("Genera database vettoriale") | |
with gr.Tab("Step 3 - Inizializza catena QA"): | |
with gr.Row(): | |
llm_btn = gr.Radio(list_llm_simple, \ | |
label="Modelli LLM", value = list_llm_simple[0], type="index", info="Scegli il tuo modello LLM") | |
with gr.Accordion("Opzioni avanzate - Modello LLM", open=False): | |
with gr.Row(): | |
slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperatura", info="Temperatura del modello", interactive=True) | |
with gr.Row(): | |
slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Token massimi", info="Token massimi del modello", interactive=True) | |
with gr.Row(): | |
slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="Campioni top-k", info="Campioni top-k del modello", interactive=True) | |
with gr.Row(): | |
llm_progress = gr.Textbox(value="Nessuno",label="Inizializzazione catena QA") | |
with gr.Row(): | |
qachain_btn = gr.Button("Inizializza catena Question Answering") | |
with gr.Tab("Step 4 - Chatbot"): | |
chatbot = gr.Chatbot(height=300) | |
with gr.Accordion("Avanzate - Riferimenti documento", open=False): | |
with gr.Row(): | |
doc_source1 = gr.Textbox(label="Riferimento 1", lines=2, container=True, scale=20) | |
source1_page = gr.Number(label="Pagina", scale=1) | |
with gr.Row(): | |
doc_source2 = gr.Textbox(label="Riferimento 2", lines=2, container=True, scale=20) | |
source2_page = gr.Number(label="Pagina", scale=1) | |
with gr.Row(): | |
doc_source3 = gr.Textbox(label="Riferimento 3", lines=2, container=True, scale=20) | |
source3_page = gr.Number(label="Pagina", scale=1) | |
with gr.Row(): | |
msg = gr.Textbox(placeholder="Digita un messaggio (es. 'Di cosa parla questo documento?')", container=True) | |
with gr.Row(): | |
submit_btn = gr.Button("Invia messaggio") | |
clear_btn = gr.ClearButton([msg, chatbot], value="Pulisci conversazione") | |
with gr.Row(): | |
language_selector = gr.Radio(choices=["italiano", "inglese"], value="italiano", label="Lingua") | |
# Preprocessing events | |
db_btn.click(initialize_database, \ | |
inputs=[document, slider_chunk_size, slider_chunk_overlap], \ | |
outputs=[vector_db, collection_name, db_progress]) | |
qachain_btn.click(initialize_LLM, \ | |
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, language], \ | |
outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \ | |
inputs=None, \ | |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \ | |
queue=False) | |
# Chatbot events | |
msg.submit(conversation, \ | |
inputs=[qa_chain, msg, chatbot, language], \ | |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \ | |
queue=False) | |
submit_btn.click(conversation, \ | |
inputs=[qa_chain, msg, chatbot, language], \ | |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \ | |
queue=False) | |
clear_btn.click(lambda:[None,"",0,"",0,"",0], \ | |
inputs=None, \ | |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \ | |
queue=False) | |
demo.queue().launch(debug=True) | |
if __name__ == "__main__": | |
demo() | |