Talha812 commited on
Commit
766a4a2
Β·
verified Β·
1 Parent(s): 5e54387

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import GPTNeoXForCausalLM, AutoTokenizer
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import fitz # PyMuPDF
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+
9
+ # Configuration
10
+ MODEL_NAME = "ibm-granite/granite-3.1-1b-a400m-instruct"
11
+ EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+ CHUNK_SIZE = 512
14
+ CHUNK_OVERLAP = 50
15
+
16
+ @st.cache_resource
17
+ def load_models():
18
+ try:
19
+ # Load Granite model
20
+ tokenizer = AutoTokenizer.from_pretrained(
21
+ MODEL_NAME,
22
+ trust_remote_code=True
23
+ )
24
+
25
+ model = GPTNeoXForCausalLM.from_pretrained(
26
+ MODEL_NAME,
27
+ device_map="auto" if DEVICE == "cuda" else None,
28
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
29
+ trust_remote_code=True
30
+ ).eval()
31
+
32
+ # Load sentence transformer for embeddings
33
+ embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
34
+
35
+ return tokenizer, model, embedder
36
+
37
+ except Exception as e:
38
+ st.error(f"Model loading failed: {str(e)}")
39
+ st.stop()
40
+
41
+ tokenizer, model, embedder = load_models()
42
+
43
+ # Text processing
44
+ def process_text(text):
45
+ splitter = RecursiveCharacterTextSplitter(
46
+ chunk_size=CHUNK_SIZE,
47
+ chunk_overlap=CHUNK_OVERLAP,
48
+ length_function=len
49
+ )
50
+ return splitter.split_text(text)
51
+
52
+ # PDF extraction
53
+ def extract_pdf_text(uploaded_file):
54
+ try:
55
+ doc = fitz.open(stream=uploaded_file.read(), filetype="pdf")
56
+ return "\n".join([page.get_text() for page in doc])
57
+ except Exception as e:
58
+ st.error(f"PDF extraction error: {str(e)}")
59
+ return ""
60
+
61
+ # Summarization function
62
+ def generate_summary(text):
63
+ chunks = process_text(text)[:10]
64
+ summaries = []
65
+
66
+ for chunk in chunks:
67
+ prompt = f"""<|user|>
68
+ Summarize this text section focusing on key themes, characters, and plot points:
69
+ {chunk[:2000]}
70
+ <|assistant|>
71
+ """
72
+
73
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
74
+ outputs = model.generate(**inputs, max_new_tokens=300, temperature=0.3)
75
+ summaries.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
76
+
77
+ combined = "\n".join(summaries)
78
+ final_prompt = f"""<|user|>
79
+ Combine these section summaries into a coherent book summary:
80
+ {combined}
81
+ <|assistant|>
82
+ The comprehensive summary is:"""
83
+
84
+ inputs = tokenizer(final_prompt, return_tensors="pt").to(DEVICE)
85
+ outputs = model.generate(**inputs, max_new_tokens=500, temperature=0.5)
86
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).split(":")[-1].strip()
87
+
88
+ # FAISS index creation
89
+ def build_faiss_index(texts):
90
+ embeddings = embedder.encode(texts, show_progress_bar=False)
91
+ dimension = embeddings.shape[1]
92
+ index = faiss.IndexFlatIP(dimension)
93
+ faiss.normalize_L2(embeddings)
94
+ index.add(embeddings)
95
+ return index
96
+
97
+ # Answer generation
98
+ def generate_answer(query, context):
99
+ prompt = f"""<|user|>
100
+ Using this context: {context}
101
+ Answer the question precisely and truthfully. If unsure, say "I don't know".
102
+ Question: {query}
103
+ <|assistant|>
104
+ """
105
+
106
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE)
107
+ outputs = model.generate(
108
+ **inputs,
109
+ max_new_tokens=300,
110
+ temperature=0.4,
111
+ top_p=0.9,
112
+ repetition_penalty=1.2,
113
+ do_sample=True
114
+ )
115
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip()
116
+
117
+ # Streamlit UI
118
+ st.set_page_config(page_title="πŸ“š Smart Book Analyst", layout="wide")
119
+ st.title("πŸ“š AI-Powered Book Analysis System")
120
+
121
+ uploaded_file = st.file_uploader("Upload book (PDF or TXT)", type=["pdf", "txt"])
122
+
123
+ if uploaded_file:
124
+ with st.spinner("πŸ“– Analyzing book content..."):
125
+ try:
126
+ if uploaded_file.type == "application/pdf":
127
+ text = extract_pdf_text(uploaded_file)
128
+ else:
129
+ text = uploaded_file.read().decode()
130
+
131
+ chunks = process_text(text)
132
+ st.session_state.docs = chunks
133
+ st.session_state.index = build_faiss_index(chunks)
134
+
135
+ with st.expander("πŸ“ Book Summary", expanded=True):
136
+ summary = generate_summary(text)
137
+ st.write(summary)
138
+
139
+ except Exception as e:
140
+ st.error(f"Processing failed: {str(e)}")
141
+
142
+ if 'index' in st.session_state and st.session_state.index:
143
+ query = st.text_input("Ask about the book:")
144
+ if query:
145
+ with st.spinner("πŸ” Searching for answers..."):
146
+ try:
147
+ query_embed = embedder.encode([query])
148
+ faiss.normalize_L2(query_embed)
149
+ distances, indices = st.session_state.index.search(query_embed, k=3)
150
+
151
+ context = "\n".join([st.session_state.docs[i] for i in indices[0]])
152
+ answer = generate_answer(query, context)
153
+
154
+ st.subheader("Answer")
155
+ st.markdown(f"```\n{answer}\n```")
156
+ st.caption("Retrieved context confidence: {:.2f}".format(distances[0][0]))
157
+
158
+ except Exception as e:
159
+ st.error(f"Query failed: {str(e)}")