File size: 20,535 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
import os
import re
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LLM
from langchain_core.pydantic_v1 import (
    BaseModel,
    Extra,
    Field,
    PrivateAttr,
    root_validator,
    validator,
)

__all__ = ["Databricks"]


class _DatabricksClientBase(BaseModel, ABC):
    """A base JSON API client that talks to Databricks."""

    api_url: str
    api_token: str

    def request(self, method: str, url: str, request: Any) -> Any:
        headers = {"Authorization": f"Bearer {self.api_token}"}
        response = requests.request(
            method=method, url=url, headers=headers, json=request
        )
        # TODO: error handling and automatic retries
        if not response.ok:
            raise ValueError(f"HTTP {response.status_code} error: {response.text}")
        return response.json()

    def _get(self, url: str) -> Any:
        return self.request("GET", url, None)

    def _post(self, url: str, request: Any) -> Any:
        return self.request("POST", url, request)

    @abstractmethod
    def post(
        self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
    ) -> Any:
        ...

    @property
    def llm(self) -> bool:
        return False


def _transform_completions(response: Dict[str, Any]) -> str:
    return response["choices"][0]["text"]


def _transform_llama2_chat(response: Dict[str, Any]) -> str:
    return response["candidates"][0]["text"]


def _transform_chat(response: Dict[str, Any]) -> str:
    return response["choices"][0]["message"]["content"]


class _DatabricksServingEndpointClient(_DatabricksClientBase):
    """An API client that talks to a Databricks serving endpoint."""

    host: str
    endpoint_name: str
    databricks_uri: str
    client: Any = None
    external_or_foundation: bool = False
    task: Optional[str] = None

    def __init__(self, **data: Any):
        super().__init__(**data)

        try:
            from mlflow.deployments import get_deploy_client

            self.client = get_deploy_client(self.databricks_uri)
        except ImportError as e:
            raise ImportError(
                "Failed to create the client. "
                "Please install mlflow with `pip install mlflow`."
            ) from e

        endpoint = self.client.get_endpoint(self.endpoint_name)
        self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in (
            "external_model",
            "foundation_model_api",
        )
        if self.task is None:
            self.task = endpoint.get("task")

    @property
    def llm(self) -> bool:
        return self.task in ("llm/v1/chat", "llm/v1/completions", "llama2/chat")

    @root_validator(pre=True)
    def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if "api_url" not in values:
            host = values["host"]
            endpoint_name = values["endpoint_name"]
            api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations"
            values["api_url"] = api_url
        return values

    def post(
        self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
    ) -> Any:
        if self.external_or_foundation:
            resp = self.client.predict(endpoint=self.endpoint_name, inputs=request)
            if transform_output_fn:
                return transform_output_fn(resp)

            if self.task == "llm/v1/chat":
                return _transform_chat(resp)
            elif self.task == "llm/v1/completions":
                return _transform_completions(resp)

            return resp
        else:
            # See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
            wrapped_request = {"dataframe_records": [request]}
            response = self.client.predict(
                endpoint=self.endpoint_name, inputs=wrapped_request
            )
            preds = response["predictions"]
            # For a single-record query, the result is not a list.
            pred = preds[0] if isinstance(preds, list) else preds
            if self.task == "llama2/chat":
                return _transform_llama2_chat(pred)
            return transform_output_fn(pred) if transform_output_fn else pred


class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
    """An API client that talks to a Databricks cluster driver proxy app."""

    host: str
    cluster_id: str
    cluster_driver_port: str

    @root_validator(pre=True)
    def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if "api_url" not in values:
            host = values["host"]
            cluster_id = values["cluster_id"]
            port = values["cluster_driver_port"]
            api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}"
            values["api_url"] = api_url
        return values

    def post(
        self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
    ) -> Any:
        resp = self._post(self.api_url, request)
        return transform_output_fn(resp) if transform_output_fn else resp


