Superkingjcj's picture
Upload 111 files
e679d69 verified
raw
history blame contribute delete
1.47 kB
import importlib
import sys
from typing import Dict
import ray
from lagent.schema import AgentMessage
from lagent.utils import load_class_from_string
class AsyncAgentRayActor:
def __init__(
self,
config: Dict,
num_gpus: int,
):
cls_name = config.pop('type')
python_path = config.pop('python_path', None)
cls_name = load_class_from_string(cls_name, python_path) if isinstance(
cls_name, str) else cls_name
AsyncAgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
self.agent_actor = AsyncAgentActor.remote(**config)
async def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
response = await self.agent_actor.__call__.remote(
*message, session_id=session_id, **kwargs)
return response
class AgentRayActor:
def __init__(
self,
config: Dict,
num_gpus: int,
):
cls_name = config.pop('type')
python_path = config.pop('python_path', None)
cls_name = load_class_from_string(cls_name, python_path) if isinstance(
cls_name, str) else cls_name
AgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
self.agent_actor = AgentActor.remote(**config)
def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
response = self.agent_actor.__call__.remote(
*message, session_id=session_id, **kwargs)
return ray.get(response)