|
import streamlit as st |
|
from PyPDF2 import PdfReader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
import os |
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain.chains.question_answering import load_qa_chain |
|
from langchain.prompts import PromptTemplate |
|
|
|
|
|
|
|
|
|
API_KEYS = [os.getenv("APIKEY1"), os.getenv("APIKEY2")] |
|
current_key_index = -1 |
|
|
|
|
|
|
|
|
|
def switch_api_key(): |
|
global current_key_index |
|
current_key_index = (current_key_index + 1) % len(API_KEYS) |
|
return API_KEYS[current_key_index] |
|
|
|
|
|
|
|
|
|
def get_pdf_text(pdf_docs): |
|
text = "" |
|
for pdf in pdf_docs: |
|
pdf_reader = PdfReader(pdf) |
|
for page in pdf_reader.pages: |
|
text += page.extract_text() |
|
return text |
|
|
|
|
|
|
|
|
|
def get_text_chunks(text): |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000) |
|
return text_splitter.split_text(text) |
|
|
|
|
|
|
|
|
|
def get_vector_store(text_chunks): |
|
api_key = switch_api_key() |
|
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key) |
|
vector_store = FAISS.from_texts(text_chunks, embedding=embeddings) |
|
vector_store.save_local("faiss_index") |
|
|
|
|
|
|
|
|
|
def get_conversational_chain(): |
|
api_key = switch_api_key() |
|
prompt_template = """ |
|
You are a helpful assistant that only answers based on the context provided from the PDF documents. |
|
Do not use any external knowledge or assumptions. If the answer is not found in the context below, reply with "I don't know." |
|
|
|
|
|
Context: |
|
{context} |
|
|
|
|
|
Question: |
|
{question} |
|
|
|
|
|
Answer: |
|
""" |
|
model = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0, google_api_key=api_key) |
|
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) |
|
chain = load_qa_chain(model, chain_type="stuff", prompt=prompt) |
|
return chain |
|
|
|
|
|
|
|
|
|
def user_input(user_question): |
|
api_key = switch_api_key() |
|
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key) |
|
new_db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True) |
|
docs = new_db.similarity_search(user_question) |
|
chain = get_conversational_chain() |
|
|
|
|
|
response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True) |
|
st.write("Reply: ", response["output_text"]) |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
st.markdown( |
|
""" |
|
<style> |
|
.header {font-size: 20px !important;} |
|
.subheader {font-size: 16px !important;} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
st.markdown('<h1 class="header">CSC 121: Computers and Scientific Thinking (Chatbot)</h1>', unsafe_allow_html=True) |
|
st.markdown('<h2 class="subheader">Ask a question ONLY from the CSC 121 textbook of Dr. Reed</h2>', unsafe_allow_html=True) |
|
|
|
user_question = st.text_input("Ask a question") |
|
|
|
if user_question: |
|
user_input(user_question) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|