|
|
|
import sys |
|
import os |
|
import re |
|
import shutil |
|
import time |
|
import streamlit as st |
|
import nltk |
|
import tempfile |
|
import subprocess |
|
|
|
|
|
REQUIRED_NLTK_VERSION = "3.9.1" |
|
subprocess.run([sys.executable, "-m", "pip", "install", f"nltk=={REQUIRED_NLTK_VERSION}"]) |
|
|
|
|
|
nltk_data_path = os.path.join(tempfile.gettempdir(), "nltk_data") |
|
os.makedirs(nltk_data_path, exist_ok=True) |
|
nltk.data.path.append(nltk_data_path) |
|
|
|
|
|
try: |
|
print("Ensuring NLTK 'punkt_tab' resource is downloaded...") |
|
nltk.download("punkt_tab", download_dir=nltk_data_path) |
|
except Exception as e: |
|
print(f"Error downloading NLTK 'punkt_tab': {e}") |
|
raise e |
|
|
|
sys.path.append(os.path.abspath(".")) |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.llms import OpenAI |
|
from langchain.document_loaders import UnstructuredPDFLoader |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.text_splitter import NLTKTextSplitter |
|
from patent_downloader import PatentDownloader |
|
|
|
PERSISTED_DIRECTORY = tempfile.mkdtemp() |
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
if not OPENAI_API_KEY: |
|
st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.") |
|
st.stop() |
|
|
|
def check_poppler_installed(): |
|
if not shutil.which("pdfinfo"): |
|
raise EnvironmentError( |
|
"Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing." |
|
) |
|
|
|
check_poppler_installed() |
|
|
|
def load_docs(document_path): |
|
try: |
|
loader = UnstructuredPDFLoader( |
|
document_path, |
|
mode="elements", |
|
strategy="fast", |
|
ocr_languages=None |
|
) |
|
documents = loader.load() |
|
text_splitter = NLTKTextSplitter(chunk_size=1000) |
|
split_docs = text_splitter.split_documents(documents) |
|
|
|
|
|
for doc in split_docs: |
|
if hasattr(doc, "metadata") and isinstance(doc.metadata, dict): |
|
doc.metadata = { |
|
k: v for k, v in doc.metadata.items() |
|
if isinstance(v, (str, int, float, bool)) |
|
} |
|
return split_docs |
|
except Exception as e: |
|
st.error(f"Failed to load and process PDF: {e}") |
|
st.stop() |
|
|
|
def already_indexed(vectordb, file_name): |
|
indexed_sources = set( |
|
x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"] |
|
) |
|
return file_name in indexed_sources |
|
|
|
def load_chain(file_name=None): |
|
loaded_patent = st.session_state.get("LOADED_PATENT") |
|
|
|
vectordb = Chroma( |
|
persist_directory=PERSISTED_DIRECTORY, |
|
embedding_function=HuggingFaceEmbeddings(), |
|
) |
|
if loaded_patent == file_name or already_indexed(vectordb, file_name): |
|
st.write("β
Already indexed.") |
|
else: |
|
vectordb.delete_collection() |
|
docs = load_docs(file_name) |
|
st.write("π Number of Documents: ", len(docs)) |
|
|
|
vectordb = Chroma.from_documents( |
|
docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY |
|
) |
|
vectordb.persist() |
|
st.session_state["LOADED_PATENT"] = file_name |
|
|
|
memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
return_messages=True, |
|
input_key="question", |
|
output_key="answer", |
|
) |
|
return ConversationalRetrievalChain.from_llm( |
|
OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY), |
|
vectordb.as_retriever(search_kwargs={"k": 3}), |
|
return_source_documents=False, |
|
memory=memory, |
|
) |
|
|
|
def extract_patent_number(url): |
|
pattern = r"/patent/([A-Z]{2}\d+)" |
|
match = re.search(pattern, url) |
|
return match.group(1) if match else None |
|
|
|
def download_pdf(patent_number): |
|
try: |
|
patent_downloader = PatentDownloader(verbose=True) |
|
output_path = patent_downloader.download(patents=patent_number, output_path=tempfile.gettempdir()) |
|
return output_path[0] |
|
except Exception as e: |
|
st.error(f"Failed to download patent PDF: {e}") |
|
st.stop() |
|
|
|
if __name__ == "__main__": |
|
st.set_page_config( |
|
page_title="Patent Chat: Google Patents Chat Demo", |
|
page_icon="π", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
) |
|
st.header("π Patent Chat: Google Patents Chat Demo") |
|
|
|
|
|
query_params = st.query_params |
|
default_patent_link = query_params.get("patent_link", "https://patents.google.com/patent/US8676427B1/en") |
|
|
|
|
|
patent_link = st.text_area("Enter Google Patent Link:", value=default_patent_link, height=100) |
|
|
|
|
|
if st.button("Load and Process Patent"): |
|
if not patent_link: |
|
st.warning("Please enter a Google patent link to proceed.") |
|
st.stop() |
|
|
|
patent_number = extract_patent_number(patent_link) |
|
if not patent_number: |
|
st.error("Invalid patent link format. Please provide a valid Google patent link.") |
|
st.stop() |
|
|
|
st.write(f"Patent number: **{patent_number}**") |
|
|
|
pdf_path = os.path.join(tempfile.gettempdir(), f"{patent_number}.pdf") |
|
if os.path.isfile(pdf_path): |
|
st.write("β
File already downloaded.") |
|
else: |
|
st.write("π₯ Downloading patent file...") |
|
pdf_path = download_pdf(patent_number) |
|
st.write(f"β
File downloaded: {pdf_path}") |
|
|
|
st.write("π Loading document into the system...") |
|
|
|
|
|
if "chain" not in st.session_state or st.session_state.get("loaded_file") != pdf_path: |
|
st.session_state.chain = load_chain(pdf_path) |
|
st.session_state.loaded_file = pdf_path |
|
st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}] |
|
|
|
st.success("π Document successfully loaded! You can now start asking questions.") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
if "chain" in st.session_state: |
|
if user_input := st.chat_input("What is your question?"): |
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
with st.chat_message("user"): |
|
st.markdown(user_input) |
|
|
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
|
|
with st.spinner("Generating response..."): |
|
try: |
|
assistant_response = st.session_state.chain({"question": user_input}) |
|
full_response = assistant_response["answer"] |
|
except Exception as e: |
|
full_response = f"An error occurred: {e}" |
|
|
|
message_placeholder.markdown(full_response) |
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
else: |
|
st.info("Press the 'Load and Process Patent' button to start processing.") |
|
|