DrishtiSharma commited on
Commit
e90d440
Β·
verified Β·
1 Parent(s): fc8a155

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -83
app.py CHANGED
@@ -2,22 +2,10 @@ import sys
2
  import os
3
  import re
4
  import time
 
5
  import streamlit as st
6
  import nltk
7
- from io import BytesIO
8
-
9
- # Force NLTK to download 'punkt' into a virtual, in-memory location
10
- try:
11
- from nltk.data import load
12
- print("Downloading 'punkt' tokenizer to memory...")
13
- nltk.download("punkt")
14
- load("tokenizers/punkt/english.pickle")
15
- print("βœ… 'punkt' successfully loaded into memory.")
16
- except Exception as e:
17
- print(f"Error loading 'punkt': {e}")
18
- raise e
19
-
20
- sys.path.append(os.path.abspath("."))
21
  from langchain.chains import ConversationalRetrievalChain
22
  from langchain.memory import ConversationBufferMemory
23
  from langchain.llms import OpenAI
@@ -27,54 +15,52 @@ from langchain.embeddings import HuggingFaceEmbeddings
27
  from langchain.text_splitter import NLTKTextSplitter
28
  from patent_downloader import PatentDownloader
29
 
30
- PERSISTED_DIRECTORY = os.path.join(os.getcwd(), "chroma_db")
 
31
 
32
- # Fetch API key securely from the environment
33
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
34
  if not OPENAI_API_KEY:
35
- st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
36
  st.stop()
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def load_docs(document_path):
 
39
  try:
40
- loader = UnstructuredPDFLoader(
41
- document_path,
42
- mode="elements",
43
- strategy="fast",
44
- ocr_languages=None
45
- )
46
  documents = loader.load()
47
  text_splitter = NLTKTextSplitter(chunk_size=1000)
48
  return text_splitter.split_documents(documents)
49
  except Exception as e:
50
- st.error(f"Failed to load and process PDF: {e}")
51
- st.stop()
52
-
53
- def already_indexed(vectordb, file_name):
54
- indexed_sources = set(
55
- x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
56
- )
57
- return file_name in indexed_sources
58
 
59
- def load_chain(file_name=None):
60
- loaded_patent = st.session_state.get("LOADED_PATENT")
61
 
62
- vectordb = Chroma(
63
- persist_directory=PERSISTED_DIRECTORY,
64
- embedding_function=HuggingFaceEmbeddings(),
 
65
  )
66
- if loaded_patent == file_name or already_indexed(vectordb, file_name):
67
- st.write("βœ… Already indexed.")
68
- else:
69
- vectordb.delete_collection()
70
- docs = load_docs(file_name)
71
- st.write("πŸ” Number of Documents: ", len(docs))
72
-
73
- vectordb = Chroma.from_documents(
74
- docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
75
- )
76
- vectordb.persist()
77
- st.session_state["LOADED_PATENT"] = file_name
78
 
79
  memory = ConversationBufferMemory(
80
  memory_key="chat_history",
@@ -82,6 +68,7 @@ def load_chain(file_name=None):
82
  input_key="question",
83
  output_key="answer",
84
  )
 
85
  return ConversationalRetrievalChain.from_llm(
86
  OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
87
  vectordb.as_retriever(search_kwargs={"k": 3}),
@@ -89,20 +76,8 @@ def load_chain(file_name=None):
89
  memory=memory,
90
  )
91
 
92
- def extract_patent_number(url):
93
- pattern = r"/patent/([A-Z]{2}\d+)"
94
- match = re.search(pattern, url)
95
- return match.group(1) if match else None
96
-
97
- def download_pdf(patent_number):
98
- try:
99
- patent_downloader = PatentDownloader(verbose=True)
100
- output_path = patent_downloader.download(patents=patent_number, output_path="/tmp")
101
- return output_path[0]
102
- except Exception as e:
103
- st.error(f"Failed to download patent PDF: {e}")
104
- st.stop()
105
 
 
106
  if __name__ == "__main__":
