import concurrent.futures import os from loguru import logger from qdrant_client.models import FieldCondition, Filter, MatchValue from openai import OpenAI from rag_demo.preprocessing.base import ( EmbeddedChunk, ) from rag_demo.rag.base.query import EmbeddedQuery, Query from .query_expansion import QueryExpansion from .reranker import Reranker from .prompt_templates import AnswerGenerationTemplate from .source_annotator import SourceAnnotator from .query_classifier import QueryClassifier from dotenv import load_dotenv load_dotenv() def flatten(nested_list: list) -> list: """Flatten a list of lists into a single list.""" return [item for sublist in nested_list for item in sublist] class RAGPipeline: def __init__(self, mock: bool = False) -> None: self._query_expander = QueryExpansion(mock=mock) self._reranker = Reranker(mock=mock) self._source_annotator = SourceAnnotator() self._query_classifier = QueryClassifier(mock=mock) def search( self, query: str, k: int = 3, expand_to_n_queries: int = 3, ) -> list: query_model = Query.from_str(query) n_generated_queries = self._query_expander.generate( query_model, expand_to_n=expand_to_n_queries ) logger.info( f"Successfully generated {len(n_generated_queries)} search queries.", ) with concurrent.futures.ThreadPoolExecutor() as executor: search_tasks = [ executor.submit(self._search, _query_model, k) for _query_model in n_generated_queries ] n_k_documents = [ task.result() for task in concurrent.futures.as_completed(search_tasks) ] n_k_documents = flatten(n_k_documents) n_k_documents = list(set(n_k_documents)) logger.info(f"{len(n_k_documents)} documents retrieved successfully") if len(n_k_documents) > 0: # k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k) k_documents = n_k_documents[:k] else: k_documents = [] return k_documents def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]: assert k >= 3, "k should be >= 3" def _search_data( data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery ) -> list[EmbeddedChunk]: return data_category_odm.search( query_vector=embedded_query.embedding, limit=k, ) api = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) embedded_query: EmbeddedQuery = EmbeddedQuery( embedding=api.embeddings.create( model="text-embedding-3-small", input=query.content ) .data[0] .embedding, id=query.id, content=query.content, ) retrieved_chunks = _search_data(EmbeddedChunk, embedded_query) logger.info(f"{len(retrieved_chunks)} documents retrieved successfully") return retrieved_chunks def rerank( self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int ) -> list[EmbeddedChunk]: if isinstance(query, str): query = Query.from_str(query) reranked_documents = self._reranker.generate( query=query, chunks=chunks, keep_top_k=keep_top_k ) logger.info(f"{len(reranked_documents)} documents reranked successfully.") return reranked_documents def generate_answer(self, query: str, reranked_chunks: list[EmbeddedChunk]) -> str: context = "" for chunk in reranked_chunks: context += "\n Document: " context += chunk.content api = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) answer_generation_template = AnswerGenerationTemplate() prompt = answer_generation_template.create_template(context, query) logger.info(prompt) response = api.chat.completions.create( model="gpt-4o-mini", messages=[{"role": "user", "content": prompt}], max_tokens=8192, ) return response.choices[0].message.content def add_context(self, response: str, reranked_chunks: list[EmbeddedChunk]) -> str: logger.info("Adding context to the answer") return self._source_annotator.annotate(response, reranked_chunks) def rag(self, query: str) -> tuple[str, list[str]]: query_type = self._query_classifier.generate(Query.from_str(query)) logger.info(f"Query type: {query_type}") if query_type == "Sources_needed": docs = self.search(query, k=10) else: docs = [] answer = self.generate_answer(query, docs) if docs: annotated_answer = self.add_context(answer, docs) else: annotated_answer = answer return ( annotated_answer, list(set([doc.metadata["filename"].split(".pdf")[0] for doc in docs])), )