DHEIVER commited on
Commit
7536e7a
·
verified ·
1 Parent(s): 7c8c122

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -25
app.py CHANGED
@@ -11,41 +11,40 @@ from langchain.memory import ConversationBufferMemory
11
  import torch
12
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
13
 
14
- # List of local models (publicly accessible, no token required)
15
- list_llm = ["facebook/opt-350m", "distilbert/distilgpt2"]
16
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
17
 
18
- # Load and split PDF document
19
  def load_doc(list_file_path):
20
  loaders = [PyPDFLoader(x) for x in list_file_path]
21
  pages = []
22
  for loader in loaders:
23
  pages.extend(loader.load())
24
  text_splitter = RecursiveCharacterTextSplitter(
25
- chunk_size=1024,
26
- chunk_overlap=64
27
  )
28
  doc_splits = text_splitter.split_documents(pages)
29
  return doc_splits
30
 
31
- # Create vector database
32
  def create_db(splits):
33
  embeddings = HuggingFaceEmbeddings()
34
  vectordb = FAISS.from_documents(splits, embeddings)
35
  return vectordb
36
 
37
- # Initialize langchain LLM chain with local model
38
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
39
- # Load the model and tokenizer locally
40
  tokenizer = AutoTokenizer.from_pretrained(llm_model)
41
  model = AutoModelForCausalLM.from_pretrained(
42
  llm_model,
43
- device_map="auto", # Automatically use GPU if available
44
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Optimize for GPU or CPU
45
- trust_remote_code=True # Required for some models
46
  )
47
 
48
- # Create a pipeline for text generation
49
  pipe = pipeline(
50
  "text-generation",
51
  model=model,
@@ -53,12 +52,11 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
53
  max_new_tokens=max_tokens,
54
  temperature=temperature,
55
  top_k=top_k,
56
- do_sample=True,
57
  repetition_penalty=1.1,
58
  return_full_text=False
59
  )
60
 
61
- # Wrap the pipeline in HuggingFacePipeline for LangChain
62
  llm = HuggingFacePipeline(pipeline=pipe)
63
 
64
  memory = ConversationBufferMemory(
@@ -67,7 +65,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
67
  return_messages=True
68
  )
69
 
70
- retriever = vector_db.as_retriever()
71
  qa_chain = ConversationalRetrievalChain.from_llm(
72
  llm,
73
  retriever=retriever,
@@ -78,14 +76,14 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
78
  )
79
  return qa_chain
80
 
81
- # Initialize database
82
  def initialize_database(list_file_obj, progress=gr.Progress()):
83
  list_file_path = [x.name for x in list_file_obj if x is not None]
84
  doc_splits = load_doc(list_file_path)
85
  vector_db = create_db(doc_splits)
86
  return vector_db, "Database created!"
87
 
88
- # Initialize LLM
89
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
90
  llm_name = list_llm[llm_option]
91
  print("llm_name: ", llm_name)
@@ -108,10 +106,10 @@ def conversation(qa_chain, message, history):
108
  response_sources = response["source_documents"]
109
  response_source1 = response_sources[0].page_content.strip()
110
  response_source2 = response_sources[1].page_content.strip()
111
- response_source3 = response_sources[2].page_content.strip()
112
  response_source1_page = response_sources[0].metadata["page"] + 1
113
  response_source2_page = response_sources[1].metadata["page"] + 1
114
- response_source3_page = response_sources[2].metadata["page"] + 1
115
  new_history = history + [(message, response_answer)]
116
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
117
 
@@ -127,7 +125,7 @@ def demo():
127
  vector_db = gr.State()
128
  qa_chain = gr.State()
129
  gr.HTML("<center><h1>RAG PDF chatbot</h1><center>")
130
- gr.Markdown("""<b>Query your PDF documents!</b> This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents. This version runs locally and does not require an API token. \
131
  <b>Please do not upload confidential documents.</b>
132
  """)
133
  with gr.Row():
@@ -145,11 +143,11 @@ def demo():
145
  with gr.Row():
146
  with gr.Accordion("LLM input parameters", open=False):
147
  with gr.Row():
148
- slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
149
  with gr.Row():
150
- slider_maxtokens = gr.Slider(minimum=128, maximum=2048, value=512, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated", interactive=True)
151
  with gr.Row():
152
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k", info="Number of tokens to select the next token from", interactive=True)
153
  with gr.Row():
154
  qachain_btn = gr.Button("Initialize Question Answering Chatbot")
155
  with gr.Row():
@@ -174,7 +172,7 @@ def demo():
174
  submit_btn = gr.Button("Submit")
175
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
176
 
177
- # Preprocessing events
178
  db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress])
179
  qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(
180
  lambda: [None, "", 0, "", 0, "", 0],
@@ -183,7 +181,7 @@ def demo():
183
  queue=False
184
  )