107
  st.set_page_config(
108
  page_title="Patent Chat: Google Patents Chat Demo",
@@ -110,8 +85,10 @@ if __name__ == "__main__":
110
  layout="wide",
111
  initial_sidebar_state="expanded",
112
  )
 
113
  st.header("πŸ“– Patent Chat: Google Patents Chat Demo")
114
 
 
115
  patent_link = st.text_input("Enter Google Patent Link:", key="PATENT_LINK")
116
 
117
  if not patent_link:
@@ -123,48 +100,64 @@ if __name__ == "__main__":
123
  st.error("Invalid patent link format. Please provide a valid Google patent link.")
124
  st.stop()
125
 
126
- st.write(f"Patent number: **{patent_number}**")
127
 
128
- pdf_path = os.path.join("/tmp", f"{patent_number}.pdf")
129
- if os.path.isfile(pdf_path):
130
- st.write("βœ… File already downloaded.")
131
- else:
132
- st.write("πŸ“₯ Downloading patent file...")
133
  pdf_path = download_pdf(patent_number)
134
- st.write(f"βœ… File downloaded: {pdf_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- st.write("πŸ”„ Loading document into the system...")
137
- chain = load_chain(pdf_path)
138
  st.success("πŸš€ Document successfully loaded! You can now start asking questions.")
139
 
 
140
  if "messages" not in st.session_state:
141
  st.session_state["messages"] = [
142
  {"role": "assistant", "content": "Hello! How can I assist you with this patent?"}
143
  ]
144
 
 
145
  for message in st.session_state.messages:
146
  with st.chat_message(message["role"]):
147
  st.markdown(message["content"])
148
 
 
149
  if user_input := st.chat_input("What is your question?"):
150
  st.session_state.messages.append({"role": "user", "content": user_input})
 
151
  with st.chat_message("user"):
152
  st.markdown(user_input)
153
 
154
  with st.chat_message("assistant"):
155
  message_placeholder = st.empty()
156
- full_response = ""
157
-
158
- with st.spinner("Generating response..."):
159
- try:
160
- assistant_response = chain({"question": user_input})
161
- for chunk in assistant_response["answer"].split():
162
- full_response += chunk + " "
163
- time.sleep(0.05)
164
- message_placeholder.markdown(full_response + "β–Œ")
165
- except Exception as e:
166
- full_response = f"An error occurred: {e}"
167
- finally:
168
- message_placeholder.markdown(full_response)
169
 
170
  st.session_state.messages.append({"role": "assistant", "content": full_response})
 
2
  import os
3
  import re
4
  import time
5
+ import tempfile
6
  import streamlit as st
7
  import nltk
8
+
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.llms import OpenAI
 
15
  from langchain.text_splitter import NLTKTextSplitter
16
  from patent_downloader import PatentDownloader
17
 
18
+ # Download NLTK resources
19
+ nltk.download("punkt", quiet=True)
20
 
21
+ #fetch API key
22
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
23
  if not OPENAI_API_KEY:
24
+ st.error("Critical Error: OpenAI API key not found in environment variables. Please configure it.")
25
  st.stop()
26
 
27
+
28
+ def extract_patent_number(url):
29
+ """Extracts patent number from a Google patent link."""
30
+ pattern = r"/patent/([A-Z]{2}\d+)"
31
+ match = re.search(pattern, url)
32
+ return match.group(1) if match else None
33
+
34
+
35
+ def download_pdf(patent_number):
36
+ """Downloads patent PDF using a temporary directory."""
37
+ try:
38
+ with tempfile.TemporaryDirectory() as temp_dir:
39
+ patent_downloader = PatentDownloader(verbose=True)
40
+ output_path = patent_downloader.download(patents=patent_number, output_path=temp_dir)
41
+ return output_path[0]
42
+ except Exception as e:
43
+ st.error(f"Failed to download patent PDF: {e}")
44
+ return None
45
+
46
+
47
  def load_docs(document_path):
48
+ """Loads and splits PDF documents into chunks."""
49
  try:
50
+ loader = UnstructuredPDFLoader(document_path)
 
 
 
 
 
51
  documents = loader.load()
52
  text_splitter = NLTKTextSplitter(chunk_size=1000)
53
  return text_splitter.split_documents(documents)
54
  except Exception as e:
55
+ st.error(f"Failed to process PDF: {e}")
56
+ return []
 
 
 
 
 
 
57
 
 
 
58
 
59
+ def load_chain(docs):
60
+ """Creates a conversational retrieval chain using in-memory ChromaDB."""
61
+ vectordb = Chroma.from_documents(
62
+ docs, HuggingFaceEmbeddings(), persist_directory=None
63
  )
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  memory = ConversationBufferMemory(
66
  memory_key="chat_history",
 
68
  input_key="question",
69
  output_key="answer",
70
  )
71
+
72
  return ConversationalRetrievalChain.from_llm(
73
  OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
74
  vectordb.as_retriever(search_kwargs={"k": 3}),
 
76
  memory=memory,
77
  )
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # Streamlit UI
81
  if __name__ == "__main__":
82
  st.set_page_config(
83
  page_title="Patent Chat: Google Patents Chat Demo",
 
85
  layout="wide",
86
  initial_sidebar_state="expanded",
87
  )
88
+
89
  st.header("πŸ“– Patent Chat: Google Patents Chat Demo")
90
 
91
+ # Input for Google Patent Link
92
  patent_link = st.text_input("Enter Google Patent Link:", key="PATENT_LINK")
93
 
94
  if not patent_link:
 
100
  st.error("Invalid patent link format. Please provide a valid Google patent link.")
101
  st.stop()
102
 
103
+ st.write(f"πŸ” Patent Number: **{patent_number}**")
104
 
105
+ # Download or Upload PDF
106
+ st.write("πŸ“₯ Downloading patent PDF...")
107
+ pdf_path = None
108
+
109
+ try:
110
  pdf_path = download_pdf(patent_number)
111
+ except Exception:
112
+ st.error("Automatic download failed. Please upload the PDF manually below.")
113
+
114
+ if not pdf_path:
115
+ uploaded_file = st.file_uploader("Upload the patent PDF file:", type="pdf")
116
+ if uploaded_file:
117
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
118
+ tmp_file.write(uploaded_file.read())
119
+ pdf_path = tmp_file.name
120
+ st.success("βœ… PDF successfully uploaded.")
121
+ else:
122
+ st.stop()
123
+
124
+ # Load and Process PDF
125
+ st.write("πŸ”„ Processing document...")
126
+ docs = load_docs(pdf_path)
127
+
128
+ if not docs:
129
+ st.error("No content found in the PDF. Exiting...")
130
+ st.stop()
131
 
132
+ chain = load_chain(docs)
 
133
  st.success("πŸš€ Document successfully loaded! You can now start asking questions.")
134
 
135
+ # Initialize chat history
136
  if "messages" not in st.session_state:
137
  st.session_state["messages"] = [
138
  {"role": "assistant", "content": "Hello! How can I assist you with this patent?"}
139
  ]
140
 
141
+ # Display chat history
142
  for message in st.session_state.messages:
143
  with st.chat_message(message["role"]):
144
  st.markdown(message["content"])
145
 
146
+ # Handle User Input
147
  if user_input := st.chat_input("What is your question?"):
148
  st.session_state.messages.append({"role": "user", "content": user_input})
149
+
150
  with st.chat_message("user"):
151
  st.markdown(user_input)
152
 
153
  with st.chat_message("assistant"):
154
  message_placeholder = st.empty()
155
+ with st.spinner("Generating response..."):
156
+ try:
157
+ assistant_response = chain({"question": user_input})
158
+ full_response = assistant_response.get("answer", "I'm sorry, I couldn't generate a response.")
159
+ except Exception as e:
160
+ full_response = f"An error occurred: {e}"
161
+ message_placeholder.markdown(full_response)
 
 
 
 
 
 
162
 
163
  st.session_state.messages.append({"role": "assistant", "content": full_response})