Mattral commited on
Commit
32e37e9
·
verified ·
1 Parent(s): d1e2ffe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -54
app.py CHANGED
@@ -6,10 +6,7 @@ from PyPDF2 import PdfReader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.callbacks.manager import CallbackManager
8
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
9
- from langchain.llms import LlamaCpp
10
-
11
- from langchain.vectorstores import Qdrant
12
- from transformers import AutoModelForCausalLM
13
 
14
  # Load the embedding model
15
  encoder = SentenceTransformer('jinaai/jina-embedding-b-en-v1')
@@ -17,8 +14,6 @@ print("Embedding model loaded...")
17
 
18
  # Load the LLM
19
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
20
-
21
- '''
22
  llm = AutoModelForCausalLM.from_pretrained(
23
  "TheBloke/Llama-2-7B-Chat-GGUF",
24
  model_file="llama-2-7b-chat.Q3_K_S.gguf",
@@ -27,32 +22,25 @@ llm = AutoModelForCausalLM.from_pretrained(
27
  repetition_penalty=1.5,
28
  max_new_tokens=300,
29
  )
30
- '''
31
- llm = LlamaCpp(
32
- model_path="./llama-2-7b-chat.Q3_K_S.gguf",
33
- temperature = 0.2,
34
- n_ctx=2048,
35
- f16_kv=True, # MUST set to True, otherwise you will run into problem after a couple of calls
36
- max_tokens = 500,
37
- callback_manager=callback_manager,
38
- verbose=True,
39
- )
40
  print("LLM loaded...")
41
 
42
- client = QdrantClient(path="./db")
 
 
 
 
 
 
43
 
44
  def setup_database(files):
45
  all_chunks = []
46
  for file in files:
47
- pdf_path = file
48
- reader = PdfReader(pdf_path)
49
- text = "".join(page.extract_text() for page in reader.pages if page.extract_text())
50
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=50, length_function=len)
51
- chunks = text_splitter.split_text(text)
52
  all_chunks.extend(chunks)
53
-
54
- print(f"Total chunks: {len(all_chunks)}")
55
-
56
  client.recreate_collection(
57
  collection_name="my_facts",
58
  vectors_config=models.VectorParams(
@@ -60,51 +48,64 @@ def setup_database(files):
60
  distance=models.Distance.COSINE,
61
  ),
62
  )
63
-
64
- print("Collection created...")
65
 
66
- for idx, chunk in enumerate(all_chunks):
67
- client.upload_record(
68
- collection_name="my_facts",
69
- record=models.Record(
70
- id=idx,
71
- vector=encoder.encode(chunk).tolist(),
72
- payload={"text": chunk}
73
- )
74
- )
75
 
76
- print("Records uploaded...")
 
 
 
77
 
78
- def answer(question):
 
79
  hits = client.search(
80
  collection_name="my_facts",
81
  query_vector=encoder.encode(question).tolist(),
82
  limit=3
83
  )
84
 
85
- context = " ".join(hit.payload["text"] for hit in hits)
86
- system_prompt = "You are a helpful co-worker. Use the provided context to answer user questions. Do not use any other information."
87
- prompt = f"Context: {context}\nUser: {question}\n{system_prompt}"
88
- response = llm(prompt)
 
 
 
 
 
 
 
 
 
89
  return response
90
 
91
- def chat(messages):
92
- if not messages:
93
- return "Please upload PDF documents to initialize the database."
94
- last_message = messages[-1]
95
- return answer(last_message["message"])
 
 
 
96
 
97
- screen = gr.Interface(
98
  fn=chat,
99
- inputs=gr.Textbox(placeholder="Type your question here..."),
100
- outputs="chatbot",
 
 
 
101
  title="Q&A with PDFs 👩🏻‍💻📓✍🏻💡",
102
  description="This app facilitates a conversation with PDFs uploaded💡",
103
  theme="soft",
 
104
  live=True,
105
- allow_flagging=False,
106
  )
