Spaces:
Runtime error
Runtime error
"""Chain for interacting with SQL Database.""" | |
from __future__ import annotations | |
import warnings | |
from typing import Any, Dict, List, Optional | |
from langchain.chains.base import Chain | |
from langchain.chains.llm import LLMChain | |
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS | |
from langchain.schema import BasePromptTemplate | |
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER | |
from langchain_community.utilities.sql_database import SQLDatabase | |
from langchain_core.callbacks.manager import CallbackManagerForChainRun | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.prompts.prompt import PromptTemplate | |
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator | |
INTERMEDIATE_STEPS_KEY = "intermediate_steps" | |
SQL_QUERY = "SQLQuery:" | |
SQL_RESULT = "SQLResult:" | |
class SQLDatabaseChain(Chain): | |
"""Chain for interacting with SQL Database. | |
Example: | |
.. code-block:: python | |
from langchain_experimental.sql import SQLDatabaseChain | |
from langchain_community.llms import OpenAI, SQLDatabase | |
db = SQLDatabase(...) | |
db_chain = SQLDatabaseChain.from_llm(OpenAI(), db) | |
*Security note*: Make sure that the database connection uses credentials | |
that are narrowly-scoped to only include the permissions this chain needs. | |
Failure to do so may result in data corruption or loss, since this chain may | |
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted. | |
The best way to guard against such negative outcomes is to (as appropriate) | |
limit the permissions granted to the credentials used with this chain. | |
This issue shows an example negative outcome if these steps are not taken: | |
https://github.com/langchain-ai/langchain/issues/5923 | |
""" | |
llm_chain: LLMChain | |
llm: Optional[BaseLanguageModel] = None | |
"""[Deprecated] LLM wrapper to use.""" | |
database: SQLDatabase = Field(exclude=True) | |
"""SQL Database to connect to.""" | |
prompt: Optional[BasePromptTemplate] = None | |
"""[Deprecated] Prompt to use to translate natural language to SQL.""" | |
top_k: int = 5 | |
"""Number of results to return from the query""" | |
input_key: str = "query" #: :meta private: | |
output_key: str = "result" #: :meta private: | |
return_sql: bool = False | |
"""Will return sql-command directly without executing it""" | |
return_intermediate_steps: bool = False | |
"""Whether or not to return the intermediate steps along with the final answer.""" | |
return_direct: bool = False | |
"""Whether or not to return the result of querying the SQL table directly.""" | |
use_query_checker: bool = False | |
"""Whether or not the query checker tool should be used to attempt | |
to fix the initial SQL from the LLM.""" | |
query_checker_prompt: Optional[BasePromptTemplate] = None | |
"""The prompt template that should be used by the query checker""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
arbitrary_types_allowed = True | |
def raise_deprecation(cls, values: Dict) -> Dict: | |
if "llm" in values: | |
warnings.warn( | |
"Directly instantiating an SQLDatabaseChain with an llm is deprecated. " | |
"Please instantiate with llm_chain argument or using the from_llm " | |
"class method." | |
) | |
if "llm_chain" not in values and values["llm"] is not None: | |
database = values["database"] | |
prompt = values.get("prompt") or SQL_PROMPTS.get( | |
database.dialect, PROMPT | |
) | |
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) | |
return values | |
def input_keys(self) -> List[str]: | |
"""Return the singular input key. | |
:meta private: | |
""" | |
return [self.input_key] | |
def output_keys(self) -> List[str]: | |
"""Return the singular output key. | |
:meta private: | |
""" | |
if not self.return_intermediate_steps: | |
return [self.output_key] | |
else: | |
return [self.output_key, INTERMEDIATE_STEPS_KEY] | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
input_text = f"{inputs[self.input_key]}\n{SQL_QUERY}" | |
_run_manager.on_text(input_text, verbose=self.verbose) | |
# If not present, then defaults to None which is all tables. | |
table_names_to_use = inputs.get("table_names_to_use") | |
table_info = self.database.get_table_info(table_names=table_names_to_use) | |
llm_inputs = { | |
"input": input_text, | |
"top_k": str(self.top_k), | |
"dialect": self.database.dialect, | |
"table_info": table_info, | |
"stop": ["\nSQLResult:"], | |
} | |
if self.memory is not None: | |
for k in self.memory.memory_variables: | |
llm_inputs[k] = inputs[k] | |
intermediate_steps: List = [] | |
try: | |
intermediate_steps.append(llm_inputs.copy()) # input: sql generation | |
sql_cmd = self.llm_chain.predict( | |
callbacks=_run_manager.get_child(), | |
**llm_inputs, | |
).strip() | |
if self.return_sql: | |
return {self.output_key: sql_cmd} | |
if not self.use_query_checker: | |
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose) | |
intermediate_steps.append( | |
sql_cmd | |
) # output: sql generation (no checker) | |
intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec | |
if SQL_QUERY in sql_cmd: | |
sql_cmd = sql_cmd.split(SQL_QUERY)[1].strip() | |
if SQL_RESULT in sql_cmd: | |
sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip() | |
result = self.database.run(sql_cmd) | |
intermediate_steps.append(str(result)) # output: sql exec | |
else: | |
query_checker_prompt = self.query_checker_prompt or PromptTemplate( | |
template=QUERY_CHECKER, input_variables=["query", "dialect"] | |
) | |
query_checker_chain = LLMChain( | |
llm=self.llm_chain.llm, prompt=query_checker_prompt | |
) | |
query_checker_inputs = { | |
"query": sql_cmd, | |
"dialect": self.database.dialect, | |
} | |
checked_sql_command: str = query_checker_chain.predict( | |
callbacks=_run_manager.get_child(), **query_checker_inputs | |
).strip() | |
intermediate_steps.append( | |
checked_sql_command | |
) # output: sql generation (checker) | |
_run_manager.on_text( | |
checked_sql_command, color="green", verbose=self.verbose | |
) | |
intermediate_steps.append( | |
{"sql_cmd": checked_sql_command} | |
) # input: sql exec | |
result = self.database.run(checked_sql_command) | |
intermediate_steps.append(str(result)) # output: sql exec | |
sql_cmd = checked_sql_command | |
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose) | |
_run_manager.on_text(str(result), color="yellow", verbose=self.verbose) | |
# If return direct, we just set the final result equal to | |
# the result of the sql query result, otherwise try to get a human readable | |
# final answer | |
if self.return_direct: | |
final_result = result | |
else: | |
_run_manager.on_text("\nAnswer:", verbose=self.verbose) | |
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:" | |
llm_inputs["input"] = input_text | |
intermediate_steps.append(llm_inputs.copy()) # input: final answer | |
final_result = self.llm_chain.predict( | |
callbacks=_run_manager.get_child(), | |
**llm_inputs, | |
).strip() | |
intermediate_steps.append(final_result) # output: final answer | |
_run_manager.on_text(final_result, color="green", verbose=self.verbose) | |
chain_result: Dict[str, Any] = {self.output_key: final_result} | |
if self.return_intermediate_steps: | |
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps | |
return chain_result | |
except Exception as exc: | |
# Append intermediate steps to exception, to aid in logging and later | |
# improvement of few shot prompt seeds | |
exc.intermediate_steps = intermediate_steps # type: ignore | |
raise exc | |
def _chain_type(self) -> str: | |
return "sql_database_chain" | |
def from_llm( | |
cls, | |
llm: BaseLanguageModel, | |
db: SQLDatabase, | |
prompt: Optional[BasePromptTemplate] = None, | |
**kwargs: Any, | |
) -> SQLDatabaseChain: | |
"""Create a SQLDatabaseChain from an LLM and a database connection. | |
*Security note*: Make sure that the database connection uses credentials | |
that are narrowly-scoped to only include the permissions this chain needs. | |
Failure to do so may result in data corruption or loss, since this chain may | |
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted. | |
The best way to guard against such negative outcomes is to (as appropriate) | |
limit the permissions granted to the credentials used with this chain. | |
This issue shows an example negative outcome if these steps are not taken: | |
https://github.com/langchain-ai/langchain/issues/5923 | |
""" | |
prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT) | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
return cls(llm_chain=llm_chain, database=db, **kwargs) | |
class SQLDatabaseSequentialChain(Chain): | |
"""Chain for querying SQL database that is a sequential chain. | |
The chain is as follows: | |
1. Based on the query, determine which tables to use. | |
2. Based on those tables, call the normal SQL database chain. | |
This is useful in cases where the number of tables in the database is large. | |
""" | |
decider_chain: LLMChain | |
sql_chain: SQLDatabaseChain | |
input_key: str = "query" #: :meta private: | |
output_key: str = "result" #: :meta private: | |
return_intermediate_steps: bool = False | |
def from_llm( | |
cls, | |
llm: BaseLanguageModel, | |
db: SQLDatabase, | |
query_prompt: BasePromptTemplate = PROMPT, | |
decider_prompt: BasePromptTemplate = DECIDER_PROMPT, | |
**kwargs: Any, | |
) -> SQLDatabaseSequentialChain: | |
"""Load the necessary chains.""" | |
sql_chain = SQLDatabaseChain.from_llm(llm, db, prompt=query_prompt, **kwargs) | |
decider_chain = LLMChain( | |
llm=llm, prompt=decider_prompt, output_key="table_names" | |
) | |
return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs) | |
def input_keys(self) -> List[str]: | |
"""Return the singular input key. | |
:meta private: | |
""" | |
return [self.input_key] | |
def output_keys(self) -> List[str]: | |
"""Return the singular output key. | |
:meta private: | |
""" | |
if not self.return_intermediate_steps: | |
return [self.output_key] | |
else: | |
return [self.output_key, INTERMEDIATE_STEPS_KEY] | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
_table_names = self.sql_chain.database.get_usable_table_names() | |
table_names = ", ".join(_table_names) | |
llm_inputs = { | |
"query": inputs[self.input_key], | |
"table_names": table_names, | |
} | |
_lowercased_table_names = [name.lower() for name in _table_names] | |
table_names_from_chain = self.decider_chain.predict_and_parse(**llm_inputs) | |
table_names_to_use = [ | |
name | |
for name in table_names_from_chain | |
if name.lower() in _lowercased_table_names | |
] | |
_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose) | |
_run_manager.on_text( | |
str(table_names_to_use), color="yellow", verbose=self.verbose | |
) | |
new_inputs = { | |
self.sql_chain.input_key: inputs[self.input_key], | |
"table_names_to_use": table_names_to_use, | |
} | |
return self.sql_chain( | |
new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True | |
) | |
def _chain_type(self) -> str: | |
return "sql_database_sequential_chain" | |