sunbal7 commited on
Commit
04b42d6
·
verified ·
1 Parent(s): 999f1a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -155
app.py CHANGED
@@ -1,157 +1,87 @@
1
- # app.py
2
-
3
  import streamlit as st
4
- import arxiv
5
- import networkx as nx
6
- import matplotlib.pyplot as plt
7
- import datetime
8
-
9
- from transformers import pipeline
10
-
11
- # Initialize Hugging Face pipelines for summarization and text generation
12
- @st.cache_resource(show_spinner=False)
13
- def load_summarizer():
14
- return pipeline("summarization", model="facebook/bart-large-cnn")
15
-
16
- @st.cache_resource(show_spinner=False)
17
- def load_generator():
18
- return pipeline("text-generation", model="gpt2")
19
-
20
- summarizer = load_summarizer()
21
- generator = load_generator()
22
-
23
- # -------------------------------
24
- # Helper Functions
25
- # -------------------------------
26
-
27
- def retrieve_papers(query, max_results=5):
28
- """
29
- Retrieve academic papers from arXiv based on the query.
30
- """
31
- search = arxiv.Search(query=query, max_results=max_results)
32
- papers = []
33
- for result in search.results():
34
- paper = {
35
- "title": result.title,
36
- "summary": result.summary,
37
- "url": result.pdf_url,
38
- "authors": [author.name for author in result.authors],
39
- "published": result.published
40
- }
41
- papers.append(paper)
42
- return papers
43
-
44
- def summarize_text(text):
45
- """
46
- Use a generative model to create a concise summary of the input text.
47
- """
48
- # The summarizer may need the text to be below a certain token length.
49
- # If necessary, you could chunk the text.
50
- summarized = summarizer(text, max_length=130, min_length=30, do_sample=False)
51
- return summarized[0]['summary_text']
52
-
53
- def generate_concept_map(papers):
54
- """
55
- Generate a visual concept map by connecting papers with shared authors.
56
- """
57
- G = nx.Graph()
58
- # Add nodes for each paper title
59
- for paper in papers:
60
- G.add_node(paper['title'])
61
- # Create edges between papers that share at least one common author
62
- for i in range(len(papers)):
63
- for j in range(i + 1, len(papers)):
64
- common_authors = set(papers[i]['authors']).intersection(set(papers[j]['authors']))
65
- if common_authors:
66
- G.add_edge(papers[i]['title'], papers[j]['title'])
67
- return G
68
-
69
- def generate_citation(paper):
70
- """
71
- Format citation information in APA style.
72
- """
73
- authors = ", ".join(paper['authors'])
74
- year = paper['published'].year if isinstance(paper['published'], datetime.datetime) else "n.d."
75
- title = paper['title']
76
- url = paper['url']
77
- citation = f"{authors} ({year}). {title}. Retrieved from {url}"
78
- return citation
79
-
80
- def generate_proposal_suggestions(text):
81
- """
82
- Generate research proposal suggestions based on the synthesized literature review.
83
- """
84
- prompt = (
85
- "Based on the following literature review, propose a novel research proposal "
86
- "including potential research questions and an outline for experimental design.\n\n"
87
- f"{text}\n\nProposal:"
88
- )
89
- generated = generator(prompt, max_new_tokens=50, num_return_sequences=1)
90
- return generated[0]['generated_text']
91
-
92
- # -------------------------------
93
- # Streamlit User Interface
94
- # -------------------------------
95
-
96
- st.title("📚PaperPilot – The Intelligent Academic Navigator")
97
- st.markdown("Welcome to **PaperPilot**! Enter a research topic or question below to retrieve academic papers, generate summaries, visualize concept maps, format citations, and get research proposal suggestions.")
98
-
99
- # Input section
100
- query = st.text_input("Research Topic or Question:")
101
-
102
- if st.button("Search"):
103
-
104
- if query.strip() == "":
105
- st.warning("Please enter a research topic or question.")
106
- else:
107
- # --- Step 1: Retrieve Papers ---
108
- with st.spinner("Retrieving relevant academic papers..."):
109
- papers = retrieve_papers(query, max_results=5)
110
-
111
- if not papers:
112
- st.error("No papers found. Please try a different query.")
113
  else:
