ChatWithData / app.py
Fiqa's picture
Upload 2 files
628527c verified
raw
history blame
2.91 kB
import streamlit as st
import PyPDF2
from langchain.llms import HuggingFaceHub
import pptx
import os
from langchain.vectorstores.cassandra import Cassandra
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
from langchain.embeddings import OpenAIEmbeddings
import cassio
from langchain.text_splitter import CharacterTextSplitter
# Secure API keys (replace with environment variables in deployment)
ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
ASTRA_DB_ID = os.getenv("ASTRA_DB_ID")
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# Initialize Astra DB connection
cassio.init(token=ASTRA_DB_APPLICATION_TOKEN, database_id=ASTRA_DB_ID)
# Initialize LLM & Embeddings
hf_llm = HuggingFaceHub(repo_id="google/flan-t5-large", model_kwargs={"temperature": 0, "max_length": 64})
embedding =OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
# Initialize vector store
astra_vector_store = Cassandra(embedding=embedding, table_name="qa_mini_demo")
def extract_text_from_pdf(uploaded_file):
"""Extract text from a PDF file."""
text = ""
pdf_reader = PyPDF2.PdfReader(uploaded_file)
for page in pdf_reader.pages:
page_text = page.extract_text()
if page_text: # Avoid NoneType error
text += page_text + "\n"
return text
def extract_text_from_ppt(uploaded_file):
"""Extract text from a PowerPoint file."""
text = ""
presentation = pptx.Presentation(uploaded_file)
for slide in presentation.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
return text
def main():
st.title("Chat with Documents")
uploaded_file = st.file_uploader("Upload a PDF or PPT file", type=["pdf", "pptx"])
extract_button = st.button("Extract Text")
extracted_text = ""
if extract_button and uploaded_file is not None:
if uploaded_file.name.endswith(".pdf"):
extracted_text = extract_text_from_pdf(uploaded_file)
elif uploaded_file.name.endswith(".pptx"):
extracted_text = extract_text_from_ppt(uploaded_file)
if extracted_text:
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=800, chunk_overlap=200, length_function=len)
texts = text_splitter.split_text(extracted_text)
astra_vector_store.add_texts(texts)
# Ensure the vector store index is initialized properly
astra_vector_index = VectorStoreIndexWrapper(vectorstore=astra_vector_store)
query = st.text_input("Enter your query")
submit_query = st.button("Submit Query")
if submit_query:
value = astra_vector_index.query(query, llm=hf_llm)
st.write(f"Response: {value}")
if __name__ == "__main__":
main()