|
""" |
|
Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions` |
|
""" |
|
|
|
from typing import List, Optional, Union |
|
|
|
from pydantic import BaseModel |
|
|
|
from litellm.litellm_core_utils.prompt_templates.common_utils import ( |
|
handle_messages_with_content_list_to_str_conversion, |
|
strip_name_from_messages, |
|
) |
|
from litellm.types.llms.openai import AllMessageValues |
|
from litellm.types.utils import ProviderField |
|
|
|
from ...openai_like.chat.transformation import OpenAILikeChatConfig |
|
|
|
|
|
class DatabricksConfig(OpenAILikeChatConfig): |
|
""" |
|
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request |
|
""" |
|
|
|
max_tokens: Optional[int] = None |
|
temperature: Optional[int] = None |
|
top_p: Optional[int] = None |
|
top_k: Optional[int] = None |
|
stop: Optional[Union[List[str], str]] = None |
|
n: Optional[int] = None |
|
|
|
def __init__( |
|
self, |
|
max_tokens: Optional[int] = None, |
|
temperature: Optional[int] = None, |
|
top_p: Optional[int] = None, |
|
top_k: Optional[int] = None, |
|
stop: Optional[Union[List[str], str]] = None, |
|
n: Optional[int] = None, |
|
) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return super().get_config() |
|
|
|
def get_required_params(self) -> List[ProviderField]: |
|
"""For a given provider, return it's required fields with a description""" |
|
return [ |
|
ProviderField( |
|
field_name="api_key", |
|
field_type="string", |
|
field_description="Your Databricks API Key.", |
|
field_value="dapi...", |
|
), |
|
ProviderField( |
|
field_name="api_base", |
|
field_type="string", |
|
field_description="Your Databricks API Base.", |
|
field_value="https://adb-..", |
|
), |
|
] |
|
|
|
def get_supported_openai_params(self, model: Optional[str] = None) -> list: |
|
return [ |
|
"stream", |
|
"stop", |
|
"temperature", |
|
"top_p", |
|
"max_tokens", |
|
"max_completion_tokens", |
|
"n", |
|
"response_format", |
|
"tools", |
|
"tool_choice", |
|
] |
|
|
|
def _should_fake_stream(self, optional_params: dict) -> bool: |
|
""" |
|
Databricks doesn't support 'response_format' while streaming |
|
""" |
|
if optional_params.get("response_format") is not None: |
|
return True |
|
|
|
return False |
|
|
|
def _transform_messages( |
|
self, messages: List[AllMessageValues], model: str |
|
) -> List[AllMessageValues]: |
|
""" |
|
Databricks does not support: |
|
- content in list format. |
|
- 'name' in user message. |
|
""" |
|
new_messages = [] |
|
for idx, message in enumerate(messages): |
|
if isinstance(message, BaseModel): |
|
_message = message.model_dump(exclude_none=True) |
|
else: |
|
_message = message |
|
new_messages.append(_message) |
|
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages) |
|
new_messages = strip_name_from_messages(new_messages) |
|
return super()._transform_messages(messages=new_messages, model=model) |
|
|