bainskarman commited on
Commit
c0a164f
·
verified ·
1 Parent(s): 4e18c76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -44
app.py CHANGED
@@ -1,28 +1,20 @@
1
  import streamlit as st
 
 
2
  from PyPDF2 import PdfReader
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain.vectorstores import FAISS
6
- from langchain.chains import RetrievalQA
7
- from langchain.llms import HuggingFacePipeline
8
- import torch
9
- from transformers import pipeline
10
  from langdetect import detect
11
 
12
- # Load a smaller LLM with customizable parameters
13
- def load_llm(temperature, top_k, max_new_tokens):
 
 
 
14
  model_name = "HuggingFaceH4/zephyr-7b-alpha" # Replace with your preferred model
15
- pipe = pipeline(
16
- "text-generation",
17
- model=model_name,
18
- torch_dtype=torch.float16,
19
- device_map="auto",
20
- temperature=temperature,
21
- top_k=top_k,
22
- max_new_tokens=max_new_tokens, # Use max_new_tokens instead of max_length
23
- )
24
- llm = HuggingFacePipeline(pipeline=pipe)
25
- return llm
26
 
27
  # Extract text from PDF
28
  def extract_text_from_pdf(file):
@@ -51,25 +43,25 @@ def create_vector_store(chunks, indexing_method="multi-representation", **kwargs
51
  vector_store = FAISS.from_texts(chunks, embeddings)
52
  return vector_store
53
 
54
- # Query the PDF
55
- def query_pdf(vector_store, query, llm, query_method="multi-query", **kwargs):
56
- if query_method == "multi-query":
57
- # Implement Multi-Query logic here
58
- qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
59
- elif query_method == "rag-fusion":
60
- # Implement RAG Fusion logic here
61
- qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
62
- elif query_method == "decomposition":
63
- # Implement Decomposition logic here
64
- qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
65
- elif query_method == "step-back":
66
- # Implement Step Back logic here
67
- qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
68
- elif query_method == "hyde":
69
- # Implement HyDE logic here
70
- qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
71
- result = qa.run(query)
72
- return result
73
 
74
  # Detect language of the text
75
  def detect_language(text):
@@ -115,9 +107,9 @@ def main():
115
  st.session_state.vector_store = create_vector_store(st.session_state.chunks, indexing_method=indexing_method)
116
  st.success("Vector store created!")
117
 
118
- # Step 3: Load LLM with user-defined parameters
119
- if "llm" not in st.session_state:
120
- st.session_state.llm = None
121
 
122
  if st.session_state.vector_store:
123
  st.subheader("LLM Parameters")
@@ -125,11 +117,12 @@ def main():
125
  top_k = st.slider("Top-k", 1, 100, 50, help="Limits sampling to the top-k tokens.")
126
  max_new_tokens = st.slider("Max New Tokens", 50, 500, 200, help="Maximum number of tokens to generate.")
127
  if st.button("Load LLM"):
128
- st.session_state.llm = load_llm(temperature=temperature, top_k=top_k, max_new_tokens=max_new_tokens)
 
129
  st.success("LLM loaded!")
130
 
131
  # Step 4: Query the PDF
132
- if st.session_state.llm:
133
  st.subheader("Query Translation Options")
134
  query_method = st.selectbox(
135
  "Query Translation Method",
@@ -138,9 +131,19 @@ def main():
138
  )
139
  query = st.text_input("Ask a question about the PDF:")
140
  if query:
141
- result = query_pdf(st.session_state.vector_store, query, st.session_state.llm, query_method=query_method)
142
- st.write("**Answer:**", result["answer"])
143
- st.write("**Source Text:**", result["source_text"])
 
 
 
 
 
 
 
 
 
 
144
 
145
  if __name__ == "__main__":
146
  main()
 
1
  import streamlit as st
2
+ import os
3
+ from huggingface_hub import InferenceApi
4
  from PyPDF2 import PdfReader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain.vectorstores import FAISS
 
 
 
 
8
  from langdetect import detect
9
 
10
+ # Load the Hugging Face token from environment variables (secrets)
11
+ token = os.environ.get("KEY2") # Replace "KEY2" with your secret key name
12
+
13
+ # Initialize the Hugging Face Inference API
14
+ def load_llm():
15
  model_name = "HuggingFaceH4/zephyr-7b-alpha" # Replace with your preferred model
16
+ api = InferenceApi(repo_id=model_name, token=token)
17
+ return api
 
 
 
 
 
 
 
 
 
18
 
19
  # Extract text from PDF
20
  def extract_text_from_pdf(file):
 
43
  vector_store = FAISS.from_texts(chunks, embeddings)
44
  return vector_store
45
 
46
+ # Query the PDF using the Hugging Face API
47
+ def query_pdf(vector_store, query, api, query_method="multi-query", max_new_tokens=200, temperature=0.7, top_k=50):
48
+ # Retrieve relevant chunks from the vector store
49
+ docs = vector_store.similarity_search(query)
50
+ context = " ".join([doc.page_content for doc in docs])
51
+
52
+ # Create a prompt for the LLM
53
+ prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
54
+
55
+ # Query the Hugging Face API
56
+ response = api(
57
+ inputs=prompt,
58
+ parameters={
59
+ "max_new_tokens": max_new_tokens,
60
+ "temperature": temperature,
61
+ "top_k": top_k,
62
+ },
63
+ )
64
+ return response[0]["generated_text"], docs
65
 
66
  # Detect language of the text
67
  def detect_language(text):
 
107
  st.session_state.vector_store = create_vector_store(st.session_state.chunks, indexing_method=indexing_method)
108
  st.success("Vector store created!")
109
 
110
+ # Step 3: Load LLM (Hugging Face API)
111
+ if "api" not in st.session_state:
112
+ st.session_state.api = None
113
 
114
  if st.session_state.vector_store:
115
  st.subheader("LLM Parameters")
 
117
  top_k = st.slider("Top-k", 1, 100, 50, help="Limits sampling to the top-k tokens.")
118
  max_new_tokens = st.slider("Max New Tokens", 50, 500, 200, help="Maximum number of tokens to generate.")
119
  if st.button("Load LLM"):
120
+ api = load_llm()
121
+ st.session_state.api = api
122
  st.success("LLM loaded!")
123
 
124
  # Step 4: Query the PDF
125
+ if st.session_state.api:
126
  st.subheader("Query Translation Options")
127
  query_method = st.selectbox(
128
  "Query Translation Method",
 
131
  )
132
  query = st.text_input("Ask a question about the PDF:")
133
  if query:
134
+ answer, source_docs = query_pdf(
135
+ st.session_state.vector_store,
136
+ query,
137
+ st.session_state.api,
138
+ query_method=query_method,
139
+ max_new_tokens=max_new_tokens,
140
+ temperature=temperature,
141
+ top_k=top_k,
142
+ )
143
+ st.write("**Answer:**", answer)
144
+ st.write("**Source Text:**")
145
+ for doc in source_docs:
146
+ st.write(doc.page_content)
147
 
148
  if __name__ == "__main__":
149
  main()