Spaces:
Runtime error
Runtime error
File size: 4,297 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter
class ArceeRetriever(BaseRetriever):
"""Arcee Domain Adapted Language Models (DALMs) retriever.
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
or pass ``arcee_api_key`` as a named parameter.
Example:
.. code-block:: python
from langchain_community.retrievers import ArceeRetriever
retriever = ArceeRetriever(
model="DALM-PubMed",
arcee_api_key="ARCEE-API-KEY"
)
documents = retriever.invoke("AI-driven music therapy")
"""
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee client."""
arcee_api_key: SecretStr
"""Arcee API Key"""
model: str
"""Arcee DALM name"""
arcee_api_url: str = "https://api.arcee.ai"
"""Arcee API URL"""
arcee_api_version: str = "v2"
"""Arcee API Version"""
arcee_app_url: str = "https://app.arcee.ai"
"""Arcee App URL"""
model_kwargs: Optional[Dict[str, Any]] = None
"""Keyword arguments to pass to the model."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
underscore_attrs_are_private = True
def __init__(self, **data: Any) -> None:
"""Initializes private fields."""
super().__init__(**data)
self._client = ArceeWrapper(
arcee_api_key=self.arcee_api_key.get_secret_value(),
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
model_name=self.model,
)
self._client.validate_model_training_status()
@root_validator()
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""
# validate env vars
values["arcee_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
)
)
values["arcee_api_url"] = get_from_dict_or_env(
values,
"arcee_api_url",
"ARCEE_API_URL",
)
values["arcee_app_url"] = get_from_dict_or_env(
values,
"arcee_app_url",
"ARCEE_APP_URL",
)
values["arcee_api_version"] = get_from_dict_or_env(
values,
"arcee_api_version",
"ARCEE_API_VERSION",
)
# validate model kwargs
if values["model_kwargs"]:
kw = values["model_kwargs"]
# validate size
if kw.get("size") is not None:
if not kw.get("size") >= 0:
raise ValueError("`size` must not be negative.")
# validate filters
if kw.get("filters") is not None:
if not isinstance(kw.get("filters"), List):
raise ValueError("`filters` must be a list.")
for f in kw.get("filters"):
DALMFilter(**f)
return values
def _get_relevant_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Retrieve {size} contexts with your retriever for a given query
Args:
query: Query to submit to the model
size: The max number of context results to retrieve.
Defaults to 3. (Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
try:
if not self._client:
raise ValueError("Client is not initialized.")
return self._client.retrieve(query=query, **kwargs)
except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e
|