DrishtiSharma commited on
Commit
ddb6580
Β·
verified Β·
1 Parent(s): 24246f0

Update interim/app.py

Browse files
Files changed (1) hide show
  1. interim/app.py +182 -0
interim/app.py CHANGED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import re
4
+ import shutil
5
+ import time
6
+ import streamlit as st
7
+ import nltk
8
+ import tempfile
9
+
10
+ # Set up temporary directory for NLTK resources
11
+ nltk_data_path = os.path.join(tempfile.gettempdir(), "nltk_data")
12
+ os.makedirs(nltk_data_path, exist_ok=True)
13
+ nltk.data.path = [nltk_data_path] # Force NLTK to use only the temp directory
14
+
15
+ # Force clean download of 'punkt'
16
+ try:
17
+ print("Ensuring NLTK 'punkt' resource is downloaded...")
18
+ if not os.path.exists(os.path.join(nltk_data_path, "tokenizers/punkt")):
19
+ nltk.download("punkt", download_dir=nltk_data_path)
20
+ except Exception as e:
21
+ print(f"Error downloading NLTK 'punkt': {e}")
22
+ raise e
23
+
24
+ sys.path.append(os.path.abspath("."))
25
+ from langchain.chains import ConversationalRetrievalChain
26
+ from langchain.memory import ConversationBufferMemory
27
+ from langchain.llms import OpenAI
28
+ from langchain.document_loaders import UnstructuredPDFLoader
29
+ from langchain.vectorstores import Chroma
30
+ from langchain.embeddings import HuggingFaceEmbeddings
31
+ from langchain.text_splitter import NLTKTextSplitter
32
+ from patent_downloader import PatentDownloader
33
+
34
+ PERSISTED_DIRECTORY = tempfile.mkdtemp()
35
+
36
+ # Fetch API key securely from the environment
37
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
38
+ if not OPENAI_API_KEY:
39
+ st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
40
+ st.stop()
41
+
42
+ def check_poppler_installed():
43
+ if not shutil.which("pdfinfo"):
44
+ raise EnvironmentError(
45
+ "Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing."
46
+ )
47
+
48
+ check_poppler_installed()
49
+
50
+ def load_docs(document_path):
51
+ try:
52
+ loader = UnstructuredPDFLoader(
53
+ document_path,
54
+ mode="elements",
55
+ strategy="fast",
56
+ ocr_languages=None
57
+ )
58
+ documents = loader.load()
59
+ text_splitter = NLTKTextSplitter(chunk_size=1000)
60
+ return text_splitter.split_documents(documents)
61
+ except Exception as e:
62
+ st.error(f"Failed to load and process PDF: {e}")
63
+ st.stop()
64
+
65
+ def already_indexed(vectordb, file_name):
66
+ indexed_sources = set(
67
+ x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
68
+ )
69
+ return file_name in indexed_sources
70
+
71
+ def load_chain(file_name=None):
72
+ loaded_patent = st.session_state.get("LOADED_PATENT")
73
+
74
+ vectordb = Chroma(
75
+ persist_directory=PERSISTED_DIRECTORY,
76
+ embedding_function=HuggingFaceEmbeddings(),
77
+ )
78
+ if loaded_patent == file_name or already_indexed(vectordb, file_name):
79
+ st.write("βœ… Already indexed.")
80
+ else:
81
+ vectordb.delete_collection()
82
+ docs = load_docs(file_name)
83
+ st.write("πŸ” Number of Documents: ", len(docs))
84
+
85
+ vectordb = Chroma.from_documents(
86
+ docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
87
+ )
88
+ vectordb.persist()
89
+ st.session_state["LOADED_PATENT"] = file_name
90
+
91
+ memory = ConversationBufferMemory(
92
+ memory_key="chat_history",
93
+ return_messages=True,
94
+ input_key="question",
95
+ output_key="answer",
96
+ )
97
+ return ConversationalRetrievalChain.from_llm(
98
+ OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
99
+ vectordb.as_retriever(search_kwargs={"k": 3}),
100
+ return_source_documents=False,
101
+ memory=memory,
102
+ )
103
+
104
+ def extract_patent_number(url):
105
+ pattern = r"/patent/([A-Z]{2}\d+)"
106
+ match = re.search(pattern, url)
107
+ return match.group(1) if match else None
108
+
109
+ def download_pdf(patent_number):
110
+ try:
111
+ patent_downloader = PatentDownloader(verbose=True)
112
+ output_path = patent_downloader.download(patents=patent_number, output_path=tempfile.gettempdir())
113
+ return output_path[0]
114
+ except Exception as e:
115
+ st.error(f"Failed to download patent PDF: {e}")
116
+ st.stop()
117
+
118
+ if __name__ == "__main__":
119
+ st.set_page_config(
120
+ page_title="Patent Chat: Google Patents Chat Demo",
121
+ page_icon="πŸ“–",
122
+ layout="wide",
123
+ initial_sidebar_state="expanded",
124
+ )
125
+ st.header("πŸ“– Patent Chat: Google Patents Chat Demo")
126
+
127
+ patent_link = st.text_input("Enter Google Patent Link:", key="PATENT_LINK")
128
+
129
+ if not patent_link:
130
+ st.warning("Please enter a Google patent link to proceed.")
131
+ st.stop()
132
+
133
+ patent_number = extract_patent_number(patent_link)
134
+ if not patent_number:
135
+ st.error("Invalid patent link format. Please provide a valid Google patent link.")
136
+ st.stop()
137
+
138
+ st.write(f"Patent number: **{patent_number}**")
139
+
140
+ pdf_path = os.path.join(tempfile.gettempdir(), f"{patent_number}.pdf")
141
+ if os.path.isfile(pdf_path):
142
+ st.write("βœ… File already downloaded.")
143
+ else:
144
+ st.write("πŸ“₯ Downloading patent file...")
145
+ pdf_path = download_pdf(patent_number)
146
+ st.write(f"βœ… File downloaded: {pdf_path}")
147
+
148
+ st.write("πŸ”„ Loading document into the system...")
149
+ chain = load_chain(pdf_path)
150
+ st.success("πŸš€ Document successfully loaded! You can now start asking questions.")
151
+
152
+ if "messages" not in st.session_state:
153
+ st.session_state["messages"] = [
154
+ {"role": "assistant", "content": "Hello! How can I assist you with this patent?"}
155
+ ]
156
+
157
+ for message in st.session_state.messages:
158
+ with st.chat_message(message["role"]):
159
+ st.markdown(message["content"])
160
+
161
+ if user_input := st.chat_input("What is your question?"):
162
+ st.session_state.messages.append({"role": "user", "content": user_input})
163
+ with st.chat_message("user"):
164
+ st.markdown(user_input)
165
+
166
+ with st.chat_message("assistant"):
167
+ message_placeholder = st.empty()
168
+ full_response = ""
169
+
170
+ with st.spinner("Generating response..."):
171
+ try:
172
+ assistant_response = chain({"question": user_input})
173
+ for chunk in assistant_response["answer"].split():
174
+ full_response += chunk + " "
175
+ time.sleep(0.05)
176
+ message_placeholder.markdown(full_response + "β–Œ")
177
+ except Exception as e:
178
+ full_response = f"An error occurred: {e}"
179
+ finally:
180
+ message_placeholder.markdown(full_response)
181
+
182
+ st.session_state.messages.append({"role": "assistant", "content": full_response})