sunbal7 commited on
Commit
76552c4
·
verified ·
1 Parent(s): 3471ccc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import faiss
4
+ import numpy as np
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
+ from sentence_transformers import SentenceTransformer
7
+ from PyPDF2 import PdfReader
8
+ from docx import Document
9
+ import re
10
+
11
+ # Initialize models
12
+ @st.cache_resource
13
+ def load_models():
14
+ # Text embedding model
15
+ embed_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
16
+
17
+ # IBM Granite models
18
+ summary_tokenizer = AutoTokenizer.from_pretrained("ibm/granite-13b-instruct-v2")
19
+ summary_model = AutoModelForCausalLM.from_pretrained("ibm/granite-13b-instruct-v2")
20
+
21
+ qa_tokenizer = AutoTokenizer.from_pretrained("ibm/granite-13b-instruct-v2")
22
+ qa_model = AutoModelForCausalLM.from_pretrained("ibm/granite-13b-instruct-v2")
23
+
24
+ return embed_model, summary_model, summary_tokenizer, qa_model, qa_tokenizer
25
+
26
+ def process_file(uploaded_file):
27
+ text = ""
28
+ file_type = uploaded_file.name.split('.')[-1].lower()
29
+
30
+ if file_type == 'pdf':
31
+ pdf_reader = PdfReader(uploaded_file)
32
+ for page in pdf_reader.pages:
33
+ text += page.extract_text()
34
+
35
+ elif file_type == 'txt':
36
+ text = uploaded_file.read().decode('utf-8')
37
+
38
+ elif file_type == 'docx':
39
+ doc = Document(uploaded_file)
40
+ for para in doc.paragraphs:
41
+ text += para.text + "\n"
42
+
43
+ return clean_text(text)
44
+
45
+ def clean_text(text):
46
+ text = re.sub(r'\s+', ' ', text)
47
+ text = re.sub(r'[^\x00-\x7F]+', ' ', text)
48
+ return text
49
+
50
+ def split_text(text, chunk_size=500):
51
+ return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
52
+
53
+ def create_faiss_index(text_chunks, embed_model):
54
+ embeddings = embed_model.encode(text_chunks)
55
+ dimension = embeddings.shape[1]
56
+ index = faiss.IndexFlatL2(dimension)
57
+ index.add(np.array(embeddings).astype('float32'))
58
+ return index
59
+
60
+ def generate_summary(text, model, tokenizer):
61
+ inputs = tokenizer(f"Summarize this document: {text[:3000]}", return_tensors="pt", max_length=4096, truncation=True)
62
+ summary_ids = model.generate(inputs.input_ids, max_length=500)
63
+ return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
64
+
65
+ def answer_question(question, index, text_chunks, embed_model, model, tokenizer):
66
+ question_embed = embed_model.encode([question])
67
+ _, indices = index.search(question_embed.astype('float32'), 3)
68
+
69
+ context = " ".join([text_chunks[i] for i in indices[0]])
70
+ prompt = f"Context: {context}\n\nQuestion: {question}\nAnswer:"
71
+
72
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=4096, truncation=True)
73
+ outputs = model.generate(inputs.input_ids, max_length=500)
74
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
75
+
76
+ def main():
77
+ st.title("📖 RAG Book Assistant with IBM Granite")
78
+
79
+ embed_model, summary_model, summary_tokenizer, qa_model, qa_tokenizer = load_models()
80
+
81
+ uploaded_file = st.file_uploader("Upload a document (PDF/TXT/DOCX)", type=['pdf', 'txt', 'docx'])
82
+
83
+ if uploaded_file and 'processed' not in st.session_state:
84
+ with st.spinner("Processing document..."):
85
+ text = process_file(uploaded_file)
86
+ text_chunks = split_text(text)
87
+
88
+ st.session_state.text_chunks = text_chunks
89
+ st.session_state.faiss_index = create_faiss_index(text_chunks, embed_model)
90
+
91
+ summary = generate_summary(text, summary_model, summary_tokenizer)
92
+ st.session_state.summary = summary
93
+ st.session_state.processed = True
94
+
95
+ if 'processed' in st.session_state:
96
+ st.subheader("Document Summary")
97
+ st.write(st.session_state.summary)
98
+
99
+ st.divider()
100
+
101
+ question = st.text_input("Ask a question about the document:")
102
+ if question:
103
+ answer = answer_question(
104
+ question,
105
+ st.session_state.faiss_index,
106
+ st.session_state.text_chunks,
107
+ embed_model,
108
+ qa_model,
109
+ qa_tokenizer
110
+ )
111
+ st.info(f"Answer: {answer}")
112
+
113
+ if __name__ == "__main__":
114
+ main()