Spaces:
Running
Running
File size: 5,200 Bytes
21db53c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import asyncio
import gc
import io
import pathlib
from io import BytesIO
from PIL import Image
from loguru import logger
from app.Models.api_models.admin_query_params import UploadImageThumbnailMode
from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Services.index_service import IndexService
from app.Services.lifespan_service import LifespanService
from app.Services.storage import StorageService
from app.Services.vector_db_context import VectorDbContext
from app.config import config
from app.util.generate_uuid import generate_uuid
class UploadService(LifespanService):
def __init__(self, storage_service: StorageService, db_context: VectorDbContext, index_service: IndexService):
self._storage_service = storage_service
self._db_context = db_context
self._index_service = index_service
self._queue = asyncio.Queue(config.admin_index_queue_max_length)
self._upload_worker_task = asyncio.create_task(self._upload_worker())
self.uploading_ids = set()
self._processed_count = 0
async def _upload_worker(self):
while True:
img_data, *args = await self._queue.get()
try:
await self._upload_task(img_data, *args)
logger.success("Image {} uploaded and indexed. Queue Length: {} [-1]", img_data.id, self._queue.qsize())
except Exception as ex:
logger.error("Error occurred while uploading image {}", img_data.id)
logger.exception(ex)
finally:
self._queue.task_done()
self.uploading_ids.remove(img_data.id)
self._processed_count += 1
if self._processed_count % 50 == 0:
gc.collect()
async def _upload_task(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
img = Image.open(BytesIO(img_bytes))
logger.info('Start indexing image {}. Local: {}. Size: {}', img_data.id, img_data.local, len(img_bytes))
file_name = f"{img_data.id}.{img_data.format}"
thumb_path = f"thumbnails/{img_data.id}.webp"
gen_thumb = thumbnail_mode == UploadImageThumbnailMode.ALWAYS or (
thumbnail_mode == UploadImageThumbnailMode.IF_NECESSARY and len(img_bytes) > 1024 * 500)
if img_data.local:
img_data.url = await self._storage_service.active_storage.url(file_name)
if gen_thumb:
img_data.thumbnail_url = await self._storage_service.active_storage.url(
f"thumbnails/{img_data.id}.webp")
img_data.local_thumbnail = True
await self._index_service.index_image(img, img_data, skip_ocr=skip_ocr, background=True)
logger.success("Image {} indexed.", img_data.id)
if img_data.local:
logger.info("Start uploading image {} to local storage.", img_data.id)
await self._storage_service.active_storage.upload(img_bytes, file_name)
logger.success("Image {} uploaded to local storage.", img_data.id)
if gen_thumb:
logger.info("Start generate and upload thumbnail for {}.", img_data.id)
img.thumbnail((256, 256), resample=Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
img.save(img_byte_arr, 'WebP', save_all=True)
await self._storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path)
logger.success("Thumbnail for {} generated and uploaded!", img_data.id)
img.close()
async def queue_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
self.uploading_ids.add(img_data.id)
await self._queue.put((img_data, img_bytes, skip_ocr, thumbnail_mode))
logger.success("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize())
async def assign_image_id(self, img_file: pathlib.Path | io.BytesIO | bytes):
img_id = generate_uuid(img_file)
# check for duplicate points
if img_id in self.uploading_ids or len(await self._db_context.validate_ids([str(img_id)])) != 0:
logger.warning("Duplicate upload request for image id: {}", img_id)
raise PointDuplicateError(f"The uploaded point is already contained in the database! entity id: {img_id}",
img_id)
return img_id
async def sync_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
await self._upload_task(img_data, img_bytes, skip_ocr, thumbnail_mode)
def get_queue_size(self):
return self._queue.qsize()
async def on_exit(self): # pragma: no cover Hard to test in UT.
if self.get_queue_size() != 0:
logger.warning("There are still {} images in the upload queue. Waiting for upload process to be completed.",
self.get_queue_size())
await self._queue.join()
|