File size: 3,209 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. 
"""

from typing import Optional

import httpx

import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import RerankResponse


class AzureAIRerankConfig(CohereRerankConfig):
    """
    Azure AI Rerank - Follows the same Spec as Cohere Rerank
    """

    def get_complete_url(self, api_base: Optional[str], model: str) -> str:
        if api_base is None:
            raise ValueError(
                "Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
            )
        if not api_base.endswith("/v1/rerank"):
            api_base = f"{api_base}/v1/rerank"
        return api_base

    def validate_environment(
        self,
        headers: dict,
        model: str,
        api_key: Optional[str] = None,
    ) -> dict:
        if api_key is None:
            api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key

        if api_key is None:
            raise ValueError(
                "Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'"
            )

        default_headers = {
            "Authorization": f"Bearer {api_key}",
            "accept": "application/json",
            "content-type": "application/json",
        }

        # If 'Authorization' is provided in headers, it overrides the default.
        if "Authorization" in headers:
            default_headers["Authorization"] = headers["Authorization"]

        # Merge other headers, overriding any default ones except Authorization
        return {**default_headers, **headers}

    def transform_rerank_response(
        self,
        model: str,
        raw_response: httpx.Response,
        model_response: RerankResponse,
        logging_obj: LiteLLMLoggingObj,
        api_key: Optional[str] = None,
        request_data: dict = {},
        optional_params: dict = {},
        litellm_params: dict = {},
    ) -> RerankResponse:
        rerank_response = super().transform_rerank_response(
            model=model,
            raw_response=raw_response,
            model_response=model_response,
            logging_obj=logging_obj,
            api_key=api_key,
            request_data=request_data,
            optional_params=optional_params,
            litellm_params=litellm_params,
        )
        base_model = self._get_base_model(
            rerank_response._hidden_params.get("llm_provider-azureml-model-group")
        )
        rerank_response._hidden_params["model"] = base_model
        return rerank_response

    def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
        if azure_model_group is None:
            return None
        if azure_model_group == "offer-cohere-rerank-mul-paygo":
            return "azure_ai/cohere-rerank-v3-multilingual"
        if azure_model_group == "offer-cohere-rerank-eng-paygo":
            return "azure_ai/cohere-rerank-v3-english"
        return azure_model_group