neko-image-gallery / app /Services /index_service.py
eggacheb's picture
Upload 97 files
21db53c verified
from PIL import Image
from fastapi.concurrency import run_in_threadpool
from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Services.lifespan_service import LifespanService
from app.Services.ocr_services import OCRService
from app.Services.transformers_service import TransformersService
from app.Services.vector_db_context import VectorDbContext
from app.config import config
class IndexService(LifespanService):
def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext):
self._ocr_service = ocr_service
self._transformers_service = transformers_service
self._db_context = db_context
def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
image_data.width = image.width
image_data.height = image.height
image_data.aspect_ratio = float(image.width) / image.height
if image.mode != 'RGB':
image = image.convert('RGB') # to reduce convert in next steps
else:
image = image.copy()
image_data.image_vector = self._transformers_service.get_image_vector(image)
if not skip_ocr and config.ocr_search.enable:
image_data.ocr_text = self._ocr_service.ocr_interface(image)
if image_data.ocr_text != "":
image_data.text_contain_vector = self._transformers_service.get_bert_vector(image_data.ocr_text)
else:
image_data.ocr_text = None
# currently, here only need just a simple check
async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool:
image_id_list = [str(item.id) for item in image_data]
result = await self._db_context.validate_ids(image_id_list)
return len(result) != 0
async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False,
background=False):
if not skip_duplicate_check and (await self._is_point_duplicate([image_data])):
raise PointDuplicateError("The uploaded points are contained in the database!", image_data.id)
if background:
await run_in_threadpool(self._prepare_image, image, image_data, skip_ocr)
else:
self._prepare_image(image, image_data, skip_ocr)
await self._db_context.insertItems([image_data])
async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData],
skip_ocr=False, allow_overwrite=False):
if not allow_overwrite and (await self._is_point_duplicate(image_data)):
raise PointDuplicateError("The uploaded points are contained in the database!")
for img, img_data in zip(image, image_data):
self._prepare_image(img, img_data, skip_ocr)
await self._db_context.insertItems(image_data)