Spaces:
Runtime error
Runtime error
"""Fake ChatModel for testing purposes.""" | |
import asyncio | |
import time | |
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel | |
from langchain_core.messages import AIMessageChunk, BaseMessage | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
class FakeMessagesListChatModel(BaseChatModel): | |
"""Fake ChatModel for testing purposes.""" | |
responses: List[BaseMessage] | |
sleep: Optional[float] = None | |
i: int = 0 | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
response = self.responses[self.i] | |
if self.i < len(self.responses) - 1: | |
self.i += 1 | |
else: | |
self.i = 0 | |
generation = ChatGeneration(message=response) | |
return ChatResult(generations=[generation]) | |
def _llm_type(self) -> str: | |
return "fake-messages-list-chat-model" | |
class FakeListChatModel(SimpleChatModel): | |
"""Fake ChatModel for testing purposes.""" | |
responses: List | |
sleep: Optional[float] = None | |
i: int = 0 | |
def _llm_type(self) -> str: | |
return "fake-list-chat-model" | |
def _call( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
"""First try to lookup in queries, else return 'foo' or 'bar'.""" | |
response = self.responses[self.i] | |
if self.i < len(self.responses) - 1: | |
self.i += 1 | |
else: | |
self.i = 0 | |
return response | |
def _stream( | |
self, | |
messages: List[BaseMessage], | |
stop: Union[List[str], None] = None, | |
run_manager: Union[CallbackManagerForLLMRun, None] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
response = self.responses[self.i] | |
if self.i < len(self.responses) - 1: | |
self.i += 1 | |
else: | |
self.i = 0 | |
for c in response: | |
if self.sleep is not None: | |
time.sleep(self.sleep) | |
yield ChatGenerationChunk(message=AIMessageChunk(content=c)) | |
async def _astream( | |
self, | |
messages: List[BaseMessage], | |
stop: Union[List[str], None] = None, | |
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None, | |
**kwargs: Any, | |
) -> AsyncIterator[ChatGenerationChunk]: | |
response = self.responses[self.i] | |
if self.i < len(self.responses) - 1: | |
self.i += 1 | |
else: | |
self.i = 0 | |
for c in response: | |
if self.sleep is not None: | |
await asyncio.sleep(self.sleep) | |
yield ChatGenerationChunk(message=AIMessageChunk(content=c)) | |
def _identifying_params(self) -> Dict[str, Any]: | |
return {"responses": self.responses} | |