import logging from pydantic import BaseModel, Field from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.runnables import RunnablePassthrough from langchain_chroma import Chroma from langchain_huggingface.embeddings import HuggingFaceEmbeddings from langchain_community.document_loaders import JSONLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from prompt import prompt from utils import execute_query_and_return_df from constant import ( GEMINI_MODEL, GOOGLE_API_KEY, PATH_SCHEMA, PATH_DB, EMBEDDING_MODEL ) class SQLOutput(BaseModel): query: str = Field(description="The SQL query to run.") reasoning: str = Field(description="Reasoning to understand the SQL query.") class Text2SQLRAG: def __init__(self, path_schema: str = PATH_SCHEMA, path_db: str = PATH_DB, model: str = GEMINI_MODEL, api_key: str = GOOGLE_API_KEY, embedding_model: str = EMBEDDING_MODEL ): """ A class for generating SQL queries based on natural language text. """ self.logger = logging.getLogger(__name__) self.logger.info('Initializing Text2SQLRAG') model_kwargs = { "max_tokens": 512, "temperature": 0.2, "top_k": 250, "top_p": 1, "stop_sequences": ["\n\nHuman:"] } self.schema = path_schema self.db = path_db self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model) self.model = ChatGoogleGenerativeAI( model=model, api_key=api_key, model_kwargs=model_kwargs ) self.llm = self.model.with_structured_output(SQLOutput) self.retriever = self._indexing_vectore() def _indexing_vectore(self): """ Indexes the database schema using a vector store for efficient retrieval. This method loads the schema from a JSON file, splits it into chunks, embeds the chunks using a specified embedding model, and stores them in a vector store. It returns a retriever configured to search for the top k relevant documents. Returns: retriever: An object capable of retrieving the most relevant schema chunks based on the given search parameters. """ self.logger.info('Indexing schema') db_schema_loader = JSONLoader( file_path=self.schema, jq_schema='.', text_content=False ) text_splitter = RecursiveCharacterTextSplitter( separators=["separator"], chunk_size=10000, chunk_overlap=100 ) docs = text_splitter.split_documents(db_schema_loader.load()) vectorstore = Chroma.from_documents(documents=docs, embedding=self.embeddings) retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) self.logger.info('Finished indexing schema') return retriever def run(self, question: str): self.logger.info(f'Running Text2SQLRAG for question: {question}') rag_chain = ( {"context": self.retriever, "question": RunnablePassthrough()} | prompt | self.llm ) response = rag_chain.invoke(question) return response