114
- st.success(f"Found {len(papers)} papers.")
115
-
116
- # --- Step 2: Display Retrieved Papers ---
117
- st.header("Retrieved Papers")
118
- for idx, paper in enumerate(papers, start=1):
119
- with st.expander(f"{idx}. {paper['title']}"):
120
- st.markdown(f"**Authors:** {', '.join(paper['authors'])}")
121
- st.markdown(f"**Published:** {paper['published'].strftime('%Y-%m-%d') if isinstance(paper['published'], datetime.datetime) else 'n.d.'}")
122
- st.markdown(f"**Link:** [PDF Link]({paper['url']})")
123
- st.markdown("**Abstract:**")
124
- st.write(paper['summary'])
125
-
126
- # --- Step 3: Generate Summaries & Literature Review ---
127
- st.header("Automated Summaries & Literature Review")
128
- combined_summary = ""
129
- for paper in papers:
130
- st.subheader(f"Summary for: {paper['title']}")
131
- # Use the paper summary as input for further summarization
132
- summary_text = summarize_text(paper['summary'])
133
- st.write(summary_text)
134
- combined_summary += summary_text + " "
135
-
136
- # --- Step 4: Create Visual Concept Map & Gap Analysis ---
137
- st.header("Visual Concept Map & Gap Analysis")
138
- G = generate_concept_map(papers)
139
- if len(G.nodes) > 0:
140
- fig, ax = plt.subplots(figsize=(8, 6))
141
- pos = nx.spring_layout(G, seed=42)
142
- nx.draw_networkx(G, pos, with_labels=True, node_color='skyblue', edge_color='gray', node_size=1500, font_size=8, ax=ax)
143
- st.pyplot(fig)
144
- else:
145
- st.info("Not enough data to generate a concept map.")
146
-
147
- # --- Step 5: Citation & Reference Management ---
148
- st.header("Formatted Citations (APA Style)")
149
- for paper in papers:
150
- citation = generate_citation(paper)
151
- st.markdown(f"- {citation}")
152
-
153
- # --- Step 6: Research Proposal Assistance ---
154
- st.header("Research Proposal Suggestions")
155
- proposal = generate_proposal_suggestions(combined_summary)
156
- st.write(proposal)
157
- st.caption("Built with ❤️")
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import faiss
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
+ import PyPDF2
8
+ import os
9
+
10
+ # Model Setup
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model_path = "ibm-granite/granite-3.1-1b-a400m-instruct"
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
15
+ model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
16
+ model.eval()
17
+
18
+ # Embedding Model for FAISS
19
+ embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
20
+
21
+ # FAISS Index
22
+ dimension = 384 # Embedding size for MiniLM
23
+ index = faiss.IndexFlatL2(dimension)
24
+ docs = [] # Store document texts
25
+
26
+ # Function to extract text from PDF
27
+ def extract_text_from_pdf(uploaded_file):
28
+ reader = PyPDF2.PdfReader(uploaded_file)
29
+ text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
30
+ return text
31
+
32
+ # Function to process uploaded documents
33
+ def process_documents(files):
34
+ global docs, index
35
+ docs = []
36
+
37
+ for file in files:
38
+ if file.type == "application/pdf":
39
+ text = extract_text_from_pdf(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  else:
41
+ text = file.getvalue().decode("utf-8")
42
+
43
+ docs.append(text)
44
+
45
+ embeddings = embedding_model.encode(docs)
46
+ index.add(np.array(embeddings))
47
+
48
+ # Function to retrieve relevant context
49
+ def retrieve_context(query):
50
+ query_embedding = embedding_model.encode([query])
51
+ distances, indices = index.search(np.array(query_embedding), k=1)
52
+
53
+ if len(indices) > 0 and indices[0][0] < len(docs):
54
+ return docs[indices[0][0]]
55
+ return "No relevant context found."
56
+
57
+ # Function to generate response using IBM Granite
58
+ def generate_response(query, context):
59
+ chat = [
60
+ {"role": "system", "content": "You are a helpful assistant using retrieved knowledge."},
61
+ {"role": "user", "content": f"Context: {context}\nQuestion: {query}\nAnswer based on context:"},
62
+ ]
63
+ chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
64
+
65
+ input_tokens = tokenizer(chat, return_tensors="pt").to(device)
66
+ output = model.generate(**input_tokens, max_new_tokens=200)
67
+ return tokenizer.batch_decode(output, skip_special_tokens=True)[0]
68
+
69
+ # Streamlit UI
70
+ st.set_page_config(page_title="📖 RAG-Based AI", page_icon="🤖")
71
+ st.title("📖 RAG-based Q&A using IBM Granite")
72
+ st.subheader("Upload documents and ask questions!")
73
+
74
+ uploaded_files = st.file_uploader("Upload PDFs or TXT files", accept_multiple_files=True)
75
+
76
+ if uploaded_files:
77
+ with st.spinner("Processing documents..."):
78
+ process_documents(uploaded_files)
79
+ st.success("Documents uploaded and indexed!")
80
+
81
+ query = st.text_input("Ask a question:")
82
+ if st.button("Get Answer"):
83
+ with st.spinner("Retrieving and generating response..."):
84
+ context = retrieve_context(query)
85
+ response = generate_response(query, context)
86
+ st.markdown("### 🤖 Answer:")
87
+ st.write(response)