Mattral commited on
Commit
1057268
·
verified ·
1 Parent(s): 88569b0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_pdf import PDF
3
+ from qdrant_client import models, QdrantClient
4
+ from sentence_transformers import SentenceTransformer
5
+ from PyPDF2 import PdfReader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.callbacks.manager import CallbackManager
8
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
9
+ from langchain.vectorstores import Qdrant
10
+ from transformers import AutoModelForCausalLM
11
+
12
+ # Load the embedding model
13
+ encoder = SentenceTransformer('jinaai/jina-embedding-b-en-v1')
14
+ print("Embedding model loaded...")
15
+
16
+ # Load the LLM
17
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
18
+ llm = AutoModelForCausalLM.from_pretrained(
19
+ "TheBloke/Llama-2-7B-Chat-GGUF",
20
+ model_file="llama-2-7b-chat.Q3_K_S.gguf",
21
+ model_type="llama",
22
+ temperature=0.2,
23
+ repetition_penalty=1.5,
24
+ max_new_tokens=300,
25
+ )
26
+ print("LLM loaded...")
27
+
28
+ client = QdrantClient(path="./db")
29
+
30
+ def setup_database(files):
31
+ all_chunks = []
32
+ for file in files:
33
+ pdf_path = file
34
+ reader = PdfReader(pdf_path)
35
+ text = "".join(page.extract_text() for page in reader.pages if page.extract_text())
36
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=50, length_function=len)
37
+ chunks = text_splitter.split_text(text)
38
+ all_chunks.extend(chunks)
39
+
40
+ print(f"Total chunks: {len(all_chunks)}")
41
+
42
+ client.recreate_collection(
43
+ collection_name="my_facts",
44
+ vectors_config=models.VectorParams(
45
+ size=encoder.get_sentence_embedding_dimension(),
46
+ distance=models.Distance.COSINE,
47
+ ),
48
+ )
49
+
50
+ print("Collection created...")
51
+
52
+ for idx, chunk in enumerate(all_chunks):
53
+ client.upload_record(
54
+ collection_name="my_facts",
55
+ record=models.Record(
56
+ id=idx,
57
+ vector=encoder.encode(chunk).tolist(),
58
+ payload={"text": chunk}
59
+ )
60
+ )
61
+
62
+ print("Records uploaded...")
63
+
64
+ def answer(question):
65
+ hits = client.search(
66
+ collection_name="my_facts",
67
+ query_vector=encoder.encode(question).tolist(),
68
+ limit=3
69
+ )
70
+
71
+ context = " ".join(hit.payload["text"] for hit in hits)
72
+ system_prompt = "You are a helpful co-worker. Use the provided context to answer user questions. Do not use any other information."
73
+ prompt = f"Context: {context}\nUser: {question}\n{system_prompt}"
74
+ response = llm(prompt)
75
+ return response
76
+
77
+ def chat(messages):
78
+ if not messages:
79
+ return "Please upload PDF documents to initialize the database."
80
+ last_message = messages[-1]
81
+ return answer(last_message["message"])
82
+
83
+ screen = gr.Interface(
84
+ fn=chat,
85
+ inputs=gr.Textbox(placeholder="Type your question here..."),
86
+ outputs="chatbot",
87
+ title="Q&A with PDFs 👩🏻‍💻📓✍🏻💡",
88
+ description="This app facilitates a conversation with PDFs uploaded💡",
89
+ theme="soft",
90
+ live=True,
91
+ allow_screenshot=False,
92
+ allow_flagging=False,
93
+ )
94
+
95
+ # Add a way to upload and setup the database before starting the chat
96
+ screen.launch()