Spaces:
Running
Running
File size: 2,945 Bytes
21db53c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from PIL import Image
from fastapi.concurrency import run_in_threadpool
from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Services.lifespan_service import LifespanService
from app.Services.ocr_services import OCRService
from app.Services.transformers_service import TransformersService
from app.Services.vector_db_context import VectorDbContext
from app.config import config
class IndexService(LifespanService):
def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext):
self._ocr_service = ocr_service
self._transformers_service = transformers_service
self._db_context = db_context
def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
image_data.width = image.width
image_data.height = image.height
image_data.aspect_ratio = float(image.width) / image.height
if image.mode != 'RGB':
image = image.convert('RGB') # to reduce convert in next steps
else:
image = image.copy()
image_data.image_vector = self._transformers_service.get_image_vector(image)
if not skip_ocr and config.ocr_search.enable:
image_data.ocr_text = self._ocr_service.ocr_interface(image)
if image_data.ocr_text != "":
image_data.text_contain_vector = self._transformers_service.get_bert_vector(image_data.ocr_text)
else:
image_data.ocr_text = None
# currently, here only need just a simple check
async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool:
image_id_list = [str(item.id) for item in image_data]
result = await self._db_context.validate_ids(image_id_list)
return len(result) != 0
async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False,
background=False):
if not skip_duplicate_check and (await self._is_point_duplicate([image_data])):
raise PointDuplicateError("The uploaded points are contained in the database!", image_data.id)
if background:
await run_in_threadpool(self._prepare_image, image, image_data, skip_ocr)
else:
self._prepare_image(image, image_data, skip_ocr)
await self._db_context.insertItems([image_data])
async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData],
skip_ocr=False, allow_overwrite=False):
if not allow_overwrite and (await self._is_point_duplicate(image_data)):
raise PointDuplicateError("The uploaded points are contained in the database!")
for img, img_data in zip(image, image_data):
self._prepare_image(img, img_data, skip_ocr)
await self._db_context.insertItems(image_data)
|