zliang commited on
Commit
3a16e8c
·
verified ·
1 Parent(s): 52d159a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -8
app.py CHANGED
@@ -79,21 +79,31 @@ def scroll_to_bottom():
79
  # ----------------------------
80
  # Core Processing Functions
81
  # ----------------------------
82
-
83
  @st.cache_data(show_spinner=False, ttl=3600)
84
  @handle_errors
85
- def summarize_pdf(_pdf_file_path, num_clusters=10):
86
- # Basic summarization without citations
 
 
 
 
87
  embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
88
  llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=openai_api_key, temperature=0.3)
89
 
 
90
  prompt = ChatPromptTemplate.from_template(
91
- """Generate a comprehensive summary with these elements:
92
  1. Key findings and conclusions
93
  2. Main methodologies used
94
  3. Important data points
95
  4. Limitations mentioned
96
- Context: {topic}"""
 
 
 
 
 
 
97
  )
98
 
99
  loader = PyMuPDFLoader(_pdf_file_path)
@@ -106,11 +116,24 @@ Context: {topic}"""
106
 
107
  embeddings = embeddings_model.embed_documents(split_contents)
108
  kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(embeddings)
109
- closest_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1))
110
- for center in kmeans.cluster_centers_]
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  chain = prompt | llm | StrOutputParser()
113
- return chain.invoke({"topic": ' '.join([split_contents[idx] for idx in closest_indices])})
 
 
114
 
115
 
116
  @st.cache_data(show_spinner=False, ttl=3600)
 
79
  # ----------------------------
80
  # Core Processing Functions
81
  # ----------------------------
 
82
  @st.cache_data(show_spinner=False, ttl=3600)
83
  @handle_errors
84
+ def summarize_pdf_with_citations(_pdf_file_path, num_clusters=10):
85
+ """
86
+ Generates a summary that includes in-text citations based on selected context chunks.
87
+ Each context chunk is numbered (e.g. [1], [2], etc.) and is referenced in the summary.
88
+ After the summary, a reference list is provided mapping each citation number to the full original text excerpt.
89
+ """
90
  embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
91
  llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=openai_api_key, temperature=0.3)
92
 
93
+ # Updated prompt instructs the LLM to use the full excerpt in the reference list.
94
  prompt = ChatPromptTemplate.from_template(
95
+ """Generate a comprehensive summary with the following elements:
96
  1. Key findings and conclusions
97
  2. Main methodologies used
98
  3. Important data points
99
  4. Limitations mentioned
100
+
101
+ For any information that is directly derived from the provided context excerpts, insert an in-text citation in the format [n] where n corresponds to the excerpt number.
102
+
103
+ After the summary, please provide a reference list where each citation number is mapped to the full original text excerpt as provided below. Do not simply echo the citation number; include the complete excerpt text.
104
+
105
+ Context Excerpts:
106
+ {contexts}"""
107
  )
108
 
109
  loader = PyMuPDFLoader(_pdf_file_path)
 
116
 
117
  embeddings = embeddings_model.embed_documents(split_contents)
118
  kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(embeddings)
119
+
120
+ citation_indices = []
121
+ for center in kmeans.cluster_centers_:
122
+ distances = np.linalg.norm(embeddings - center, axis=1)
123
+ idx = int(np.argmin(distances))
124
+ citation_indices.append(idx)
125
+
126
+ # Create a context string with citations including the full original text excerpts
127
+ citation_contexts = []
128
+ for i, idx in enumerate(citation_indices):
129
+ # Using the full excerpt from split_contents for the reference list.
130
+ citation_contexts.append(f"[{i+1}]: {split_contents[idx]}")
131
+ combined_contexts = "\n\n".join(citation_contexts)
132
 
133
  chain = prompt | llm | StrOutputParser()
134
+ result = chain.invoke({"contexts": combined_contexts})
135
+ return result
136
+
137
 
138
 
139
  @st.cache_data(show_spinner=False, ttl=3600)