def get_repl_context() -> Any:
    """Get the notebook REPL context if running inside a Databricks notebook.
    Returns None otherwise.
    """
    try:
        from dbruntime.databricks_repl_context import get_context

        return get_context()
    except ImportError:
        raise ImportError(
            "Cannot access dbruntime, not running inside a Databricks notebook."
        )


def get_default_host() -> str:
    """Get the default Databricks workspace hostname.
    Raises an error if the hostname cannot be automatically determined.
    """
    host = os.getenv("DATABRICKS_HOST")
    if not host:
        try:
            host = get_repl_context().browserHostName
            if not host:
                raise ValueError("context doesn't contain browserHostName.")
        except Exception as e:
            raise ValueError(
                "host was not set and cannot be automatically inferred. Set "
                f"environment variable 'DATABRICKS_HOST'. Received error: {e}"
            )
    # TODO: support Databricks CLI profile
    host = host.lstrip("https://").lstrip("http://").rstrip("/")
    return host


def get_default_api_token() -> str:
    """Get the default Databricks personal access token.
    Raises an error if the token cannot be automatically determined.
    """
    if api_token := os.getenv("DATABRICKS_TOKEN"):
        return api_token
    try:
        api_token = get_repl_context().apiToken
        if not api_token:
            raise ValueError("context doesn't contain apiToken.")
    except Exception as e:
        raise ValueError(
            "api_token was not set and cannot be automatically inferred. Set "
            f"environment variable 'DATABRICKS_TOKEN'. Received error: {e}"
        )
    # TODO: support Databricks CLI profile
    return api_token


def _is_hex_string(data: str) -> bool:
    """Checks if a data is a valid hexadecimal string using a regular expression."""
    if not isinstance(data, str):
        return False
    pattern = r"^[0-9a-fA-F]+$"
    return bool(re.match(pattern, data))


def _load_pickled_fn_from_hex_string(
    data: str, allow_dangerous_deserialization: Optional[bool]
) -> Callable:
    """Loads a pickled function from a hexadecimal string."""
    if not allow_dangerous_deserialization:
        raise ValueError(
            "This code relies on the pickle module. "
            "You will need to set allow_dangerous_deserialization=True "
            "if you want to opt-in to allow deserialization of data using pickle."
            "Data can be compromised by a malicious actor if "
            "not handled properly to include "
            "a malicious payload that when deserialized with "
            "pickle can execute arbitrary code on your machine."
        )

    try:
        import cloudpickle
    except Exception as e:
        raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")

    try:
        return cloudpickle.loads(bytes.fromhex(data))
    except Exception as e:
        raise ValueError(
            f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
        )


def _pickle_fn_to_hex_string(fn: Callable) -> str:
    """Pickles a function and returns the hexadecimal string."""
    try:
        import cloudpickle
    except Exception as e:
        raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")

    try:
        return cloudpickle.dumps(fn).hex()
    except Exception as e:
        raise ValueError(f"Failed to pickle the function: {e}")


