Spaces:
Running
Running
import asyncio | |
import gc | |
import io | |
import pathlib | |
from io import BytesIO | |
from PIL import Image | |
from loguru import logger | |
from app.Models.api_models.admin_query_params import UploadImageThumbnailMode | |
from app.Models.errors import PointDuplicateError | |
from app.Models.img_data import ImageData | |
from app.Services.index_service import IndexService | |
from app.Services.lifespan_service import LifespanService | |
from app.Services.storage import StorageService | |
from app.Services.vector_db_context import VectorDbContext | |
from app.config import config | |
from app.util.generate_uuid import generate_uuid | |
class UploadService(LifespanService): | |
def __init__(self, storage_service: StorageService, db_context: VectorDbContext, index_service: IndexService): | |
self._storage_service = storage_service | |
self._db_context = db_context | |
self._index_service = index_service | |
self._queue = asyncio.Queue(config.admin_index_queue_max_length) | |
self._upload_worker_task = asyncio.create_task(self._upload_worker()) | |
self.uploading_ids = set() | |
self._processed_count = 0 | |
async def _upload_worker(self): | |
while True: | |
img_data, *args = await self._queue.get() | |
try: | |
await self._upload_task(img_data, *args) | |
logger.success("Image {} uploaded and indexed. Queue Length: {} [-1]", img_data.id, self._queue.qsize()) | |
except Exception as ex: | |
logger.error("Error occurred while uploading image {}", img_data.id) | |
logger.exception(ex) | |
finally: | |
self._queue.task_done() | |
self.uploading_ids.remove(img_data.id) | |
self._processed_count += 1 | |
if self._processed_count % 50 == 0: | |
gc.collect() | |
async def _upload_task(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool, | |
thumbnail_mode: UploadImageThumbnailMode): | |
img = Image.open(BytesIO(img_bytes)) | |
logger.info('Start indexing image {}. Local: {}. Size: {}', img_data.id, img_data.local, len(img_bytes)) | |
file_name = f"{img_data.id}.{img_data.format}" | |
thumb_path = f"thumbnails/{img_data.id}.webp" | |
gen_thumb = thumbnail_mode == UploadImageThumbnailMode.ALWAYS or ( | |
thumbnail_mode == UploadImageThumbnailMode.IF_NECESSARY and len(img_bytes) > 1024 * 500) | |
if img_data.local: | |
img_data.url = await self._storage_service.active_storage.url(file_name) | |
if gen_thumb: | |
img_data.thumbnail_url = await self._storage_service.active_storage.url( | |
f"thumbnails/{img_data.id}.webp") | |
img_data.local_thumbnail = True | |
await self._index_service.index_image(img, img_data, skip_ocr=skip_ocr, background=True) | |
logger.success("Image {} indexed.", img_data.id) | |
if img_data.local: | |
logger.info("Start uploading image {} to local storage.", img_data.id) | |
await self._storage_service.active_storage.upload(img_bytes, file_name) | |
logger.success("Image {} uploaded to local storage.", img_data.id) | |
if gen_thumb: | |
logger.info("Start generate and upload thumbnail for {}.", img_data.id) | |
img.thumbnail((256, 256), resample=Image.Resampling.LANCZOS) | |
img_byte_arr = BytesIO() | |
img.save(img_byte_arr, 'WebP', save_all=True) | |
await self._storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path) | |
logger.success("Thumbnail for {} generated and uploaded!", img_data.id) | |
img.close() | |
async def queue_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool, | |
thumbnail_mode: UploadImageThumbnailMode): | |
self.uploading_ids.add(img_data.id) | |
await self._queue.put((img_data, img_bytes, skip_ocr, thumbnail_mode)) | |
logger.success("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize()) | |
async def assign_image_id(self, img_file: pathlib.Path | io.BytesIO | bytes): | |
img_id = generate_uuid(img_file) | |
# check for duplicate points | |
if img_id in self.uploading_ids or len(await self._db_context.validate_ids([str(img_id)])) != 0: | |
logger.warning("Duplicate upload request for image id: {}", img_id) | |
raise PointDuplicateError(f"The uploaded point is already contained in the database! entity id: {img_id}", | |
img_id) | |
return img_id | |
async def sync_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool, | |
thumbnail_mode: UploadImageThumbnailMode): | |
await self._upload_task(img_data, img_bytes, skip_ocr, thumbnail_mode) | |
def get_queue_size(self): | |
return self._queue.qsize() | |
async def on_exit(self): # pragma: no cover Hard to test in UT. | |
if self.get_queue_size() != 0: | |
logger.warning("There are still {} images in the upload queue. Waiting for upload process to be completed.", | |
self.get_queue_size()) | |
await self._queue.join() | |