pradeepsengarr commited on
Commit
dea11f3
·
verified ·
1 Parent(s): 50a3fdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -24
app.py CHANGED
@@ -446,7 +446,7 @@ def load_model():
446
  checkpoint = "MBZUAI/LaMini-T5-738M"
447
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
448
  model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
449
- pipe = pipeline('text2text-generation', model=model, tokenizer=tokenizer, max_length=512)
450
  return HuggingFacePipeline(pipeline=pipe)
451
 
452
  # --- Extract PDF Text ---
@@ -461,44 +461,46 @@ def read_pdf(file):
461
  logging.error(f"Failed to extract text: {e}")
462
  return ""
463
 
464
- # --- Process Answer ---
465
- def process_answer(question, full_text):
466
- # Save the full_text to a temporary file
 
467
  with open("temp_text.txt", "w") as f:
468
  f.write(full_text)
469
 
470
  loader = TextLoader("temp_text.txt")
471
  docs = loader.load()
472
 
473
- # Chunk the documents
474
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
475
  splits = text_splitter.split_documents(docs)
476
 
477
- # Load embeddings
478
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
479
 
 
 
 
 
 
480
 
481
- from langchain_community.vectorstores import Chroma
482
-
483
- # Create Chroma in-memory vector store
484
- db = Chroma.from_documents(splits, embeddings)
485
- retriever = db.as_retriever()
486
-
487
 
488
-
489
- # Set up the model
490
  llm = load_model()
491
 
492
- # RAG-style retrieval QA
493
- qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
494
-
495
- # Smart prompting
496
- if "summarize" in question.lower() or "summary" in question.lower() or "tl;dr" in question.lower():
497
  prompt = f"Summarize the following document:\n\n{full_text[:3000]}"
498
- summary = llm(prompt)
499
  return summary
500
- else:
501
- return qa_chain.run(question)
 
 
502
 
503
  # --- UI Layout ---
504
  with st.sidebar:
@@ -518,9 +520,12 @@ if uploaded_file:
518
  st.subheader("💬 Ask a Question")
519
  user_question = st.text_input("Type your question about the PDF content")
520
 
 
 
 
521
  if user_question:
522
  with st.spinner("Thinking..."):
523
- answer = process_answer(user_question, full_text)
524
  st.markdown("### 🤖 Answer")
525
  st.write(answer)
526
 
@@ -540,3 +545,4 @@ if uploaded_file:
540
  st.error("⚠️ No text could be extracted from the PDF. Try another file.")
541
  else:
542
  st.info("Upload a PDF to begin.")
 
 
446
  checkpoint = "MBZUAI/LaMini-T5-738M"
447
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
448
  model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
449
+ pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
450
  return HuggingFacePipeline(pipeline=pipe)
451
 
452
  # --- Extract PDF Text ---
 
461
  logging.error(f"Failed to extract text: {e}")
462
  return ""
463
 
464
+ # --- Build Retriever (cached per session) ---
465
+ @st.cache_resource
466
+ def build_retriever(full_text):
467
+ # Save text to temp file
468
  with open("temp_text.txt", "w") as f:
469
  f.write(full_text)
470
 
471
  loader = TextLoader("temp_text.txt")
472
  docs = loader.load()
473
 
474
+ # Chunking
475
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=300)
476
  splits = text_splitter.split_documents(docs)
477
 
478
+ # Embeddings
479
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
480
 
481
+ # Safe temporary directory for Chroma
482
+ chroma_dir = os.path.join(tempfile.gettempdir(), "chroma_db_rag")
483
+ if os.path.exists(chroma_dir):
484
+ shutil.rmtree(chroma_dir)
485
+ os.makedirs(chroma_dir, exist_ok=True)
486
 
487
+ db = Chroma.from_documents(splits, embeddings, persist_directory=chroma_dir)
488
+ db.persist()
489
+ return db.as_retriever(search_kwargs={"k": 6})
 
 
 
490
 
491
+ # --- Process Answer ---
492
+ def process_answer(question, full_text, retriever):
493
  llm = load_model()
494
 
495
+ # Special handling for summary-type queries
496
+ if any(x in question.lower() for x in ["summarize", "summary", "tl;dr"]):
 
 
 
497
  prompt = f"Summarize the following document:\n\n{full_text[:3000]}"
498
+ summary = llm(prompt) # Uses the LLM to generate a summary
499
  return summary
500
+
501
+ # Use RetrievalQA for general queries
502
+ qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
503
+ return qa_chain.run(question) # This is the main answer generation with retrieval
504
 
505
  # --- UI Layout ---
506
  with st.sidebar:
 
520
  st.subheader("💬 Ask a Question")
521
  user_question = st.text_input("Type your question about the PDF content")
522
 
523
+ # Build retriever once per session
524
+ retriever = build_retriever(full_text)
525
+
526
  if user_question:
527
  with st.spinner("Thinking..."):
528
+ answer = process_answer(user_question, full_text, retriever)
529
  st.markdown("### 🤖 Answer")
530
  st.write(answer)
531
 
 
545
  st.error("⚠️ No text could be extracted from the PDF. Try another file.")
546
  else:
547
  st.info("Upload a PDF to begin.")
548
+