File size: 2,264 Bytes
87337b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from dataclasses import dataclass
import requests
from openai import AsyncOpenAI, AsyncAzureOpenAI
from ten.async_ten_env import AsyncTenEnv
from ten_ai_base.config import BaseConfig
@dataclass
class OpenAIImageGenerateToolConfig(BaseConfig):
api_key: str = ""
base_url: str = "https://api.openai.com/v1"
model: str = "dall-e-3"
size: str = "1024x1024"
quality: str = "standard"
n: int = 1
proxy_url: str = ""
vendor: str = "openai"
azure_endpoint: str = ""
azure_api_version: str = ""
class OpenAIImageGenerateClient:
client = None
def __init__(self, ten_env: AsyncTenEnv, config: OpenAIImageGenerateToolConfig):
self.config = config
ten_env.log_info(f"OpenAIImageGenerateClient initialized with config: {config.api_key}")
if self.config.vendor == "azure":
self.client = AsyncAzureOpenAI(
api_key=config.api_key,
api_version=self.config.azure_api_version,
azure_endpoint=config.azure_endpoint,
)
ten_env.log_info(
f"Using Azure OpenAI with endpoint: {config.azure_endpoint}, api_version: {config.azure_api_version}"
)
else:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
self.session = requests.Session()
if config.proxy_url:
proxies = {
"http": config.proxy_url,
"https": config.proxy_url,
}
ten_env.log_info(f"Setting proxies: {proxies}")
self.session.proxies.update(proxies)
self.client.session = self.session
async def generate_images(self, prompt: str) -> str:
req = {
"model": self.config.model,
"prompt": prompt,
"size": self.config.size,
"quality": self.config.quality,
"n": self.config.n,
}
try:
response = await self.client.images.generate(**req)
except Exception as e:
raise RuntimeError(f"GenerateImages failed, err: {e}") from e
return response.data[0].url |