File size: 3,479 Bytes
af733da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import logging
from pydantic import BaseModel, Field
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.runnables import RunnablePassthrough
from langchain_chroma import Chroma
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import JSONLoader  
from langchain.text_splitter import RecursiveCharacterTextSplitter

from prompt import prompt
from utils import execute_query_and_return_df
from constant import (
    GEMINI_MODEL, 
    GOOGLE_API_KEY, 
    PATH_SCHEMA, 
    PATH_DB,
    EMBEDDING_MODEL
)



class SQLOutput(BaseModel):
    query: str = Field(description="The SQL query to run.")
    reasoning: str = Field(description="Reasoning to understand the SQL query.")

class Text2SQLRAG:
    def __init__(self, 
                path_schema: str = PATH_SCHEMA,
                path_db: str = PATH_DB,
                model: str = GEMINI_MODEL, 
                api_key: str = GOOGLE_API_KEY,
                embedding_model: str = EMBEDDING_MODEL
            ):
        """
        A class for generating SQL queries based on natural language text.
        """
        self.logger = logging.getLogger(__name__)
        self.logger.info('Initializing Text2SQLRAG')

        model_kwargs = {  
            "max_tokens": 512,  
            "temperature": 0.2,  
            "top_k": 250,  
            "top_p": 1,  
            "stop_sequences": ["\n\nHuman:"] 
        }
        self.schema = path_schema
        self.db = path_db
        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
        
        self.model = ChatGoogleGenerativeAI(
            model=model,
            api_key=api_key,
            model_kwargs=model_kwargs
        )
        self.llm = self.model.with_structured_output(SQLOutput)
        self.retriever = self._indexing_vectore()

    def _indexing_vectore(self):
        """
        Indexes the database schema using a vector store for efficient retrieval.

        This method loads the schema from a JSON file, splits it into chunks,
        embeds the chunks using a specified embedding model, and stores them in
        a vector store. It returns a retriever configured to search for the top
        k relevant documents.

        Returns:
            retriever: An object capable of retrieving the most relevant schema
            chunks based on the given search parameters.
        """

        self.logger.info('Indexing schema')
        db_schema_loader = JSONLoader(
            file_path=self.schema,  
            jq_schema='.',  
            text_content=False
        )  
        text_splitter = RecursiveCharacterTextSplitter(  
            separators=["separator"], 
            chunk_size=10000,  
            chunk_overlap=100  
        )

        docs = text_splitter.split_documents(db_schema_loader.load())  

        vectorstore = Chroma.from_documents(documents=docs, 
                                        embedding=self.embeddings)

        retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
        self.logger.info('Finished indexing schema')
        return retriever
    
    def run(self, question: str):
        self.logger.info(f'Running Text2SQLRAG for question: {question}')
        rag_chain = (
            {"context": self.retriever, "question": RunnablePassthrough()}
            | prompt
            | self.llm
        )
        response = rag_chain.invoke(question)
        return response