Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/community
/langchain_community
/tools
/databricks
/tool.py
import json | |
from datetime import date, datetime | |
from decimal import Decimal | |
from hashlib import md5 | |
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union | |
from langchain_core.pydantic_v1 import BaseModel, Field, create_model | |
from langchain_core.tools import BaseTool, BaseToolkit, StructuredTool | |
from typing_extensions import Self | |
if TYPE_CHECKING: | |
from databricks.sdk import WorkspaceClient | |
from databricks.sdk.service.catalog import FunctionInfo | |
from langchain_community.tools.databricks._execution import execute_function | |
def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type: | |
mapping = { | |
"long": int, | |
"binary": bytes, | |
"boolean": bool, | |
"date": date, | |
"double": float, | |
"float": float, | |
"integer": int, | |
"short": int, | |
"string": str, | |
"timestamp": datetime, | |
"timestamp_ntz": datetime, | |
"byte": int, | |
} | |
if isinstance(uc_type_json, str): | |
if uc_type_json in mapping: | |
return mapping[uc_type_json] | |
else: | |
if uc_type_json.startswith("decimal"): | |
return Decimal | |
elif uc_type_json == "void" or uc_type_json.startswith("interval"): | |
raise TypeError(f"Type {uc_type_json} is not supported.") | |
else: | |
raise TypeError( | |
f"Unknown type {uc_type_json}. Try upgrading this package." | |
) | |
else: | |
assert isinstance(uc_type_json, dict) | |
tpe = uc_type_json["type"] | |
if tpe == "array": | |
element_type = _uc_type_to_pydantic_type(uc_type_json["elementType"]) | |
if uc_type_json["containsNull"]: | |
element_type = Optional[element_type] # type: ignore | |
return List[element_type] # type: ignore | |
elif tpe == "map": | |
key_type = uc_type_json["keyType"] | |
assert key_type == "string", TypeError( | |
f"Only support STRING key type for MAP but got {key_type}." | |
) | |
value_type = _uc_type_to_pydantic_type(uc_type_json["valueType"]) | |
if uc_type_json["valueContainsNull"]: | |
value_type: Type = Optional[value_type] # type: ignore | |
return Dict[str, value_type] # type: ignore | |
elif tpe == "struct": | |
fields = {} | |
for field in uc_type_json["fields"]: | |
field_type = _uc_type_to_pydantic_type(field["type"]) | |
if field.get("nullable"): | |
field_type = Optional[field_type] # type: ignore | |
comment = ( | |
uc_type_json["metadata"].get("comment") | |
if "metadata" in uc_type_json | |
else None | |
) | |
fields[field["name"]] = (field_type, Field(..., description=comment)) | |
uc_type_json_str = json.dumps(uc_type_json, sort_keys=True) | |
type_hash = md5(uc_type_json_str.encode()).hexdigest()[:8] | |
return create_model(f"Struct_{type_hash}", **fields) # type: ignore | |
else: | |
raise TypeError(f"Unknown type {uc_type_json}. Try upgrading this package.") | |
def _generate_args_schema(function: "FunctionInfo") -> Type[BaseModel]: | |
if function.input_params is None: | |
return BaseModel | |
params = function.input_params.parameters | |
assert params is not None | |
fields = {} | |
for p in params: | |
assert p.type_json is not None | |
type_json = json.loads(p.type_json)["type"] | |
pydantic_type = _uc_type_to_pydantic_type(type_json) | |
description = p.comment | |
default: Any = ... | |
if p.parameter_default: | |
pydantic_type = Optional[pydantic_type] # type: ignore | |
default = None | |
# TODO: Convert default value string to the correct type. | |
# We might need to use statement execution API | |
# to get the JSON representation of the value. | |
default_description = f"(Default: {p.parameter_default})" | |
if description: | |
description += f" {default_description}" | |
else: | |
description = default_description | |
fields[p.name] = ( | |
pydantic_type, | |
Field(default=default, description=description), | |
) | |
return create_model( | |
f"{function.catalog_name}__{function.schema_name}__{function.name}__params", | |
**fields, # type: ignore | |
) | |
def _get_tool_name(function: "FunctionInfo") -> str: | |
tool_name = f"{function.catalog_name}__{function.schema_name}__{function.name}"[ | |
-64: | |
] | |
return tool_name | |
def _get_default_workspace_client() -> "WorkspaceClient": | |
try: | |
from databricks.sdk import WorkspaceClient | |
except ImportError as e: | |
raise ImportError( | |
"Could not import databricks-sdk python package. " | |
"Please install it with `pip install databricks-sdk`." | |
) from e | |
return WorkspaceClient() | |
class UCFunctionToolkit(BaseToolkit): | |
warehouse_id: str = Field( | |
description="The ID of a Databricks SQL Warehouse to execute functions." | |
) | |
workspace_client: "WorkspaceClient" = Field( | |
default_factory=_get_default_workspace_client, | |
description="Databricks workspace client.", | |
) | |
tools: Dict[str, BaseTool] = Field(default_factory=dict) | |
class Config: | |
arbitrary_types_allowed = True | |
def include(self, *function_names: str, **kwargs: Any) -> Self: | |
""" | |
Includes UC functions to the toolkit. | |
Args: | |
functions: A list of UC function names in the format | |
"catalog_name.schema_name.function_name" or | |
"catalog_name.schema_name.*". | |
If the function name ends with ".*", | |
all functions in the schema will be added. | |
kwargs: Extra arguments to pass to StructuredTool, e.g., `return_direct`. | |
""" | |
for name in function_names: | |
if name.endswith(".*"): | |
catalog_name, schema_name = name[:-2].split(".") | |
# TODO: handle pagination, warn and truncate if too many | |
functions = self.workspace_client.functions.list( | |
catalog_name=catalog_name, schema_name=schema_name | |
) | |
for f in functions: | |
assert f.full_name is not None | |
self.include(f.full_name, **kwargs) | |
else: | |
if name not in self.tools: | |
self.tools[name] = self._make_tool(name, **kwargs) | |
return self | |
def _make_tool(self, function_name: str, **kwargs: Any) -> BaseTool: | |
function = self.workspace_client.functions.get(function_name) | |
name = _get_tool_name(function) | |
description = function.comment or "" | |
args_schema = _generate_args_schema(function) | |
def func(*args: Any, **kwargs: Any) -> str: | |
# TODO: We expect all named args and ignore args. | |
# Non-empty args show up when the function has no parameters. | |
args_json = json.loads(json.dumps(kwargs, default=str)) | |
result = execute_function( | |
ws=self.workspace_client, | |
warehouse_id=self.warehouse_id, | |
function=function, | |
parameters=args_json, | |
) | |
return result.to_json() | |
return StructuredTool( | |
name=name, | |
description=description, | |
args_schema=args_schema, | |
func=func, | |
**kwargs, | |
) | |
def get_tools(self) -> List[BaseTool]: | |
return list(self.tools.values()) | |