Spaces:
Running
Running
Create main.py
Browse files
main.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import re
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import aiohttp
|
9 |
+
import asyncio
|
10 |
+
from aiohttp import web
|
11 |
+
import unicodedata
|
12 |
+
from dataclasses import dataclass, asdict
|
13 |
+
import logging
|
14 |
+
|
15 |
+
logging.basicConfig(level=logging.INFO)
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class Config:
|
21 |
+
# SILICONFLOW-Key
|
22 |
+
AUTH_TOKEN: str = 'sk-xxxxxxxxx'
|
23 |
+
# One-API/New-API 中转地址
|
24 |
+
OPENAI_CHAT_API: str = 'https://xxxxxxxxxxxx/v1/chat/completions'
|
25 |
+
# key
|
26 |
+
OPENAI_CHAT_API_KEY: str = 'sk-xxxxxxxxxx'
|
27 |
+
# 默认的翻译模型
|
28 |
+
DEFAULT_TRANSLATE_MODEL: str = 'deepseek-chat'
|
29 |
+
# 增强的翻译模型
|
30 |
+
DEFAULT_PROMPT_MODEL: str = 'Qwen2-72B-Instruct'
|
31 |
+
|
32 |
+
|
33 |
+
config = Config()
|
34 |
+
|
35 |
+
URLS = {
|
36 |
+
'API_FLUX1_API4GPT_COM': 'https://api-flux1.api4gpt.com',
|
37 |
+
'FLUXAIWEB_COM_TOKEN': 'https://fluxaiweb.com/flux/getToken',
|
38 |
+
'FLUXAIWEB_COM_GENERATE': 'https://fluxaiweb.com/flux/generateImage',
|
39 |
+
'FLUXIMG_COM': 'https://fluximg.com/api/image/generateImage',
|
40 |
+
'API_SILICONFLOW_CN': 'https://api.siliconflow.cn/v1/chat/completions'
|
41 |
+
}
|
42 |
+
|
43 |
+
URL_MAP = {
|
44 |
+
'flux': "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image",
|
45 |
+
'sd3': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-3-medium/text-to-image",
|
46 |
+
'sdxl': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image",
|
47 |
+
'sd2': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-2-1/text-to-image",
|
48 |
+
'sdt': "https://api.siliconflow.cn/v1/stabilityai/sd-turbo/text-to-image",
|
49 |
+
'sdxlt': "https://api.siliconflow.cn/v1/stabilityai/sdxl-turbo/text-to-image",
|
50 |
+
'sdxll': "https://api.siliconflow.cn/v1/ByteDance/SDXL-Lightning/text-to-image"
|
51 |
+
}
|
52 |
+
|
53 |
+
IMG_URL_MAP = {
|
54 |
+
'sdxl': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/image-to-image",
|
55 |
+
'sd2': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-2-1/image-to-image",
|
56 |
+
'sdxll': "https://api.siliconflow.cn/v1/ByteDance/SDXL-Lightning/image-to-image",
|
57 |
+
'pm': "https://api.siliconflow.cn/v1/TencentARC/PhotoMaker/image-to-image"
|
58 |
+
}
|
59 |
+
|
60 |
+
RATIO_MAP = {
|
61 |
+
"1:1": "1024x1024",
|
62 |
+
"1:2": "1024x2048",
|
63 |
+
"3:2": "1536x1024",
|
64 |
+
"4:3": "1536x2048",
|
65 |
+
"16:9": "2048x1152",
|
66 |
+
"9:16": "1152x2048"
|
67 |
+
}
|
68 |
+
|
69 |
+
SYSTEM_ASSISTANT = """作为 Stable Diffusion Prompt 提示词专家,您将从关键词中创建提示,通常来自 Danbooru 等数据库。
|
70 |
+
提示通常描述图像,使用常见词汇,按重要性排列,并用逗号分隔。避免使用"-"或".",但可以接受空格和自然语言。避免词汇重复。
|
71 |
+
|
72 |
+
为了强调关键词,请将其放在括号中以增加其权重。例如,"(flowers)"将'flowers'的权重增加1.1倍,而"(((flowers)))"将其增加1.331倍。使用"(flowers:1.5)"将'flowers'的权重增加1.5倍。只为重要的标签增加权重。
|
73 |
+
|
74 |
+
提示包括三个部分:**前缀**(质量标签+风格词+效果器)+ **主题**(图像的主要焦点)+ **场景**(背景、环境)。
|
75 |
+
|
76 |
+
* 前缀影响图像质量。像"masterpiece"、"best quality"、"4k"这样的标签可以提高图像的细节。像"illustration"、"lensflare"这样的风格词定义图像的风格。像"bestlighting"、"lensflare"、"depthoffield"这样的效果器会影响光照和深度。
|
77 |
+
|
78 |
+
* 主题是图像的主要焦点,如角色或场景。对主题进行详细描述可以确保图像丰富而详细。增加主题的权重以增强其清晰度。对于角色,描述面部、头发、身体、服装、姿势等特征。
|
79 |
+
|
80 |
+
* 场景描述环境。没有场景,图像的背景是平淡的,主题显得过大。某些主题本身包含场景(例如建筑物、风景)。像"花草草地"、"阳光"、"河流"这样的环境词可以丰富场景。你的任务是设计图像生成的提示。请按照以下步骤进行操作:
|
81 |
+
|
82 |
+
1. 我会发送给您一个图像场景。需要你生成详细的图像描述
|
83 |
+
2. 图像描述必须是英文,输出为Positive Prompt。
|
84 |
+
|
85 |
+
示例:
|
86 |
+
|
87 |
+
我发送:二战时期的护士。
|
88 |
+
您回复只回复:
|
89 |
+
A WWII-era nurse in a German uniform, holding a wine bottle and stethoscope, sitting at a table in white attire, with a table in the background, masterpiece, best quality, 4k, illustration style, best lighting, depth of field, detailed character, detailed environment.
|
90 |
+
"""
|
91 |
+
|
92 |
+
|
93 |
+
async def select_random_image_generator():
|
94 |
+
generators = [generate_image1, generate_image2, generate_image3, generate_image4]
|
95 |
+
generator = random.choice(generators)
|
96 |
+
if generator == generate_image4:
|
97 |
+
return lambda prompt, size, model: generator(prompt, size, model)
|
98 |
+
else:
|
99 |
+
return lambda prompt, size, model=None: generator(prompt, size)
|
100 |
+
|
101 |
+
|
102 |
+
def extract_size_and_model_from_prompt(prompt):
|
103 |
+
size_match = re.search(r'--ar\s+(\S+)', prompt)
|
104 |
+
model_match = re.search(r'--m\s+(\S+)', prompt)
|
105 |
+
|
106 |
+
size = size_match.group(1) if size_match else '1:1'
|
107 |
+
model = model_match.group(1) if model_match else ''
|
108 |
+
|
109 |
+
clean_prompt = re.sub(r'--ar\s+\S+', '', prompt).strip()
|
110 |
+
clean_prompt = re.sub(r'--m\s+\S+', '', clean_prompt).strip()
|
111 |
+
|
112 |
+
return {'size': size, 'model': model, 'clean_prompt': clean_prompt}
|
113 |
+
|
114 |
+
|
115 |
+
def is_chinese(char):
|
116 |
+
return 'CJK' in unicodedata.name(char, '')
|
117 |
+
|
118 |
+
|
119 |
+
async def translate_prompt(prompt):
|
120 |
+
if all(not is_chinese(char) for char in prompt):
|
121 |
+
logger.info('Prompt is already in English, skipping translation')
|
122 |
+
return prompt
|
123 |
+
|
124 |
+
try:
|
125 |
+
async with aiohttp.ClientSession() as session:
|
126 |
+
async with session.post(config.OPENAI_CHAT_API, json={
|
127 |
+
'model': config.DEFAULT_TRANSLATE_MODEL, # 使用config中的model
|
128 |
+
'messages': [
|
129 |
+
{'role': 'system', 'content': SYSTEM_ASSISTANT},
|
130 |
+
{'role': 'user', 'content': prompt}
|
131 |
+
],
|
132 |
+
}, headers={
|
133 |
+
'Content-Type': 'application/json',
|
134 |
+
'Authorization': f'Bearer {config.OPENAI_CHAT_API_KEY}'
|
135 |
+
}) as response:
|
136 |
+
if response.status != 200:
|
137 |
+
error_text = await response.text()
|
138 |
+
logger.error(f'HTTP error! status: {response.status}, body: {error_text}')
|
139 |
+
raise Exception(f'HTTP error! status: {response.status}')
|
140 |
+
|
141 |
+
if 'application/json' not in response.headers.get('Content-Type', ''):
|
142 |
+
error_text = await response.text()
|
143 |
+
logger.error(f'Unexpected content type: {response.headers.get("Content-Type")}, body: {error_text}')
|
144 |
+
raise Exception(f'Unexpected content type: {response.headers.get("Content-Type")}')
|
145 |
+
|
146 |
+
result = await response.json()
|
147 |
+
return result['choices'][0]['message']['content']
|
148 |
+
|
149 |
+
except Exception as e:
|
150 |
+
logger.error('Translation error:', e)
|
151 |
+
return prompt
|
152 |
+
|
153 |
+
|
154 |
+
async def handle_request(request):
|
155 |
+
if request.method != 'POST' or not request.url.path.endswith('/v1/chat/completions'):
|
156 |
+
return web.Response(text='Not Found', status=404)
|
157 |
+
|
158 |
+
try:
|
159 |
+
data = await request.json()
|
160 |
+
messages = data.get('messages', [])
|
161 |
+
stream = data.get('stream', False)
|
162 |
+
|
163 |
+
user_message = next((msg['content'] for msg in reversed(messages) if msg['role'] == 'user'), None)
|
164 |
+
|
165 |
+
if not user_message:
|
166 |
+
return web.json_response({'error': "未找到用户消息"}, status=400)
|
167 |
+
|
168 |
+
size_and_model = extract_size_and_model_from_prompt(user_message)
|
169 |
+
translated_prompt = await translate_prompt(size_and_model['clean_prompt'])
|
170 |
+
|
171 |
+
selected_generator = await select_random_image_generator()
|
172 |
+
attempts = 0
|
173 |
+
max_attempts = 3
|
174 |
+
|
175 |
+
while attempts < max_attempts:
|
176 |
+
try:
|
177 |
+
image_data = await selected_generator(translated_prompt, size_and_model['size'],
|
178 |
+
size_and_model['model'])
|
179 |
+
break
|
180 |
+
except Exception as e:
|
181 |
+
logger.error(f"Error generating image with generator {selected_generator.__name__}: {e}")
|
182 |
+
selected_generator = await select_random_image_generator()
|
183 |
+
attempts += 1
|
184 |
+
|
185 |
+
if attempts == max_attempts:
|
186 |
+
logger.error("Failed to generate image after multiple attempts")
|
187 |
+
return web.json_response({'error': "生成图像失败"}, status=500)
|
188 |
+
|
189 |
+
unique_id = f"chatcmpl-{int(time.time())}"
|
190 |
+
created_timestamp = int(time.time())
|
191 |
+
model_name = "flux"
|
192 |
+
system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9))
|
193 |
+
|
194 |
+
if stream:
|
195 |
+
return await handle_stream_response(request, unique_id, image_data, size_and_model['clean_prompt'],
|
196 |
+
translated_prompt, size_and_model['size'], created_timestamp,
|
197 |
+
model_name, system_fingerprint)
|
198 |
+
else:
|
199 |
+
return handle_non_stream_response(unique_id, image_data, size_and_model['clean_prompt'], translated_prompt,
|
200 |
+
size_and_model['size'], created_timestamp, model_name, system_fingerprint)
|
201 |
+
except Exception as e:
|
202 |
+
logger.error('Error handling request:', e)
|
203 |
+
return web.json_response({'error': f"处理请求失败: {str(e)}"}, status=500)
|
204 |
+
|
205 |
+
|
206 |
+
async def handle_stream_response(request, unique_id, image_data, original_prompt, translated_prompt, size, created,
|
207 |
+
model, system_fingerprint):
|
208 |
+
logger.debug("Starting stream response")
|
209 |
+
response = web.StreamResponse(
|
210 |
+
status=200,
|
211 |
+
reason='OK',
|
212 |
+
headers={
|
213 |
+
'Content-Type': 'text/event-stream',
|
214 |
+
'Cache-Control': 'no-cache',
|
215 |
+
'Connection': 'keep-alive'
|
216 |
+
}
|
217 |
+
)
|
218 |
+
await response.prepare(request)
|
219 |
+
logger.debug("Response prepared")
|
220 |
+
|
221 |
+
chunks = [
|
222 |
+
f"原始提示词:\n{original_prompt}\n",
|
223 |
+
f"翻译后的提示词:\n{translated_prompt}\n",
|
224 |
+
f"图像规格���{size}\n",
|
225 |
+
"正在根据提示词生成图像...\n",
|
226 |
+
"图像正在处理中...\n",
|
227 |
+
"即将完成...\n",
|
228 |
+
f"生成成功!\n图像生成完毕,以下是结果:\n\n"
|
229 |
+
]
|
230 |
+
|
231 |
+
for i, chunk in enumerate(chunks):
|
232 |
+
json_chunk = json.dumps({
|
233 |
+
"id": unique_id,
|
234 |
+
"object": "chat.completion.chunk",
|
235 |
+
"created": created,
|
236 |
+
"model": model,
|
237 |
+
"system_fingerprint": system_fingerprint,
|
238 |
+
"choices": [{
|
239 |
+
"index": 0,
|
240 |
+
"delta": {"content": chunk},
|
241 |
+
"logprobs": None,
|
242 |
+
"finish_reason": None
|
243 |
+
}]
|
244 |
+
})
|
245 |
+
try:
|
246 |
+
await response.write(f"data: {json_chunk}\n\n".encode('utf-8'))
|
247 |
+
logger.debug(f"Chunk {i + 1} sent")
|
248 |
+
except Exception as e:
|
249 |
+
logger.error(f"Error sending chunk {i + 1}: {str(e)}")
|
250 |
+
await asyncio.sleep(0.5) # 模拟生成时间
|
251 |
+
|
252 |
+
final_chunk = json.dumps({
|
253 |
+
"id": unique_id,
|
254 |
+
"object": "chat.completion.chunk",
|
255 |
+
"created": created,
|
256 |
+
"model": model,
|
257 |
+
"system_fingerprint": system_fingerprint,
|
258 |
+
"choices": [{
|
259 |
+
"index": 0,
|
260 |
+
"delta": {},
|
261 |
+
"logprobs": None,
|
262 |
+
"finish_reason": "stop"
|
263 |
+
}]
|
264 |
+
})
|
265 |
+
try:
|
266 |
+
await response.write(f"data: {final_chunk}\n\n".encode('utf-8'))
|
267 |
+
logger.debug("Final chunk sent")
|
268 |
+
except Exception as e:
|
269 |
+
logger.error(f"Error sending final chunk: {str(e)}")
|
270 |
+
|
271 |
+
await response.write_eof()
|
272 |
+
logger.debug("Stream response completed")
|
273 |
+
return response
|
274 |
+
|
275 |
+
|
276 |
+
def handle_non_stream_response(unique_id, image_data, original_prompt, translated_prompt, size, created, model,
|
277 |
+
system_fingerprint):
|
278 |
+
content = (
|
279 |
+
f"原始提示词:{original_prompt}\n"
|
280 |
+
f"翻译后的提示词:{translated_prompt}\n"
|
281 |
+
f"图像规格:{size}\n"
|
282 |
+
f"图像生成成功!\n"
|
283 |
+
f"以下是结果:\n\n"
|
284 |
+
f""
|
285 |
+
)
|
286 |
+
|
287 |
+
response = {
|
288 |
+
'id': unique_id,
|
289 |
+
'object': "chat.completion",
|
290 |
+
'created': created,
|
291 |
+
'model': model,
|
292 |
+
'system_fingerprint': system_fingerprint,
|
293 |
+
'choices': [{
|
294 |
+
'index': 0,
|
295 |
+
'message': {
|
296 |
+
'role': "assistant",
|
297 |
+
'content': content
|
298 |
+
},
|
299 |
+
'finish_reason': "stop"
|
300 |
+
}],
|
301 |
+
'usage': {
|
302 |
+
'prompt_tokens': len(original_prompt),
|
303 |
+
'completion_tokens': len(content),
|
304 |
+
'total_tokens': len(original_prompt) + len(content)
|
305 |
+
}
|
306 |
+
}
|
307 |
+
|
308 |
+
return web.json_response(response)
|
309 |
+
|
310 |
+
|
311 |
+
async def generate_image1(prompt, size):
|
312 |
+
# 调用 get_prompt 函数来增强提示词
|
313 |
+
enhanced_prompt = await get_prompt(prompt)
|
314 |
+
|
315 |
+
prompt_without_spaces = enhanced_prompt.replace(" ", "")
|
316 |
+
image_url = f"{URLS['API_FLUX1_API4GPT_COM']}/?prompt={prompt_without_spaces}&size={size}"
|
317 |
+
return {
|
318 |
+
'data': [{'url': image_url}],
|
319 |
+
'size': size
|
320 |
+
}
|
321 |
+
|
322 |
+
|
323 |
+
async def generate_image2(prompt, size):
|
324 |
+
random_ip = generate_random_ip()
|
325 |
+
# 调用 get_prompt 来增强提示词
|
326 |
+
enhanced_prompt = await get_prompt(prompt)
|
327 |
+
async with aiohttp.ClientSession() as session:
|
328 |
+
async with session.get(URLS['FLUXAIWEB_COM_TOKEN'],
|
329 |
+
headers={'X-Forwarded-For': random_ip}) as token_response:
|
330 |
+
token_data = await token_response.json()
|
331 |
+
token = token_data['data']['token']
|
332 |
+
|
333 |
+
async with session.post(URLS['FLUXAIWEB_COM_GENERATE'], headers={
|
334 |
+
'Content-Type': 'application/json',
|
335 |
+
'token': token,
|
336 |
+
'X-Forwarded-For': random_ip
|
337 |
+
}, json={
|
338 |
+
'prompt': enhanced_prompt,
|
339 |
+
'aspectRatio': size,
|
340 |
+
'outputFormat': 'webp',
|
341 |
+
'numOutputs': 1,
|
342 |
+
'outputQuality': 90
|
343 |
+
}) as image_response:
|
344 |
+
image_data = await image_response.json()
|
345 |
+
return {
|
346 |
+
'data': [{'url': image_data['data']['image']}],
|
347 |
+
'size': size
|
348 |
+
}
|
349 |
+
|
350 |
+
|
351 |
+
async def generate_image3(prompt, size):
|
352 |
+
json_body = {
|
353 |
+
'textStr': prompt,
|
354 |
+
'model': "black-forest-labs/flux-schnell",
|
355 |
+
'size': size
|
356 |
+
}
|
357 |
+
|
358 |
+
max_retries = 3
|
359 |
+
for attempt in range(max_retries):
|
360 |
+
try:
|
361 |
+
async with aiohttp.ClientSession() as session:
|
362 |
+
async with session.post(URLS['FLUXIMG_COM'], data=json.dumps(json_body),
|
363 |
+
headers={'Content-Type': 'text/plain;charset=UTF-8'}) as response:
|
364 |
+
if response.status == 200:
|
365 |
+
image_url = await response.text()
|
366 |
+
return {
|
367 |
+
'data': [{'url': image_url}],
|
368 |
+
'size': size
|
369 |
+
}
|
370 |
+
else:
|
371 |
+
logger.error(
|
372 |
+
f"Unexpected response status: {response.status}, response text: {await response.text()}")
|
373 |
+
except aiohttp.ClientConnectorError as e:
|
374 |
+
logger.error(f"Connection error on attempt {attempt + 1}: {e}")
|
375 |
+
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
376 |
+
|
377 |
+
logger.error("Failed to generate image after multiple attempts")
|
378 |
+
return {
|
379 |
+
'data': [{'url': "https://via.placeholder.com/640x480/428675/ffffff?text=Error"}],
|
380 |
+
'size': size
|
381 |
+
}
|
382 |
+
|
383 |
+
|
384 |
+
async def generate_image4(prompt, size, model):
|
385 |
+
if not config.AUTH_TOKEN:
|
386 |
+
raise Exception("AUTH_TOKEN is required for this method")
|
387 |
+
|
388 |
+
api_url = URL_MAP.get(model, URL_MAP['flux'])
|
389 |
+
clean_prompt = re.sub(r'--m\s+\S+', '', prompt).strip()
|
390 |
+
|
391 |
+
# 调用 get_prompt 函数来增强提示词
|
392 |
+
enhanced_prompt = await get_prompt(clean_prompt)
|
393 |
+
|
394 |
+
json_body = {
|
395 |
+
'prompt': enhanced_prompt,
|
396 |
+
'image_size': RATIO_MAP.get(size, "1024x1024"),
|
397 |
+
'num_inference_steps': 50
|
398 |
+
}
|
399 |
+
|
400 |
+
if model and model != "flux":
|
401 |
+
json_body['batch_size'] = 1
|
402 |
+
json_body['guidance_scale'] = 7.5
|
403 |
+
|
404 |
+
if model in ["sdt", "sdxlt"]:
|
405 |
+
json_body['num_inference_steps'] = 6
|
406 |
+
json_body['guidance_scale'] = 1
|
407 |
+
elif model == "sdxll":
|
408 |
+
json_body['num_inference_steps'] = 4
|
409 |
+
json_body['guidance_scale'] = 1
|
410 |
+
|
411 |
+
async with aiohttp.ClientSession() as session:
|
412 |
+
async with session.post(api_url, headers={
|
413 |
+
'authorization': config.AUTH_TOKEN if config.AUTH_TOKEN.startswith(
|
414 |
+
'Bearer ') else f'Bearer {config.AUTH_TOKEN}',
|
415 |
+
'Accept': 'application/json',
|
416 |
+
'Content-Type': 'application/json'
|
417 |
+
}, json=json_body) as response:
|
418 |
+
if response.status != 200:
|
419 |
+
raise Exception(f'Unexpected response {response.status}')
|
420 |
+
|
421 |
+
json_response = await response.json()
|
422 |
+
return {
|
423 |
+
'data': [{'url': json_response['images'][0]['url']}],
|
424 |
+
'size': size
|
425 |
+
}
|
426 |
+
|
427 |
+
|
428 |
+
def generate_random_ip():
|
429 |
+
return f"{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
|
430 |
+
|
431 |
+
|
432 |
+
async def get_prompt(prompt):
|
433 |
+
logger.info(f"Original Prompt: {prompt}") # 记录输入的原始提示词
|
434 |
+
request_body_json = json.dumps({
|
435 |
+
'model': config.DEFAULT_PROMPT_MODEL, # 使用config中的model
|
436 |
+
'messages': [
|
437 |
+
{
|
438 |
+
'role': "system",
|
439 |
+
'content': SYSTEM_ASSISTANT
|
440 |
+
},
|
441 |
+
{
|
442 |
+
'role': "user",
|
443 |
+
'content': prompt
|
444 |
+
}
|
445 |
+
],
|
446 |
+
'stream': False,
|
447 |
+
'max_tokens': 512,
|
448 |
+
'temperature': 0.7,
|
449 |
+
'top_p': 0.7,
|
450 |
+
'top_k': 50,
|
451 |
+
'frequency_penalty': 0.5,
|
452 |
+
'n': 1
|
453 |
+
})
|
454 |
+
|
455 |
+
# 打印出请求的详细信息
|
456 |
+
logger.debug(f"Request Body: {request_body_json}")
|
457 |
+
request_headers = {
|
458 |
+
'accept': 'application/json',
|
459 |
+
'authorization': config.OPENAI_CHAT_API_KEY if config.OPENAI_CHAT_API_KEY.startswith(
|
460 |
+
'Bearer ') else f'Bearer {config.OPENAI_CHAT_API_KEY}',
|
461 |
+
'content-type': 'application/json'
|
462 |
+
}
|
463 |
+
logger.debug(f"Request Headers: {request_headers}")
|
464 |
+
|
465 |
+
try:
|
466 |
+
async with aiohttp.ClientSession() as session:
|
467 |
+
async with session.post(config.OPENAI_CHAT_API, headers=request_headers,
|
468 |
+
data=request_body_json) as response:
|
469 |
+
if response.status != 200:
|
470 |
+
error_text = await response.text()
|
471 |
+
logger.error(f"Failed to get response, status code: {response.status}, response: {error_text}")
|
472 |
+
return prompt
|
473 |
+
|
474 |
+
json_response = await response.json()
|
475 |
+
logger.debug(f"API Response: {json_response}") # 记录API的完整响应
|
476 |
+
|
477 |
+
if 'choices' in json_response and len(json_response['choices']) > 0:
|
478 |
+
enhanced_prompt = json_response['choices'][0]['message']['content']
|
479 |
+
logger.info(f"Enhanced Prompt: {enhanced_prompt}") # 记录增强后的提示词
|
480 |
+
return enhanced_prompt
|
481 |
+
else:
|
482 |
+
logger.warning("No enhanced prompt found in the response, returning original prompt.")
|
483 |
+
return prompt
|
484 |
+
except Exception as e:
|
485 |
+
logger.error(f"Exception occurred: {e}")
|
486 |
+
return prompt
|
487 |
+
|
488 |
+
|
489 |
+
app = web.Application()
|
490 |
+
app.router.add_post('/v1/chat/completions', handle_request)
|
491 |
+
|
492 |
+
if __name__ == '__main__':
|
493 |
+
web.run_app(app, host='0.0.0.0', port=6666)
|