185
 
186
- # Chatbot events
187
  msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
188
  submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
189
  clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
 
11
  import torch
12
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
13
 
14
+ # Lista de modelos públicos e leves
15
+ list_llm = ["EleutherAI/gpt-neo-125m", "distilbert/distilgpt2"]
16
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
17
 
18
+ # Carregar e dividir documento PDF
19
  def load_doc(list_file_path):
20
  loaders = [PyPDFLoader(x) for x in list_file_path]
21
  pages = []
22
  for loader in loaders:
23
  pages.extend(loader.load())
24
  text_splitter = RecursiveCharacterTextSplitter(
25
+ chunk_size=512, # Reduzido para acelerar a busca
26
+ chunk_overlap=32 # Menor sobreposição para menos processamento
27
  )
28
  doc_splits = text_splitter.split_documents(pages)
29
  return doc_splits
30
 
31
+ # Criar banco de vetores
32
  def create_db(splits):
33
  embeddings = HuggingFaceEmbeddings()
34
  vectordb = FAISS.from_documents(splits, embeddings)
35
  return vectordb
36
 
37
+ # Inicializar o chain LLM local
38
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
39
  tokenizer = AutoTokenizer.from_pretrained(llm_model)
40
  model = AutoModelForCausalLM.from_pretrained(
41
  llm_model,
42
+ device_map="auto", # Usa GPU se disponível
43
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Otimiza para GPU
44
+ trust_remote_code=True
45
  )
46
 
47
+ # Pipeline otimizado
48
  pipe = pipeline(
49
  "text-generation",
50
  model=model,
 
52
  max_new_tokens=max_tokens,
53
  temperature=temperature,
54
  top_k=top_k,
55
+ do_sample=False, # Greedy decoding para mais velocidade
56
  repetition_penalty=1.1,
57
  return_full_text=False
58
  )
59
 
 
60
  llm = HuggingFacePipeline(pipeline=pipe)
61
 
62
  memory = ConversationBufferMemory(
 
65
  return_messages=True
66
  )
67
 
68
+ retriever = vector_db.as_retriever(search_kwargs={"k": 2}) # Reduzir número de documentos retornados
69
  qa_chain = ConversationalRetrievalChain.from_llm(
70
  llm,
71
  retriever=retriever,
 
76
  )
77
  return qa_chain
78
 
79
+ # Inicializar banco de dados
80
  def initialize_database(list_file_obj, progress=gr.Progress()):
81
  list_file_path = [x.name for x in list_file_obj if x is not None]
82
  doc_splits = load_doc(list_file_path)
83
  vector_db = create_db(doc_splits)
84
  return vector_db, "Database created!"
85
 
86
+ # Inicializar LLM
87
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
88
  llm_name = list_llm[llm_option]
89
  print("llm_name: ", llm_name)
 
106
  response_sources = response["source_documents"]
107
  response_source1 = response_sources[0].page_content.strip()
108
  response_source2 = response_sources[1].page_content.strip()
109
+ response_source3 = "" # Menos referências para acelerar
110
  response_source1_page = response_sources[0].metadata["page"] + 1
111
  response_source2_page = response_sources[1].metadata["page"] + 1
112
+ response_source3_page = 0
113
  new_history = history + [(message, response_answer)]
114
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
115
 
 
125
  vector_db = gr.State()
126
  qa_chain = gr.State()
127
  gr.HTML("<center><h1>RAG PDF chatbot</h1><center>")
128
+ gr.Markdown("""<b>Query your PDF documents!</b> This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents. Optimized for speed without an API token. \
129
  <b>Please do not upload confidential documents.</b>
130
  """)
131
  with gr.Row():
 
143
  with gr.Row():
144
  with gr.Accordion("LLM input parameters", open=False):
145
  with gr.Row():
146
+ slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness (ignored with greedy decoding)", interactive=True)
147
  with gr.Row():
148
+ slider_maxtokens = gr.Slider(minimum=64, maximum=512, value=128, step=64, label="Max New Tokens", info="Maximum number of tokens to be generated", interactive=True)
149
  with gr.Row():
150
+ slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k", info="Number of tokens to select (ignored with greedy decoding)", interactive=True)
151
  with gr.Row():
152
  qachain_btn = gr.Button("Initialize Question Answering Chatbot")
153
  with gr.Row():
 
172
  submit_btn = gr.Button("Submit")
173
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
174
 
175
+ # Eventos de pré-processamento
176
  db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress])
177
  qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(
178
  lambda: [None, "", 0, "", 0, "", 0],
 
181
  queue=False
182
  )
183
 
184
+ # Eventos do chatbot
185
  msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
186
  submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
187
  clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)