107
 
108
-
109
- # Add a way to upload and setup the database before starting the chat
110
- screen.launch()
 
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.callbacks.manager import CallbackManager
8
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
9
+ from ctransformers import AutoModelForCausalLM
 
 
 
10
 
11
  # Load the embedding model
12
  encoder = SentenceTransformer('jinaai/jina-embedding-b-en-v1')
 
14
 
15
  # Load the LLM
16
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 
 
17
  llm = AutoModelForCausalLM.from_pretrained(
18
  "TheBloke/Llama-2-7B-Chat-GGUF",
19
  model_file="llama-2-7b-chat.Q3_K_S.gguf",
 
22
  repetition_penalty=1.5,
23
  max_new_tokens=300,
24
  )
 
 
 
 
 
 
 
 
 
 
25
  print("LLM loaded...")
26
 
27
+ def get_chunks(text):
28
+ text_splitter = RecursiveCharacterTextSplitter(
29
+ chunk_size=250,
30
+ chunk_overlap=50,
31
+ length_function=len,
32
+ )
33
+ return text_splitter.split_text(text)
34
 
35
  def setup_database(files):
36
  all_chunks = []
37
  for file in files:
38
+ reader = PdfReader(file)
39
+ text = "".join(page.extract_text() for page in reader.pages)
40
+ chunks = get_chunks(text)
 
 
41
  all_chunks.extend(chunks)
42
+
43
+ client = QdrantClient(path="./db")
 
44
  client.recreate_collection(
45
  collection_name="my_facts",
46
  vectors_config=models.VectorParams(
 
48
  distance=models.Distance.COSINE,
49
  ),
50
  )
 
 
51
 
52
+ records = [
53
+ models.Record(
54
+ id=idx,
55
+ vector=encoder.encode(chunk).tolist(),
56
+ payload={f"chunk_{idx}": chunk}
57
+ ) for idx, chunk in enumerate(all_chunks)
58
+ ]
 
 
59
 
60
+ client.upload_records(
61
+ collection_name="my_facts",
62
+ records=records,
63
+ )
64
 
65
+ def answer_question(question):
66
+ client = QdrantClient(path="./db")
67
  hits = client.search(
68
  collection_name="my_facts",
69
  query_vector=encoder.encode(question).tolist(),
70
  limit=3
71
  )
72
 
73
+ context = " ".join(hit.payload[f"chunk_{hit.id}"] for hit in hits)
74
+
75
+ system_prompt = """You are a helpful co-worker, you will use the provided context to answer user questions.
76
+ Read the given context before answering questions and think step by step. If you cannot answer a user question based on
77
+ the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question."""
78
+
79
+ B_INST, E_INST = "[INST]", "[/INST]"
80
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
81
+
82
+ instruction = f"Context: {context}\nUser: {question}"
83
+ prompt_template = f"{B_INST}{B_SYS}{system_prompt}{E_SYS}{instruction}{E_INST}"
84
+
85
+ response = llm(prompt_template)
86
  return response
87
 
88
+ def chat(messages, files):
89
+ if files:
90
+ setup_database(files)
91
+ if messages:
92
+ question = messages[-1]["text"]
93
+ answer = answer_question(question)
94
+ messages.append({"text": answer, "is_user": False})
95
+ return messages
96
 
97
+ interface = gr.Interface(
98
  fn=chat,
99
+ inputs=[
100
+ gr.Chatbot(label="Chat"),
101
+ gr.File(label="Upload PDFs", file_count="multiple")
102
+ ],
103
+ outputs=gr.Chatbot(label="Chat"),
104
  title="Q&A with PDFs 👩🏻‍💻📓✍🏻💡",
105
  description="This app facilitates a conversation with PDFs uploaded💡",
106
  theme="soft",
107
+ share=True,
108
  live=True,
 
109
  )
110
 
111
+ interface.launch()