tianlong12 commited on
Commit
725308e
·
verified ·
1 Parent(s): 66ed197

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +493 -0
main.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import random
5
+ import re
6
+ import json
7
+ import time
8
+ import aiohttp
9
+ import asyncio
10
+ from aiohttp import web
11
+ import unicodedata
12
+ from dataclasses import dataclass, asdict
13
+ import logging
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class Config:
21
+ # SILICONFLOW-Key
22
+ AUTH_TOKEN: str = 'sk-xxxxxxxxx'
23
+ # One-API/New-API 中转地址
24
+ OPENAI_CHAT_API: str = 'https://xxxxxxxxxxxx/v1/chat/completions'
25
+ # key
26
+ OPENAI_CHAT_API_KEY: str = 'sk-xxxxxxxxxx'
27
+ # 默认的翻译模型
28
+ DEFAULT_TRANSLATE_MODEL: str = 'deepseek-chat'
29
+ # 增强的翻译模型
30
+ DEFAULT_PROMPT_MODEL: str = 'Qwen2-72B-Instruct'
31
+
32
+
33
+ config = Config()
34
+
35
+ URLS = {
36
+ 'API_FLUX1_API4GPT_COM': 'https://api-flux1.api4gpt.com',
37
+ 'FLUXAIWEB_COM_TOKEN': 'https://fluxaiweb.com/flux/getToken',
38
+ 'FLUXAIWEB_COM_GENERATE': 'https://fluxaiweb.com/flux/generateImage',
39
+ 'FLUXIMG_COM': 'https://fluximg.com/api/image/generateImage',
40
+ 'API_SILICONFLOW_CN': 'https://api.siliconflow.cn/v1/chat/completions'
41
+ }
42
+
43
+ URL_MAP = {
44
+ 'flux': "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image",
45
+ 'sd3': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-3-medium/text-to-image",
46
+ 'sdxl': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image",
47
+ 'sd2': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-2-1/text-to-image",
48
+ 'sdt': "https://api.siliconflow.cn/v1/stabilityai/sd-turbo/text-to-image",
49
+ 'sdxlt': "https://api.siliconflow.cn/v1/stabilityai/sdxl-turbo/text-to-image",
50
+ 'sdxll': "https://api.siliconflow.cn/v1/ByteDance/SDXL-Lightning/text-to-image"
51
+ }
52
+
53
+ IMG_URL_MAP = {
54
+ 'sdxl': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/image-to-image",
55
+ 'sd2': "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-2-1/image-to-image",
56
+ 'sdxll': "https://api.siliconflow.cn/v1/ByteDance/SDXL-Lightning/image-to-image",
57
+ 'pm': "https://api.siliconflow.cn/v1/TencentARC/PhotoMaker/image-to-image"
58
+ }
59
+
60
+ RATIO_MAP = {
61
+ "1:1": "1024x1024",
62
+ "1:2": "1024x2048",
63
+ "3:2": "1536x1024",
64
+ "4:3": "1536x2048",
65
+ "16:9": "2048x1152",
66
+ "9:16": "1152x2048"
67
+ }
68
+
69
+ SYSTEM_ASSISTANT = """作为 Stable Diffusion Prompt 提示词专家,您将从关键词中创建提示,通常来自 Danbooru 等数据库。
70
+ 提示通常描述图像,使用常见词汇,按重要性排列,并用逗号分隔。避免使用"-"或".",但可以接受空格和自然语言。避免词汇重复。
71
+
72
+ 为了强调关键词,请将其放在括号中以增加其权重。例如,"(flowers)"将'flowers'的权重增加1.1倍,而"(((flowers)))"将其增加1.331倍。使用"(flowers:1.5)"将'flowers'的权重增加1.5倍。只为重要的标签增加权重。
73
+
74
+ 提示包括三个部分:**前缀**(质量标签+风格词+效果器)+ **主题**(图像的主要焦点)+ **场景**(背景、环境)。
75
+
76
+ * 前缀影响图像质量。像"masterpiece"、"best quality"、"4k"这样的标签可以提高图像的细节。像"illustration"、"lensflare"这样的风格词定义图像的风格。像"bestlighting"、"lensflare"、"depthoffield"这样的效果器会影响光照和深度。
77
+
78
+ * 主题是图像的主要焦点,如角色或场景。对主题进行详细描述可以确保图像丰富而详细。增加主题的权重以增强其清晰度。对于角色,描述面部、头发、身体、服装、姿势等特征。
79
+
80
+ * 场景描述环境。没有场景,图像的背景是平淡的,主题显得过大。某些主题本身包含场景(例如建筑物、风景)。像"花草草地"、"阳光"、"河流"这样的环境词可以丰富场景。你的任务是设计图像生成的提示。请按照以下步骤进行操作:
81
+
82
+ 1. 我会发送给您一个图像场景。需要你生成详细的图像描述
83
+ 2. 图像描述必须是英文,输出为Positive Prompt。
84
+
85
+ 示例:
86
+
87
+ 我发送:二战时期的护士。
88
+ 您回复只回复:
89
+ A WWII-era nurse in a German uniform, holding a wine bottle and stethoscope, sitting at a table in white attire, with a table in the background, masterpiece, best quality, 4k, illustration style, best lighting, depth of field, detailed character, detailed environment.
90
+ """
91
+
92
+
93
+ async def select_random_image_generator():
94
+ generators = [generate_image1, generate_image2, generate_image3, generate_image4]
95
+ generator = random.choice(generators)
96
+ if generator == generate_image4:
97
+ return lambda prompt, size, model: generator(prompt, size, model)
98
+ else:
99
+ return lambda prompt, size, model=None: generator(prompt, size)
100
+
101
+
102
+ def extract_size_and_model_from_prompt(prompt):
103
+ size_match = re.search(r'--ar\s+(\S+)', prompt)
104
+ model_match = re.search(r'--m\s+(\S+)', prompt)
105
+
106
+ size = size_match.group(1) if size_match else '1:1'
107
+ model = model_match.group(1) if model_match else ''
108
+
109
+ clean_prompt = re.sub(r'--ar\s+\S+', '', prompt).strip()
110
+ clean_prompt = re.sub(r'--m\s+\S+', '', clean_prompt).strip()
111
+
112
+ return {'size': size, 'model': model, 'clean_prompt': clean_prompt}
113
+
114
+
115
+ def is_chinese(char):
116
+ return 'CJK' in unicodedata.name(char, '')
117
+
118
+
119
+ async def translate_prompt(prompt):
120
+ if all(not is_chinese(char) for char in prompt):
121
+ logger.info('Prompt is already in English, skipping translation')
122
+ return prompt
123
+
124
+ try:
125
+ async with aiohttp.ClientSession() as session:
126
+ async with session.post(config.OPENAI_CHAT_API, json={
127
+ 'model': config.DEFAULT_TRANSLATE_MODEL, # 使用config中的model
128
+ 'messages': [
129
+ {'role': 'system', 'content': SYSTEM_ASSISTANT},
130
+ {'role': 'user', 'content': prompt}
131
+ ],
132
+ }, headers={
133
+ 'Content-Type': 'application/json',
134
+ 'Authorization': f'Bearer {config.OPENAI_CHAT_API_KEY}'
135
+ }) as response:
136
+ if response.status != 200:
137
+ error_text = await response.text()
138
+ logger.error(f'HTTP error! status: {response.status}, body: {error_text}')
139
+ raise Exception(f'HTTP error! status: {response.status}')
140
+
141
+ if 'application/json' not in response.headers.get('Content-Type', ''):
142
+ error_text = await response.text()
143
+ logger.error(f'Unexpected content type: {response.headers.get("Content-Type")}, body: {error_text}')
144
+ raise Exception(f'Unexpected content type: {response.headers.get("Content-Type")}')
145
+
146
+ result = await response.json()
147
+ return result['choices'][0]['message']['content']
148
+
149
+ except Exception as e:
150
+ logger.error('Translation error:', e)
151
+ return prompt
152
+
153
+
154
+ async def handle_request(request):
155
+ if request.method != 'POST' or not request.url.path.endswith('/v1/chat/completions'):
156
+ return web.Response(text='Not Found', status=404)
157
+
158
+ try:
159
+ data = await request.json()
160
+ messages = data.get('messages', [])
161
+ stream = data.get('stream', False)
162
+
163
+ user_message = next((msg['content'] for msg in reversed(messages) if msg['role'] == 'user'), None)
164
+
165
+ if not user_message:
166
+ return web.json_response({'error': "未找到用户消息"}, status=400)
167
+
168
+ size_and_model = extract_size_and_model_from_prompt(user_message)
169
+ translated_prompt = await translate_prompt(size_and_model['clean_prompt'])
170
+
171
+ selected_generator = await select_random_image_generator()
172
+ attempts = 0
173
+ max_attempts = 3
174
+
175
+ while attempts < max_attempts:
176
+ try:
177
+ image_data = await selected_generator(translated_prompt, size_and_model['size'],
178
+ size_and_model['model'])
179
+ break
180
+ except Exception as e:
181
+ logger.error(f"Error generating image with generator {selected_generator.__name__}: {e}")
182
+ selected_generator = await select_random_image_generator()
183
+ attempts += 1
184
+
185
+ if attempts == max_attempts:
186
+ logger.error("Failed to generate image after multiple attempts")
187
+ return web.json_response({'error': "生成图像失败"}, status=500)
188
+
189
+ unique_id = f"chatcmpl-{int(time.time())}"
190
+ created_timestamp = int(time.time())
191
+ model_name = "flux"
192
+ system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9))
193
+
194
+ if stream:
195
+ return await handle_stream_response(request, unique_id, image_data, size_and_model['clean_prompt'],
196
+ translated_prompt, size_and_model['size'], created_timestamp,
197
+ model_name, system_fingerprint)
198
+ else:
199
+ return handle_non_stream_response(unique_id, image_data, size_and_model['clean_prompt'], translated_prompt,
200
+ size_and_model['size'], created_timestamp, model_name, system_fingerprint)
201
+ except Exception as e:
202
+ logger.error('Error handling request:', e)
203
+ return web.json_response({'error': f"处理请求失败: {str(e)}"}, status=500)
204
+
205
+
206
+ async def handle_stream_response(request, unique_id, image_data, original_prompt, translated_prompt, size, created,
207
+ model, system_fingerprint):
208
+ logger.debug("Starting stream response")
209
+ response = web.StreamResponse(
210
+ status=200,
211
+ reason='OK',
212
+ headers={
213
+ 'Content-Type': 'text/event-stream',
214
+ 'Cache-Control': 'no-cache',
215
+ 'Connection': 'keep-alive'
216
+ }
217
+ )
218
+ await response.prepare(request)
219
+ logger.debug("Response prepared")
220
+
221
+ chunks = [
222
+ f"原始提示词:\n{original_prompt}\n",
223
+ f"翻译后的提示词:\n{translated_prompt}\n",
224
+ f"图像规格���{size}\n",
225
+ "正在根据提示词生成图像...\n",
226
+ "图像正在处理中...\n",
227
+ "即将完成...\n",
228
+ f"生成成功!\n图像生成完毕,以下是结果:\n\n![生成的图像]({image_data['data'][0]['url']})"
229
+ ]
230
+
231
+ for i, chunk in enumerate(chunks):
232
+ json_chunk = json.dumps({
233
+ "id": unique_id,
234
+ "object": "chat.completion.chunk",
235
+ "created": created,
236
+ "model": model,
237
+ "system_fingerprint": system_fingerprint,
238
+ "choices": [{
239
+ "index": 0,
240
+ "delta": {"content": chunk},
241
+ "logprobs": None,
242
+ "finish_reason": None
243
+ }]
244
+ })
245
+ try:
246
+ await response.write(f"data: {json_chunk}\n\n".encode('utf-8'))
247
+ logger.debug(f"Chunk {i + 1} sent")
248
+ except Exception as e:
249
+ logger.error(f"Error sending chunk {i + 1}: {str(e)}")
250
+ await asyncio.sleep(0.5) # 模拟生成时间
251
+
252
+ final_chunk = json.dumps({
253
+ "id": unique_id,
254
+ "object": "chat.completion.chunk",
255
+ "created": created,
256
+ "model": model,
257
+ "system_fingerprint": system_fingerprint,
258
+ "choices": [{
259
+ "index": 0,
260
+ "delta": {},
261
+ "logprobs": None,
262
+ "finish_reason": "stop"
263
+ }]
264
+ })
265
+ try:
266
+ await response.write(f"data: {final_chunk}\n\n".encode('utf-8'))
267
+ logger.debug("Final chunk sent")
268
+ except Exception as e:
269
+ logger.error(f"Error sending final chunk: {str(e)}")
270
+
271
+ await response.write_eof()
272
+ logger.debug("Stream response completed")
273
+ return response
274
+
275
+
276
+ def handle_non_stream_response(unique_id, image_data, original_prompt, translated_prompt, size, created, model,
277
+ system_fingerprint):
278
+ content = (
279
+ f"原始提示词:{original_prompt}\n"
280
+ f"翻译后的提示词:{translated_prompt}\n"
281
+ f"图像规格:{size}\n"
282
+ f"图像生成成功!\n"
283
+ f"以下是结果:\n\n"
284
+ f"![生成的图像]({image_data['data'][0]['url']})"
285
+ )
286
+
287
+ response = {
288
+ 'id': unique_id,
289
+ 'object': "chat.completion",
290
+ 'created': created,
291
+ 'model': model,
292
+ 'system_fingerprint': system_fingerprint,
293
+ 'choices': [{
294
+ 'index': 0,
295
+ 'message': {
296
+ 'role': "assistant",
297
+ 'content': content
298
+ },
299
+ 'finish_reason': "stop"
300
+ }],
301
+ 'usage': {
302
+ 'prompt_tokens': len(original_prompt),
303
+ 'completion_tokens': len(content),
304
+ 'total_tokens': len(original_prompt) + len(content)
305
+ }
306
+ }
307
+
308
+ return web.json_response(response)
309
+
310
+
311
+ async def generate_image1(prompt, size):
312
+ # 调用 get_prompt 函数来增强提示词
313
+ enhanced_prompt = await get_prompt(prompt)
314
+
315
+ prompt_without_spaces = enhanced_prompt.replace(" ", "")
316
+ image_url = f"{URLS['API_FLUX1_API4GPT_COM']}/?prompt={prompt_without_spaces}&size={size}"
317
+ return {
318
+ 'data': [{'url': image_url}],
319
+ 'size': size
320
+ }
321
+
322
+
323
+ async def generate_image2(prompt, size):
324
+ random_ip = generate_random_ip()
325
+ # 调用 get_prompt 来增强提示词
326
+ enhanced_prompt = await get_prompt(prompt)
327
+ async with aiohttp.ClientSession() as session:
328
+ async with session.get(URLS['FLUXAIWEB_COM_TOKEN'],
329
+ headers={'X-Forwarded-For': random_ip}) as token_response:
330
+ token_data = await token_response.json()
331
+ token = token_data['data']['token']
332
+
333
+ async with session.post(URLS['FLUXAIWEB_COM_GENERATE'], headers={
334
+ 'Content-Type': 'application/json',
335
+ 'token': token,
336
+ 'X-Forwarded-For': random_ip
337
+ }, json={
338
+ 'prompt': enhanced_prompt,
339
+ 'aspectRatio': size,
340
+ 'outputFormat': 'webp',
341
+ 'numOutputs': 1,
342
+ 'outputQuality': 90
343
+ }) as image_response:
344
+ image_data = await image_response.json()
345
+ return {
346
+ 'data': [{'url': image_data['data']['image']}],
347
+ 'size': size
348
+ }
349
+
350
+
351
+ async def generate_image3(prompt, size):
352
+ json_body = {
353
+ 'textStr': prompt,
354
+ 'model': "black-forest-labs/flux-schnell",
355
+ 'size': size
356
+ }
357
+
358
+ max_retries = 3
359
+ for attempt in range(max_retries):
360
+ try:
361
+ async with aiohttp.ClientSession() as session:
362
+ async with session.post(URLS['FLUXIMG_COM'], data=json.dumps(json_body),
363
+ headers={'Content-Type': 'text/plain;charset=UTF-8'}) as response:
364
+ if response.status == 200:
365
+ image_url = await response.text()
366
+ return {
367
+ 'data': [{'url': image_url}],
368
+ 'size': size
369
+ }
370
+ else:
371
+ logger.error(
372
+ f"Unexpected response status: {response.status}, response text: {await response.text()}")
373
+ except aiohttp.ClientConnectorError as e:
374
+ logger.error(f"Connection error on attempt {attempt + 1}: {e}")
375
+ await asyncio.sleep(2 ** attempt) # Exponential backoff
376
+
377
+ logger.error("Failed to generate image after multiple attempts")
378
+ return {
379
+ 'data': [{'url': "https://via.placeholder.com/640x480/428675/ffffff?text=Error"}],
380
+ 'size': size
381
+ }
382
+
383
+
384
+ async def generate_image4(prompt, size, model):
385
+ if not config.AUTH_TOKEN:
386
+ raise Exception("AUTH_TOKEN is required for this method")
387
+
388
+ api_url = URL_MAP.get(model, URL_MAP['flux'])
389
+ clean_prompt = re.sub(r'--m\s+\S+', '', prompt).strip()
390
+
391
+ # 调用 get_prompt 函数来增强提示词
392
+ enhanced_prompt = await get_prompt(clean_prompt)
393
+
394
+ json_body = {
395
+ 'prompt': enhanced_prompt,
396
+ 'image_size': RATIO_MAP.get(size, "1024x1024"),
397
+ 'num_inference_steps': 50
398
+ }
399
+
400
+ if model and model != "flux":
401
+ json_body['batch_size'] = 1
402
+ json_body['guidance_scale'] = 7.5
403
+
404
+ if model in ["sdt", "sdxlt"]:
405
+ json_body['num_inference_steps'] = 6
406
+ json_body['guidance_scale'] = 1
407
+ elif model == "sdxll":
408
+ json_body['num_inference_steps'] = 4
409
+ json_body['guidance_scale'] = 1
410
+
411
+ async with aiohttp.ClientSession() as session:
412
+ async with session.post(api_url, headers={
413
+ 'authorization': config.AUTH_TOKEN if config.AUTH_TOKEN.startswith(
414
+ 'Bearer ') else f'Bearer {config.AUTH_TOKEN}',
415
+ 'Accept': 'application/json',
416
+ 'Content-Type': 'application/json'
417
+ }, json=json_body) as response:
418
+ if response.status != 200:
419
+ raise Exception(f'Unexpected response {response.status}')
420
+
421
+ json_response = await response.json()
422
+ return {
423
+ 'data': [{'url': json_response['images'][0]['url']}],
424
+ 'size': size
425
+ }
426
+
427
+
428
+ def generate_random_ip():
429
+ return f"{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
430
+
431
+
432
+ async def get_prompt(prompt):
433
+ logger.info(f"Original Prompt: {prompt}") # 记录输入的原始提示词
434
+ request_body_json = json.dumps({
435
+ 'model': config.DEFAULT_PROMPT_MODEL, # 使用config中的model
436
+ 'messages': [
437
+ {
438
+ 'role': "system",
439
+ 'content': SYSTEM_ASSISTANT
440
+ },
441
+ {
442
+ 'role': "user",
443
+ 'content': prompt
444
+ }
445
+ ],
446
+ 'stream': False,
447
+ 'max_tokens': 512,
448
+ 'temperature': 0.7,
449
+ 'top_p': 0.7,
450
+ 'top_k': 50,
451
+ 'frequency_penalty': 0.5,
452
+ 'n': 1
453
+ })
454
+
455
+ # 打印出请求的详细信息
456
+ logger.debug(f"Request Body: {request_body_json}")
457
+ request_headers = {
458
+ 'accept': 'application/json',
459
+ 'authorization': config.OPENAI_CHAT_API_KEY if config.OPENAI_CHAT_API_KEY.startswith(
460
+ 'Bearer ') else f'Bearer {config.OPENAI_CHAT_API_KEY}',
461
+ 'content-type': 'application/json'
462
+ }
463
+ logger.debug(f"Request Headers: {request_headers}")
464
+
465
+ try:
466
+ async with aiohttp.ClientSession() as session:
467
+ async with session.post(config.OPENAI_CHAT_API, headers=request_headers,
468
+ data=request_body_json) as response:
469
+ if response.status != 200:
470
+ error_text = await response.text()
471
+ logger.error(f"Failed to get response, status code: {response.status}, response: {error_text}")
472
+ return prompt
473
+
474
+ json_response = await response.json()
475
+ logger.debug(f"API Response: {json_response}") # 记录API的完整响应
476
+
477
+ if 'choices' in json_response and len(json_response['choices']) > 0:
478
+ enhanced_prompt = json_response['choices'][0]['message']['content']
479
+ logger.info(f"Enhanced Prompt: {enhanced_prompt}") # 记录增强后的提示词
480
+ return enhanced_prompt
481
+ else:
482
+ logger.warning("No enhanced prompt found in the response, returning original prompt.")
483
+ return prompt
484
+ except Exception as e:
485
+ logger.error(f"Exception occurred: {e}")
486
+ return prompt
487
+
488
+
489
+ app = web.Application()
490
+ app.router.add_post('/v1/chat/completions', handle_request)
491
+
492
+ if __name__ == '__main__':
493
+ web.run_app(app, host='0.0.0.0', port=6666)