Spaces:
Running
Running
from asyncio import sleep | |
from base64 import b64decode | |
from binascii import Error as BinasciiError | |
from contextlib import asynccontextmanager | |
from io import BytesIO | |
from json import dumps, loads | |
from logging import Formatter, INFO, StreamHandler, getLogger | |
from pathlib import Path | |
from random import choice | |
from typing import AsyncGenerator | |
from PIL.Image import open as image_open | |
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from httpx import AsyncClient | |
from patchright.async_api import FilePayload, Request as PlaywrightRequest, async_playwright | |
from prlps_fakeua import UserAgent | |
from starlette.responses import Response | |
logger = getLogger('RHYMES_AI_API') | |
logger.setLevel(INFO) | |
handler = StreamHandler() | |
handler.setLevel(INFO) | |
formatter = Formatter('%(asctime)s | %(levelname)s : %(message)s', datefmt='%d.%m.%Y %H:%M:%S') | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
logger.info('инициализация приложения...') | |
ua = UserAgent(os=['windows', 'mac']) | |
workdir = Path(__file__).parent | |
infer_data = workdir / 'infer_data.json' | |
BASE_URL = 'https://akhaliq-anychat.hf.space' | |
def base64_to_jpeg_bytes(base64_str: str) -> bytes: | |
try: | |
if ',' not in base64_str: | |
raise ValueError("недопустимый формат строки base64") | |
base64_data = base64_str.split(',', 1)[1] | |
binary_data = b64decode(base64_data) | |
with image_open(BytesIO(binary_data)) as img: | |
with BytesIO() as jpeg_bytes: | |
img.convert('RGB').save(jpeg_bytes, format='JPEG', quality=90, optimize=True) | |
return jpeg_bytes.getvalue() | |
except (BinasciiError, OSError) as e: | |
raise ValueError('данные не являются корректным изображением') from e | |
def image_bytes(base64_image_str: str) -> FilePayload: | |
return FilePayload( | |
name=generate_random_string(12) + '.jpeg', | |
mimeType='image/jpeg', | |
buffer=base64_to_jpeg_bytes(base64_image_str) | |
) | |
def generate_random_string(length): | |
return ''.join(choice('abcdefghijklmnopqrstuvwxyz0123456789') for _ in range(length)) | |
def get_infer_data() -> tuple[int, int, str]: | |
data = loads(infer_data.read_text()) | |
logger.debug(f'загруженные из файла данные `get_infer_data`: {data}') | |
return data['fn_index'], data['trigger_id'], data['session_hash'] | |
def prepare_data(gradio_file_path: str, question: str, fn_index: int, trigger_id: int, session_hash: str) -> dict: | |
return { | |
"data": [ | |
None, | |
[[{ | |
"file": { | |
"path": gradio_file_path, | |
"url": f"{BASE_URL}/gradio_api/file={gradio_file_path}", | |
"size": None, "orig_name": None, "mime_type": "image/jpeg", "is_stream": False, | |
"meta": {"_type": "gradio.FileData"} | |
}, | |
"alt_text": None | |
}, None], [question, None]] | |
], "event_data": None, | |
"fn_index": fn_index, | |
"trigger_id": trigger_id, | |
"session_hash": session_hash | |
} | |
async def fetch_result(base64_image_str: str, question: str) -> str | None: | |
fn_index, trigger_id, session_hash = get_infer_data() | |
async with AsyncClient(follow_redirects=True, timeout=40) as client: | |
image_file = image_bytes(base64_image_str) | |
boundary = f'----WebKitFormBoundary{generate_random_string(15).upper()}' | |
upload_response = await client.post( | |
f'{BASE_URL}/gradio_api/upload?upload_id={generate_random_string(11)}', | |
headers={ | |
'Content-Type': f'multipart/form-data; boundary={boundary}', | |
'accept': '*/*' | |
}, | |
content=( | |
f'--{boundary}\r\n' | |
f'Content-Disposition: form-data; name="files"; filename="{image_file.get('name')}"\r\n' | |
f'Content-Type: {image_file.get("mimeType")}\r\n\r\n' | |
f'{image_file.get("buffer").decode("latin1")}\r\n' | |
f'--{boundary}--\r\n' | |
).encode('latin1') | |
) | |
upload_response.raise_for_status() | |
gradio_file_path = upload_response.json()[0] | |
logger.debug(f'gradio_file_path: {gradio_file_path}') | |
send_response = await client.post( | |
f'{BASE_URL}/gradio_api/queue/join', | |
headers={ | |
'accept': '*/*', | |
'content-type': 'application/json' | |
}, | |
json=prepare_data(gradio_file_path, question, fn_index, trigger_id, session_hash) | |
) | |
send_response.raise_for_status() | |
logger.debug(f'send_response: {send_response.text}') | |
async with client.stream( | |
'GET', | |
f'{BASE_URL}/gradio_api/queue/data?session_hash={session_hash}', | |
headers={'accept': 'text/event-stream', 'content-type': 'application/json' | |
}) as result_response: | |
result_response.raise_for_status() | |
async for line in result_response.aiter_lines(): | |
if line.startswith('data:'): | |
logger.debug(f'result_response line: {line}') | |
event_data = loads(line[6:]) | |
if event_data.get('msg') == 'process_completed': | |
logger.debug(f'process_completed: {event_data}') | |
data = event_data.get('output', {}).get('data', []) | |
if data: | |
return data[0][1][1] | |
return None | |
def take_infer_data(request: PlaywrightRequest): | |
if request.url.startswith("https://akhaliq-anychat.hf.space/gradio_api/queue/join"): | |
try: | |
data = loads(request.post_data) | |
if data.get('data'): | |
fn_index = data.get('fn_index') | |
trigger_id = data.get('trigger_id') | |
session_hash = data.get('session_hash') | |
if fn_index and trigger_id and session_hash: | |
infer_data_json = { | |
'fn_index': fn_index, | |
'trigger_id': trigger_id, | |
'session_hash': session_hash | |
} | |
infer_data.write_text(dumps(infer_data_json, indent=4)) | |
logger.debug(f'полученные из браузера данные в `take_infer_data`: {infer_data_json}') | |
except Exception as ext: | |
logger.error(f'ошибка `take_infer_data`: {ext}') | |
pass | |
async def browser_request(base64_image_str: str, question: str) -> str | None: | |
async with async_playwright() as playwright: | |
browser = await playwright.chromium.launch(headless=True, args=['--disable-blink-features=AutomationControlled']) | |
context = await browser.new_context( | |
viewport={'width': 2560, 'height': 1440}, | |
screen={'width': 2560, 'height': 1286}, | |
color_scheme='dark', | |
ignore_https_errors=True, | |
locale='en-US', | |
user_agent=ua.random, | |
) | |
try: | |
page = await context.new_page() | |
image_file = image_bytes(base64_image_str) | |
page.on('request', take_infer_data) | |
await page.goto('https://akhaliq-anychat.hf.space/?__theme=light') | |
await page.get_by_role('tab', name='Grok').click() | |
await page.get_by_role('textbox', name='Type a message...').fill(question) | |
await page.get_by_role('group', name='Multimedia input field').get_by_test_id('file-upload').set_input_files(image_file) | |
await page.wait_for_selector('img.thumbnail-image') | |
submit_button = page.get_by_role('group', name='Multimedia input field').locator('.submit-button') | |
await submit_button.click() | |
await page.wait_for_selector('button[aria-label="Retry"]', state='visible') | |
await submit_button.wait_for(state='visible') | |
caption = ' '.join(await page.get_by_test_id('bot').all_text_contents()).strip() | |
await context.close() | |
await browser.close() | |
if caption: | |
logger.info('результат получен из `browser_request`') | |
return caption | |
except Exception as exc: | |
logger.error(f'ошибка `browser_request`: {exc}') | |
return None | |
async def httpx_request(base64_image_str: str, question: str) -> str | None: | |
try: | |
caption = await fetch_result(base64_image_str, question) | |
logger.debug(caption) | |
if caption: | |
logger.info('результат получен из `httpx_request`') | |
return caption | |
except Exception as exc: | |
logger.error(f'ошибка `browser_request`: {exc}') | |
return None | |
async def get_grok_caption(base64_image_str: str, question: str) -> str | None: | |
attempts = 3 | |
for _ in range(attempts): | |
result = await httpx_request(base64_image_str, question) | |
if result: | |
return result | |
result = await browser_request(base64_image_str, question) | |
if result: | |
return result | |
await sleep(1.5) | |
logger.error(f'превышено максимальное количество попыток') | |
return None | |
async def app_lifespan(_) -> AsyncGenerator: | |
logger.info('запуск приложения') | |
try: | |
logger.info('старт API') | |
yield | |
finally: | |
logger.info('приложение завершено') | |
app = FastAPI(lifespan=app_lifespan, title='RHYMES_AI_API') | |
banned_endpoints = [ | |
'/openapi.json', | |
'/docs', | |
'/docs/oauth2-redirect', | |
'swagger_ui_redirect', | |
'/redoc', | |
] | |
async def block_banned_endpoints(request: Request, call_next): | |
logger.debug(f'получен запрос: {request.url.path}') | |
if request.url.path in banned_endpoints: | |
logger.warning(f'запрещенный endpoint: {request.url.path}') | |
return Response(status_code=403) | |
response = await call_next(request) | |
return response | |
async def describe_v1(request: Request): | |
logger.info('запрос `describe_v1`') | |
body = await request.json() | |
content_text = '' | |
image_data = '' | |
messages = body.get('messages', []) | |
for message in messages: | |
role = message.get('role') | |
content = message.get('content') | |
if role in ['system', 'user']: | |
if isinstance(content, str): | |
content_text += content + ' ' | |
elif isinstance(content, list): | |
for item in content: | |
if item.get('type') == 'text': | |
content_text += item.get('text', '') + ' ' | |
elif item.get('type') == 'image_url': | |
image_url = item.get('image_url', {}) | |
url = image_url.get('url') | |
if url and url.startswith('data:image/'): | |
image_data = url | |
image_data, content_text = image_data.strip(), content_text.strip() | |
if not content_text or not image_data: | |
return JSONResponse({'caption': 'изображение должно быть передано как строка base64 `data:image/jpeg;base64,{base64_img}` а также текст'}, status_code=400) | |
try: | |
caption = await get_grok_caption(image_data, content_text) | |
return JSONResponse({'caption': caption}, status_code=200) | |
except Exception as e: | |
return JSONResponse({'caption': str(e)}, status_code=500) | |
async def root(): | |
return HTMLResponse('ну пролапс, ну и что', status_code=200) | |
if __name__ == '__main__': | |
from uvicorn import run as uvicorn_run | |
logger.info('запуск сервера uvicorn') | |
uvicorn_run(app, host='0.0.0.0', port=7860) | |