class Databricks(LLM):
    """Databricks serving endpoint or a cluster driver proxy app for LLM.

    It supports two endpoint types:

    * **Serving endpoint** (recommended for both production and development).
      We assume that an LLM was deployed to a serving endpoint.
      To wrap it as an LLM you must have "Can Query" permission to the endpoint.
      Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
      ``cluster_driver_port``.

      If the underlying model is a model registered by MLflow, the expected model
      signature is:

      * inputs::

          [{"name": "prompt", "type": "string"},
           {"name": "stop", "type": "list[string]"}]

      * outputs: ``[{"type": "string"}]``

      If the underlying model is an external or foundation model, the response from the
      endpoint is automatically transformed to the expected format unless
      ``transform_output_fn`` is provided.

    * **Cluster driver proxy app** (recommended for interactive development).
      One can load an LLM on a Databricks interactive cluster and start a local HTTP
      server on the driver node to serve the model at ``/`` using HTTP POST method
      with JSON input/output.
      Please use a port number between ``[3000, 8000]`` and let the server listen to
      the driver IP address or simply ``0.0.0.0`` instead of localhost only.
      To wrap it as an LLM you must have "Can Attach To" permission to the cluster.
      Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``.
      The expected server schema (using JSON schema) is:

      * inputs::

          {"type": "object",
           "properties": {
              "prompt": {"type": "string"},
              "stop": {"type": "array", "items": {"type": "string"}}},
           "required": ["prompt"]}`

      * outputs: ``{"type": "string"}``

    If the endpoint model signature is different or you want to set extra params,
    you can use `transform_input_fn` and `transform_output_fn` to apply necessary
    transformations before and after the query.
    """

    host: str = Field(default_factory=get_default_host)
    """Databricks workspace hostname.
    If not provided, the default value is determined by

    * the ``DATABRICKS_HOST`` environment variable if present, or
    * the hostname of the current Databricks workspace if running inside
      a Databricks notebook attached to an interactive cluster in "single user"
      or "no isolation shared" mode.
    """

    api_token: str = Field(default_factory=get_default_api_token)
    """Databricks personal access token.
    If not provided, the default value is determined by

    * the ``DATABRICKS_TOKEN`` environment variable if present, or
    * an automatically generated temporary token if running inside a Databricks
      notebook attached to an interactive cluster in "single user" or
      "no isolation shared" mode.
    """

    endpoint_name: Optional[str] = None
    """Name of the model serving endpoint.
    You must specify the endpoint name to connect to a model serving endpoint.
    You must not set both ``endpoint_name`` and ``cluster_id``.
    """

    cluster_id: Optional[str] = None
    """ID of the cluster if connecting to a cluster driver proxy app.
    If neither ``endpoint_name`` nor ``cluster_id`` is not provided and the code runs
    inside a Databricks notebook attached to an interactive cluster in "single user"
    or "no isolation shared" mode, the current cluster ID is used as default.
    You must not set both ``endpoint_name`` and ``cluster_id``.
    """

    cluster_driver_port: Optional[str] = None
    """The port number used by the HTTP server running on the cluster driver node.
    The server should listen on the driver IP address or simply ``0.0.0.0`` to connect.
    We recommend the server using a port number between ``[3000, 8000]``.
    """

    model_kwargs: Optional[Dict[str, Any]] = None
    """
    Deprecated. Please use ``extra_params`` instead. Extra parameters to pass to
    the endpoint.
    """

    transform_input_fn: Optional[Callable] = None
    """A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible
    request object that the endpoint accepts.
    For example, you can apply a prompt template to the input prompt.
    """

    transform_output_fn: Optional[Callable[..., str]] = None
    """A function that transforms the output from the endpoint to the generated text.
    """

    databricks_uri: str = "databricks"
    """The databricks URI. Only used when using a serving endpoint."""

    temperature: float = 0.0
    """The sampling temperature."""
    n: int = 1
    """The number of completion choices to generate."""
    stop: Optional[List[str]] = None
    """The stop sequence."""
    max_tokens: Optional[int] = None
    """The maximum number of tokens to generate."""
    extra_params: Dict[str, Any] = Field(default_factory=dict)
    """Any extra parameters to pass to the endpoint."""
    task: Optional[str] = None
    """The task of the endpoint. Only used when using a serving endpoint.
    If not provided, the task is automatically inferred from the endpoint.
    """

    allow_dangerous_deserialization: bool = False
    """Whether to allow dangerous deserialization of the data which 
    involves loading data using pickle.
    
    If the data has been modified by a malicious actor, it can deliver a
    malicious payload that results in execution of arbitrary code on the target
    machine.
    """

    _client: _DatabricksClientBase = PrivateAttr()

    class Config:
        extra = Extra.forbid
        underscore_attrs_are_private = True

    @property
    def _llm_params(self) -> Dict[str, Any]:
        params: Dict[str, Any] = {
            "temperature": self.temperature,
            "n": self.n,
        }
        if self.stop:
            params["stop"] = self.stop
        if self.max_tokens is not None:
            params["max_tokens"] = self.max_tokens
        return params

    @validator("cluster_id", always=True)
    def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
        if v and values["endpoint_name"]:
            raise ValueError("Cannot set both endpoint_name and cluster_id.")
        elif values["endpoint_name"]:
            return None
        elif v:
            return v
        else:
            try:
                if v := get_repl_context().clusterId:
                    return v
                raise ValueError("Context doesn't contain clusterId.")
            except Exception as e:
                raise ValueError(
                    "Neither endpoint_name nor cluster_id was set. "
                    "And the cluster_id cannot be automatically determined. Received"
                    f" error: {e}"
                )

    @validator("cluster_driver_port", always=True)
    def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
        if v and values["endpoint_name"]:
            raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
        elif values["endpoint_name"]:
            return None
        elif v is None:
            raise ValueError(
                "Must set cluster_driver_port to connect to a cluster driver."
            )
        elif int(v) <= 0:
            raise ValueError(f"Invalid cluster_driver_port: {v}")
        else:
            return v

    @validator("model_kwargs", always=True)
    def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
        if v:
            assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
            assert "stop" not in v, "model_kwargs must not contain key 'stop'"
        return v

    def __init__(self, **data: Any):
        if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
            data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
                data=data["transform_input_fn"],
                allow_dangerous_deserialization=data.get(
                    "allow_dangerous_deserialization"
                ),
            )
        if "transform_output_fn" in data and _is_hex_string(
            data["transform_output_fn"]
        ):
            data["transform_output_fn"] = _load_pickled_fn_from_hex_string(
                data=data["transform_output_fn"],
                allow_dangerous_deserialization=data.get(
                    "allow_dangerous_deserialization"
                ),
            )

        super().__init__(**data)
        if self.model_kwargs is not None and self.extra_params is not None:
            raise ValueError("Cannot set both extra_params and extra_params.")
        elif self.model_kwargs is not None:
            warnings.warn(
                "model_kwargs is deprecated. Please use extra_params instead.",
                DeprecationWarning,
            )
        if self.endpoint_name:
            self._client = _DatabricksServingEndpointClient(
                host=self.host,
                api_token=self.api_token,
                endpoint_name=self.endpoint_name,
                databricks_uri=self.databricks_uri,
                task=self.task,
            )
        elif self.cluster_id and self.cluster_driver_port:
            self._client = _DatabricksClusterDriverProxyClient(  # type: ignore[call-arg]
                host=self.host,
                api_token=self.api_token,
                cluster_id=self.cluster_id,
                cluster_driver_port=self.cluster_driver_port,
            )
        else:
            raise ValueError(
                "Must specify either endpoint_name or cluster_id/cluster_driver_port."
            )

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Return default params."""
        return {
            "host": self.host,
            # "api_token": self.api_token,  # Never save the token
            "endpoint_name": self.endpoint_name,
            "cluster_id": self.cluster_id,
            "cluster_driver_port": self.cluster_driver_port,
            "databricks_uri": self.databricks_uri,
            "model_kwargs": self.model_kwargs,
            "temperature": self.temperature,
            "n": self.n,
            "stop": self.stop,
            "max_tokens": self.max_tokens,
            "extra_params": self.extra_params,
            "task": self.task,
            "transform_input_fn": None
            if self.transform_input_fn is None
            else _pickle_fn_to_hex_string(self.transform_input_fn),
            "transform_output_fn": None
            if self.transform_output_fn is None
            else _pickle_fn_to_hex_string(self.transform_output_fn),
        }

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return self._default_params

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "databricks"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Queries the LLM endpoint with the given prompt and stop sequence."""

        # TODO: support callbacks

        request: Dict[str, Any] = {"prompt": prompt}
        if self._client.llm:
            request.update(self._llm_params)
        request.update(self.model_kwargs or self.extra_params)
        request.update(kwargs)
        if stop:
            request["stop"] = stop

        if self.transform_input_fn:
            request = self.transform_input_fn(**request)

        return self._client.post(request, transform_output_fn=self.transform_output_fn)