chaaim123 commited on
Commit
fb4cb2b
·
verified ·
1 Parent(s): f1a2317

Create retriever/llm_manager.py

Browse files
Files changed (1) hide show
  1. retriever/llm_manager.py +309 -0
retriever/llm_manager.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import List, Dict, Any, Tuple
4
+ from langchain_groq import ChatGroq
5
+ from langchain.chains import RetrievalQA
6
+ from langchain_core.documents import Document
7
+ from langchain_core.retrievers import BaseRetriever
8
+ from langchain.chains.summarize import load_summarize_chain
9
+ from langchain.prompts import PromptTemplate
10
+
11
+ class LLMManager:
12
+ DEFAULT_MODEL = "gemma2-9b-it" # Set the default model name
13
+
14
+ def __init__(self):
15
+ self.generation_llm = None
16
+ logging.info("LLMManager initialized")
17
+
18
+ # Initialize the default model during construction
19
+ try:
20
+ self.initialize_generation_llm(self.DEFAULT_MODEL)
21
+ logging.info(f"Initialized default LLM model: {self.DEFAULT_MODEL}")
22
+ except ValueError as e:
23
+ logging.error(f"Failed to initialize default LLM model: {str(e)}")
24
+
25
+ def initialize_generation_llm(self, model_name: str) -> None:
26
+ """
27
+ Initialize the generation LLM using the Groq API.
28
+
29
+ Args:
30
+ model_name (str): The name of the model to use for generation.
31
+
32
+ Raises:
33
+ ValueError: If GROQ_API_KEY is not set.
34
+ """
35
+ api_key = os.getenv("GROQ_API_KEY")
36
+ if not api_key:
37
+ raise ValueError("GROQ_API_KEY is not set. Please add it in your environment variables.")
38
+
39
+ os.environ["GROQ_API_KEY"] = api_key
40
+ self.generation_llm = ChatGroq(model=model_name, temperature=0.7)
41
+ self.generation_llm.name = model_name
42
+ logging.info(f"Generation LLM {model_name} initialized")
43
+
44
+ def reinitialize_llm(self, model_name: str) -> str:
45
+ """
46
+ Reinitialize the LLM with a new model name.
47
+
48
+ Args:
49
+ model_name (str): The name of the new model to initialize.
50
+
51
+ Returns:
52
+ str: Status message indicating success or failure.
53
+ """
54
+ try:
55
+ self.initialize_generation_llm(model_name)
56
+ return f"LLM model changed to {model_name}"
57
+ except ValueError as e:
58
+ logging.error(f"Failed to reinitialize LLM with model {model_name}: {str(e)}")
59
+ return f"Error: Failed to change LLM model: {str(e)}"
60
+
61
+ def generate_response(self, question: str, relevant_docs: List[Dict[str, Any]]) -> Tuple[str, List[Document]]:
62
+ """
63
+ Generate a response using the generation LLM based on the question and relevant documents.
64
+
65
+ Args:
66
+ question (str): The user's query.
67
+ relevant_docs (List[Dict[str, Any]]): List of relevant document chunks with text, metadata, and scores.
68
+
69
+ Returns:
70
+ Tuple[str, List[Document]]: The LLM's response and the source documents used.
71
+
72
+ Raises:
73
+ ValueError: If the generation LLM is not initialized.
74
+ Exception: If there's an error during the QA chain invocation.
75
+ """
76
+ if not self.generation_llm:
77
+ raise ValueError("Generation LLM is not initialized. Call initialize_generation_llm first.")
78
+
79
+ # Convert the relevant documents into LangChain Document objects
80
+ documents = [
81
+ Document(page_content=doc['text'], metadata=doc['metadata'])
82
+ for doc in relevant_docs
83
+ ]
84
+
85
+ # Create a proper retriever by subclassing BaseRetriever
86
+ class SimpleRetriever(BaseRetriever):
87
+ def __init__(self, docs: List[Document], **kwargs):
88
+ super().__init__(**kwargs) # Pass kwargs to BaseRetriever
89
+ self._docs = docs # Use a private attribute to store docs
90
+ logging.debug(f"SimpleRetriever initialized with {len(docs)} documents")
91
+
92
+ def _get_relevant_documents(self, query: str) -> List[Document]:
93
+ logging.debug(f"SimpleRetriever._get_relevant_documents called with query: {query}")
94
+ return self._docs
95
+
96
+ async def _aget_relevant_documents(self, query: str) -> List[Document]:
97
+ logging.debug(f"SimpleRetriever._aget_relevant_documents called with query: {query}")
98
+ return self._docs
99
+
100
+ # Instantiate the retriever
101
+ retriever = SimpleRetriever(docs=documents)
102
+
103
+ # Create a retrieval-based question-answering chain
104
+ qa_chain = RetrievalQA.from_chain_type(
105
+ llm=self.generation_llm,
106
+ retriever=retriever,
107
+ return_source_documents=True
108
+ )
109
+
110
+ try:
111
+ result = qa_chain.invoke({"query": question})
112
+ response = result['result']
113
+ source_docs = result['source_documents']
114
+ #logging.info(f"Generated response for question: {question} : {response}")
115
+ return response, source_docs
116
+ except Exception as e:
117
+ logging.error(f"Error during QA chain invocation: {str(e)}")
118
+ raise e
119
+
120
+ def generate_summary_v0(self, chunks: any):
121
+ logging.info("Generating summary ...")
122
+
123
+ # Limit the number of chunks (for example, top 30 chunks)
124
+ limited_chunks = chunks[:30]
125
+
126
+ # Combine text from the selected chunks
127
+ full_text = "\n".join(chunk['text'] for chunk in limited_chunks)
128
+ text_length = len(full_text)
129
+ logging.info(f"Total text length (characters): {text_length}")
130
+
131
+ # Define a maximum character limit to fit in a 1024-token context.
132
+ # For many models, roughly 3200 characters is a safe limit.
133
+ MAX_CHAR_LIMIT = 3200
134
+ if text_length > MAX_CHAR_LIMIT:
135
+ logging.warning(f"Input text too long ({text_length} chars), truncating to {MAX_CHAR_LIMIT} chars.")
136
+ full_text = full_text[:MAX_CHAR_LIMIT]
137
+
138
+ # Define a custom prompt to instruct concise summarization in bullet points.
139
+ custom_prompt_template = """
140
+ You are an expert summarizer. Summarize the following text into a concise summary using bullet points.
141
+ Ensure that the final summary is no longer than 20-30 bullet points and fits within 15-20 lines.
142
+ Focus only on the most critical points.
143
+
144
+ Text to summarize:
145
+ {text}
146
+
147
+ Summary:
148
+ """
149
+ prompt = PromptTemplate(input_variables=["text"], template=custom_prompt_template)
150
+
151
+ # Use the 'stuff' chain type to send a single LLM request with our custom prompt.
152
+ chain = load_summarize_chain(self.generation_llm, chain_type="stuff", prompt=prompt)
153
+
154
+ # Wrap the full text in a single Document object (chain expects a list of Documents)
155
+ docs = [Document(page_content=full_text)]
156
+
157
+ # Generate the summary
158
+ summary = chain.invoke(docs)
159
+ return summary['output_text']
160
+
161
+ def generate_questions(self, chunks: any):
162
+ logging.info("Generating sample questions ...")
163
+
164
+ # Use the top 30 chunks or fewer
165
+ limited_chunks = chunks[:30]
166
+
167
+ # Combine text from chunks
168
+ full_text = "\n".join(chunk['text'] for chunk in limited_chunks)
169
+ text_length = len(full_text)
170
+ logging.info(f"Total text length for questions: {text_length}")
171
+
172
+ MAX_CHAR_LIMIT = 3200
173
+ if text_length > MAX_CHAR_LIMIT:
174
+ logging.warning(f"Input text too long ({text_length} chars), truncating to {MAX_CHAR_LIMIT} chars.")
175
+ full_text = full_text[:MAX_CHAR_LIMIT]
176
+
177
+ # Prompt template for generating questions
178
+ question_prompt_template = """
179
+ You are an AI expert at creating questions from documents.
180
+
181
+ Based on the text below, generate not less than 20 insightful and highly relevant sample questions that a user might ask to better understand the content.
182
+
183
+ **Instructions:**
184
+ - Questions must be specific to the document's content and context.
185
+ - Avoid generic questions like 'What is this document about?'
186
+ - Do not include numbers, prefixes (e.g., '1.', '2.'), or explanations (e.g., '(Clarifies...)').
187
+ - Each question should be a single, clear sentence ending with a question mark.
188
+ - Focus on key concepts, processes, components, or use cases mentioned in the text.
189
+
190
+ Text:
191
+ {text}
192
+
193
+ Output format:
194
+ What is the purpose of the Communication Server in Collateral Management?
195
+ How does the system handle data encryption for secure communication?
196
+ ...
197
+ """
198
+ prompt = PromptTemplate(input_variables=["text"], template=question_prompt_template)
199
+
200
+ chain = load_summarize_chain(self.generation_llm, chain_type="stuff", prompt=prompt)
201
+ docs = [Document(page_content=full_text)]
202
+
203
+ try:
204
+ result = chain.invoke(docs)
205
+ question_output = result.get("output_text", "").strip()
206
+
207
+ # Clean and parse the output into a list of questions
208
+ questions = []
209
+ for line in question_output.split("\n"):
210
+ # Remove any leading/trailing whitespace, numbers, or bullet points
211
+ cleaned_line = line.strip().strip("-*1234567890. ").rstrip(".")
212
+ # Remove any explanation in parentheses
213
+ cleaned_line = cleaned_line.split("(")[0].strip()
214
+ # Ensure the line is a valid question (ends with '?' and is not empty)
215
+ if cleaned_line and cleaned_line.endswith("?"):
216
+ questions.append(cleaned_line)
217
+
218
+ # Limit to 10 questions
219
+ questions = questions[:10]
220
+ logging.info(f"Generated questions: {questions}")
221
+ return questions
222
+ except Exception as e:
223
+ logging.error(f"Error generating questions: {e}")
224
+ return []
225
+
226
+ def generate_summary(self, chunks: Any, toc_text: Any, summary_type: str = "medium") -> str:
227
+ """
228
+ Generate a summary of the document using LangChain's summarization chains.
229
+
230
+ Args:
231
+ vector_store_manager: Instance of VectorStoreManager with a FAISS vector store.
232
+ summary_type (str): Type of summary ("small", "medium", "detailed").
233
+ k (int): Number of chunks to retrieve from the vector store.
234
+ include_toc (bool): Whether to include the table of contents (if available).
235
+
236
+ Returns:
237
+ str: Generated summary.
238
+
239
+ Raises:
240
+ ValueError: If summary_type is invalid or vector store is not initialized.
241
+ """
242
+
243
+ # Define chunk retrieval parameters based on summary type
244
+ if summary_type == "small":
245
+ k = min(k, 3) # Fewer chunks for small summary
246
+ chain_type = "stuff" # Use stuff for small summaries
247
+ word_count = "50-100"
248
+ elif summary_type == "medium":
249
+ k = min(k, 10)
250
+ chain_type = "map_reduce" # Use map-reduce for medium summaries
251
+ word_count = "200-400"
252
+ else: # detailed
253
+ k = min(k, 20)
254
+ chain_type = "map_reduce" # Use map-reduce for detailed summaries
255
+ word_count = "500-1000"
256
+
257
+ # Define prompts
258
+ if chain_type == "stuff":
259
+ prompt = PromptTemplate(
260
+ input_variables=["text"],
261
+ template=(
262
+ "Generate a {summary_type} summary ({word_count} words) of the following document excerpts. "
263
+ "Focus on key points and ensure clarity. Stick strictly to the provided text:\n\n"
264
+ "{toc_prompt}{text}"
265
+ ).format(
266
+ summary_type=summary_type,
267
+ word_count=word_count,
268
+ toc_prompt="Table of Contents:\n{toc_text}\n\n" if toc_text else ""
269
+ )
270
+ )
271
+ chain = load_summarize_chain(
272
+ llm=self.generation_llm,
273
+ chain_type="stuff",
274
+ prompt=prompt
275
+ )
276
+ else: # map_reduce
277
+ map_prompt = PromptTemplate(
278
+ input_variables=["text"],
279
+ template=(
280
+ "Summarize the following document excerpt in 1-2 sentences, focusing on key points. "
281
+ "Consider the document's structure from this table of contents:\n\n"
282
+ "Table of Contents:\n{toc_text}\n\nExcerpt:\n{text}"
283
+ ).format(toc_text=toc_text if toc_text else "Not provided")
284
+ )
285
+ combine_prompt = PromptTemplate(
286
+ input_variables=["text"],
287
+ template=(
288
+ "Combine the following summaries into a cohesive {summary_type} summary "
289
+ "({word_count} words) of the document. Ensure clarity, avoid redundancy, and "
290
+ "organize by key themes or sections if applicable:\n\n{text}"
291
+ ).format(summary_type=summary_type, word_count=word_count)
292
+ )
293
+ chain = load_summarize_chain(
294
+ llm=self.generation_llm,
295
+ chain_type="map_reduce",
296
+ map_prompt=map_prompt,
297
+ combine_prompt=combine_prompt,
298
+ return_intermediate_steps=False
299
+ )
300
+
301
+ # Run the chain
302
+ try:
303
+ logging.info(f"Generating {summary_type} summary with {len(chunks)} chunks")
304
+ summary = chain.run(chunks)
305
+ logging.info(f"{summary_type.capitalize()} summary generated successfully")
306
+ return summary
307
+ except Exception as e:
308
+ logging.error(f"Error generating summary: {str(e)}")
309
+ return f"Error generating summary: {str(e)}"