Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/community
/langchain_community
/chat_models
/huggingface.py
"""Hugging Face Chat Wrapper.""" | |
from typing import Any, AsyncIterator, Iterator, List, Optional | |
from langchain_core._api.deprecation import deprecated | |
from langchain_core.callbacks.manager import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.language_models.chat_models import ( | |
BaseChatModel, | |
agenerate_from_stream, | |
generate_from_stream, | |
) | |
from langchain_core.messages import ( | |
AIMessage, | |
AIMessageChunk, | |
BaseMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
from langchain_core.outputs import ( | |
ChatGeneration, | |
ChatGenerationChunk, | |
ChatResult, | |
LLMResult, | |
) | |
from langchain_core.pydantic_v1 import root_validator | |
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint | |
from langchain_community.llms.huggingface_hub import HuggingFaceHub | |
from langchain_community.llms.huggingface_text_gen_inference import ( | |
HuggingFaceTextGenInference, | |
) | |
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.""" | |
class ChatHuggingFace(BaseChatModel): | |
""" | |
Wrapper for using Hugging Face LLM's as ChatModels. | |
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`, | |
and `HuggingFaceHub` LLMs. | |
Upon instantiating this class, the model_id is resolved from the url | |
provided to the LLM, and the appropriate tokenizer is loaded from | |
the HuggingFace Hub. | |
Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat | |
""" | |
llm: Any | |
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or | |
HuggingFaceHub.""" | |
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) | |
tokenizer: Any = None | |
model_id: Optional[str] = None | |
streaming: bool = False | |
def __init__(self, **kwargs: Any): | |
super().__init__(**kwargs) | |
from transformers import AutoTokenizer | |
self._resolve_model_id() | |
self.tokenizer = ( | |
AutoTokenizer.from_pretrained(self.model_id) | |
if self.tokenizer is None | |
else self.tokenizer | |
) | |
def validate_llm(cls, values: dict) -> dict: | |
if not isinstance( | |
values["llm"], | |
(HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub), | |
): | |
raise TypeError( | |
"Expected llm to be one of HuggingFaceTextGenInference, " | |
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}" | |
) | |
return values | |
def _stream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
request = self._to_chat_prompt(messages) | |
for data in self.llm.stream(request, **kwargs): | |
delta = data | |
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) | |
if run_manager: | |
run_manager.on_llm_new_token(delta, chunk=chunk) | |
yield chunk | |
async def _astream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> AsyncIterator[ChatGenerationChunk]: | |
request = self._to_chat_prompt(messages) | |
async for data in self.llm.astream(request, **kwargs): | |
delta = data | |
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) | |
if run_manager: | |
await run_manager.on_llm_new_token(delta, chunk=chunk) | |
yield chunk | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
if self.streaming: | |
stream_iter = self._stream( | |
messages, stop=stop, run_manager=run_manager, **kwargs | |
) | |
return generate_from_stream(stream_iter) | |
llm_input = self._to_chat_prompt(messages) | |
llm_result = self.llm._generate( | |
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs | |
) | |
return self._to_chat_result(llm_result) | |
async def _agenerate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
if self.streaming: | |
stream_iter = self._astream( | |
messages, stop=stop, run_manager=run_manager, **kwargs | |
) | |
return await agenerate_from_stream(stream_iter) | |
llm_input = self._to_chat_prompt(messages) | |
llm_result = await self.llm._agenerate( | |
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs | |
) | |
return self._to_chat_result(llm_result) | |
def _to_chat_prompt( | |
self, | |
messages: List[BaseMessage], | |
) -> str: | |
"""Convert a list of messages into a prompt format expected by wrapped LLM.""" | |
if not messages: | |
raise ValueError("At least one HumanMessage must be provided!") | |
if not isinstance(messages[-1], HumanMessage): | |
raise ValueError("Last message must be a HumanMessage!") | |
messages_dicts = [self._to_chatml_format(m) for m in messages] | |
return self.tokenizer.apply_chat_template( | |
messages_dicts, tokenize=False, add_generation_prompt=True | |
) | |
def _to_chatml_format(self, message: BaseMessage) -> dict: | |
"""Convert LangChain message to ChatML format.""" | |
if isinstance(message, SystemMessage): | |
role = "system" | |
elif isinstance(message, AIMessage): | |
role = "assistant" | |
elif isinstance(message, HumanMessage): | |
role = "user" | |
else: | |
raise ValueError(f"Unknown message type: {type(message)}") | |
return {"role": role, "content": message.content} | |
def _to_chat_result(llm_result: LLMResult) -> ChatResult: | |
chat_generations = [] | |
for g in llm_result.generations[0]: | |
chat_generation = ChatGeneration( | |
message=AIMessage(content=g.text), generation_info=g.generation_info | |
) | |
chat_generations.append(chat_generation) | |
return ChatResult( | |
generations=chat_generations, llm_output=llm_result.llm_output | |
) | |
def _resolve_model_id(self) -> None: | |
"""Resolve the model_id from the LLM's inference_server_url""" | |
from huggingface_hub import list_inference_endpoints | |
available_endpoints = list_inference_endpoints("*") | |
if isinstance(self.llm, HuggingFaceHub) or ( | |
hasattr(self.llm, "repo_id") and self.llm.repo_id | |
): | |
self.model_id = self.llm.repo_id | |
return | |
elif isinstance(self.llm, HuggingFaceTextGenInference): | |
endpoint_url: Optional[str] = self.llm.inference_server_url | |
else: | |
endpoint_url = self.llm.endpoint_url | |
for endpoint in available_endpoints: | |
if endpoint.url == endpoint_url: | |
self.model_id = endpoint.repository | |
if not self.model_id: | |
raise ValueError( | |
"Failed to resolve model_id:" | |
f"Could not find model id for inference server: {endpoint_url}" | |
"Make sure that your Hugging Face token has access to the endpoint." | |
) | |
def _llm_type(self) -> str: | |
return "huggingface-chat-wrapper" | |