Spaces:
Sleeping
Sleeping
File size: 5,785 Bytes
76552c4 ab46633 64794fb 76552c4 64794fb ab46633 64794fb ab46633 64794fb ab46633 64794fb 76552c4 64794fb ab46633 64794fb 76552c4 64794fb ab46633 64794fb ab46633 64794fb 76552c4 ab46633 64794fb ab46633 64794fb ab46633 64794fb ab46633 64794fb ab46633 64794fb ab46633 64794fb ab46633 64794fb ab46633 64794fb ab46633 76552c4 64794fb ab46633 64794fb ab46633 76552c4 ab46633 76552c4 ab46633 64794fb ab46633 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import streamlit as st
st.set_page_config(page_title="RAG Book Analyzer", layout="wide") # Must be the first Streamlit command
import torch
import numpy as np
import faiss
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import fitz # PyMuPDF for PDF extraction
import docx2txt # For DOCX extraction
from langchain_text_splitters import RecursiveCharacterTextSplitter
# ------------------------
# Configuration
# ------------------------
MODEL_NAME = "ibm-granite/granite-3.1-1b-a400m-instruct"
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------------
# Model Loading with Caching
# ------------------------
@st.cache_resource
def load_models():
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
revision="main"
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
revision="main",
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
low_cpu_mem_usage=True
).eval()
embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
return tokenizer, model, embedder
except Exception as e:
st.error(f"Model loading failed: {str(e)}")
st.stop()
tokenizer, model, embedder = load_models()
# ------------------------
# Text Processing Functions
# ------------------------
def split_text(text):
splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len
)
return splitter.split_text(text)
def extract_text(file):
file_type = file.type
if file_type == "application/pdf":
try:
doc = fitz.open(stream=file.read(), filetype="pdf")
return "\n".join([page.get_text() for page in doc])
except Exception as e:
st.error("Error processing PDF: " + str(e))
return ""
elif file_type == "text/plain":
return file.read().decode("utf-8")
elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
try:
return docx2txt.process(file)
except Exception as e:
st.error("Error processing DOCX: " + str(e))
return ""
else:
st.error("Unsupported file type: " + file_type)
return ""
def build_index(chunks):
embeddings = embedder.encode(chunks, show_progress_bar=True)
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
faiss.normalize_L2(embeddings)
index.add(embeddings)
return index
# ------------------------
# Summarization and Q&A Functions
# ------------------------
def generate_summary(text):
# Limit input text to avoid long sequences
prompt = f"<|user|>\nSummarize the following book in a concise and informative paragraph:\n\n{text[:4000]}\n<|assistant|>\n"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
outputs = model.generate(**inputs, max_new_tokens=300, temperature=0.5)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove any markers and extra lines; return the first non-empty paragraph.
summary = summary.replace("<|assistant|>", "").strip()
paragraphs = [p.strip() for p in summary.split("\n") if p.strip()]
return paragraphs[0] if paragraphs else summary
def generate_answer(query, context):
prompt = f"<|user|>\nUsing the context below, answer the following question precisely. If unsure, say 'I don't know'.\n\nContext: {context}\n\nQuestion: {query}\n<|assistant|>\n"
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE)
outputs = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.4,
top_p=0.9,
repetition_penalty=1.2,
do_sample=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = answer.replace("<|assistant|>", "").strip()
paragraphs = [p.strip() for p in answer.split("\n") if p.strip()]
return paragraphs[0] if paragraphs else answer
# ------------------------
# Streamlit UI
# ------------------------
st.title("RAG-Based Book Analyzer")
st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.")
uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"])
if uploaded_file:
text = extract_text(uploaded_file)
if text:
st.success("File successfully processed!")
st.write("Generating summary...")
summary = generate_summary(text)
st.markdown("### Book Summary")
st.write(summary)
# Process text into chunks and build FAISS index
chunks = split_text(text)
index = build_index(chunks)
st.session_state.chunks = chunks
st.session_state.index = index
st.markdown("### Ask a Question about the Book:")
query = st.text_input("Your Question:")
if query:
# Retrieve top 3 relevant chunks as context
query_embedding = embedder.encode([query])
faiss.normalize_L2(query_embedding)
distances, indices = st.session_state.index.search(query_embedding, k=3)
retrieved_chunks = [chunks[i] for i in indices[0] if i < len(chunks)]
context = "\n".join(retrieved_chunks)
answer = generate_answer(query, context)
st.markdown("### Answer")
st.write(answer)
|