Spaces:
Running
Running
from PIL import Image | |
from fastapi.concurrency import run_in_threadpool | |
from app.Models.errors import PointDuplicateError | |
from app.Models.img_data import ImageData | |
from app.Services.lifespan_service import LifespanService | |
from app.Services.ocr_services import OCRService | |
from app.Services.transformers_service import TransformersService | |
from app.Services.vector_db_context import VectorDbContext | |
from app.config import config | |
class IndexService(LifespanService): | |
def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext): | |
self._ocr_service = ocr_service | |
self._transformers_service = transformers_service | |
self._db_context = db_context | |
def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False): | |
image_data.width = image.width | |
image_data.height = image.height | |
image_data.aspect_ratio = float(image.width) / image.height | |
if image.mode != 'RGB': | |
image = image.convert('RGB') # to reduce convert in next steps | |
else: | |
image = image.copy() | |
image_data.image_vector = self._transformers_service.get_image_vector(image) | |
if not skip_ocr and config.ocr_search.enable: | |
image_data.ocr_text = self._ocr_service.ocr_interface(image) | |
if image_data.ocr_text != "": | |
image_data.text_contain_vector = self._transformers_service.get_bert_vector(image_data.ocr_text) | |
else: | |
image_data.ocr_text = None | |
# currently, here only need just a simple check | |
async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool: | |
image_id_list = [str(item.id) for item in image_data] | |
result = await self._db_context.validate_ids(image_id_list) | |
return len(result) != 0 | |
async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False, | |
background=False): | |
if not skip_duplicate_check and (await self._is_point_duplicate([image_data])): | |
raise PointDuplicateError("The uploaded points are contained in the database!", image_data.id) | |
if background: | |
await run_in_threadpool(self._prepare_image, image, image_data, skip_ocr) | |
else: | |
self._prepare_image(image, image_data, skip_ocr) | |
await self._db_context.insertItems([image_data]) | |
async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], | |
skip_ocr=False, allow_overwrite=False): | |
if not allow_overwrite and (await self._is_point_duplicate(image_data)): | |
raise PointDuplicateError("The uploaded points are contained in the database!") | |
for img, img_data in zip(image, image_data): | |
self._prepare_image(img, img_data, skip_ocr) | |
await self._db_context.insertItems(image_data) | |