Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -446,7 +446,7 @@ def load_model():
|
|
446 |
checkpoint = "MBZUAI/LaMini-T5-738M"
|
447 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
448 |
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
449 |
-
pipe = pipeline(
|
450 |
return HuggingFacePipeline(pipeline=pipe)
|
451 |
|
452 |
# --- Extract PDF Text ---
|
@@ -461,44 +461,46 @@ def read_pdf(file):
|
|
461 |
logging.error(f"Failed to extract text: {e}")
|
462 |
return ""
|
463 |
|
464 |
-
# ---
|
465 |
-
|
466 |
-
|
|
|
467 |
with open("temp_text.txt", "w") as f:
|
468 |
f.write(full_text)
|
469 |
|
470 |
loader = TextLoader("temp_text.txt")
|
471 |
docs = loader.load()
|
472 |
|
473 |
-
#
|
474 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=
|
475 |
splits = text_splitter.split_documents(docs)
|
476 |
|
477 |
-
#
|
478 |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
479 |
|
|
|
|
|
|
|
|
|
|
|
480 |
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
db = Chroma.from_documents(splits, embeddings)
|
485 |
-
retriever = db.as_retriever()
|
486 |
-
|
487 |
|
488 |
-
|
489 |
-
|
490 |
llm = load_model()
|
491 |
|
492 |
-
#
|
493 |
-
|
494 |
-
|
495 |
-
# Smart prompting
|
496 |
-
if "summarize" in question.lower() or "summary" in question.lower() or "tl;dr" in question.lower():
|
497 |
prompt = f"Summarize the following document:\n\n{full_text[:3000]}"
|
498 |
-
summary = llm(prompt)
|
499 |
return summary
|
500 |
-
|
501 |
-
|
|
|
|
|
502 |
|
503 |
# --- UI Layout ---
|
504 |
with st.sidebar:
|
@@ -518,9 +520,12 @@ if uploaded_file:
|
|
518 |
st.subheader("💬 Ask a Question")
|
519 |
user_question = st.text_input("Type your question about the PDF content")
|
520 |
|
|
|
|
|
|
|
521 |
if user_question:
|
522 |
with st.spinner("Thinking..."):
|
523 |
-
answer = process_answer(user_question, full_text)
|
524 |
st.markdown("### 🤖 Answer")
|
525 |
st.write(answer)
|
526 |
|
@@ -540,3 +545,4 @@ if uploaded_file:
|
|
540 |
st.error("⚠️ No text could be extracted from the PDF. Try another file.")
|
541 |
else:
|
542 |
st.info("Upload a PDF to begin.")
|
|
|
|
446 |
checkpoint = "MBZUAI/LaMini-T5-738M"
|
447 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
448 |
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
449 |
+
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
|
450 |
return HuggingFacePipeline(pipeline=pipe)
|
451 |
|
452 |
# --- Extract PDF Text ---
|
|
|
461 |
logging.error(f"Failed to extract text: {e}")
|
462 |
return ""
|
463 |
|
464 |
+
# --- Build Retriever (cached per session) ---
|
465 |
+
@st.cache_resource
|
466 |
+
def build_retriever(full_text):
|
467 |
+
# Save text to temp file
|
468 |
with open("temp_text.txt", "w") as f:
|
469 |
f.write(full_text)
|
470 |
|
471 |
loader = TextLoader("temp_text.txt")
|
472 |
docs = loader.load()
|
473 |
|
474 |
+
# Chunking
|
475 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=300)
|
476 |
splits = text_splitter.split_documents(docs)
|
477 |
|
478 |
+
# Embeddings
|
479 |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
480 |
|
481 |
+
# Safe temporary directory for Chroma
|
482 |
+
chroma_dir = os.path.join(tempfile.gettempdir(), "chroma_db_rag")
|
483 |
+
if os.path.exists(chroma_dir):
|
484 |
+
shutil.rmtree(chroma_dir)
|
485 |
+
os.makedirs(chroma_dir, exist_ok=True)
|
486 |
|
487 |
+
db = Chroma.from_documents(splits, embeddings, persist_directory=chroma_dir)
|
488 |
+
db.persist()
|
489 |
+
return db.as_retriever(search_kwargs={"k": 6})
|
|
|
|
|
|
|
490 |
|
491 |
+
# --- Process Answer ---
|
492 |
+
def process_answer(question, full_text, retriever):
|
493 |
llm = load_model()
|
494 |
|
495 |
+
# Special handling for summary-type queries
|
496 |
+
if any(x in question.lower() for x in ["summarize", "summary", "tl;dr"]):
|
|
|
|
|
|
|
497 |
prompt = f"Summarize the following document:\n\n{full_text[:3000]}"
|
498 |
+
summary = llm(prompt) # Uses the LLM to generate a summary
|
499 |
return summary
|
500 |
+
|
501 |
+
# Use RetrievalQA for general queries
|
502 |
+
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
|
503 |
+
return qa_chain.run(question) # This is the main answer generation with retrieval
|
504 |
|
505 |
# --- UI Layout ---
|
506 |
with st.sidebar:
|
|
|
520 |
st.subheader("💬 Ask a Question")
|
521 |
user_question = st.text_input("Type your question about the PDF content")
|
522 |
|
523 |
+
# Build retriever once per session
|
524 |
+
retriever = build_retriever(full_text)
|
525 |
+
|
526 |
if user_question:
|
527 |
with st.spinner("Thinking..."):
|
528 |
+
answer = process_answer(user_question, full_text, retriever)
|
529 |
st.markdown("### 🤖 Answer")
|
530 |
st.write(answer)
|
531 |
|
|
|
545 |
st.error("⚠️ No text could be extracted from the PDF. Try another file.")
|
546 |
else:
|
547 |
st.info("Upload a PDF to begin.")
|
548 |
+
|