Spaces:
Running
Running
File size: 11,977 Bytes
21db53c |
|
from io import BytesIO
from typing import Annotated, List
from uuid import uuid4, UUID
from PIL import Image
from fastapi import APIRouter, HTTPException
from fastapi.params import File, Query, Path, Depends
from loguru import logger
from app.Models.api_models.search_api_model import AdvancedSearchModel, CombinedSearchModel, SearchBasisEnum
from app.Models.api_response.search_api_response import SearchApiResponse
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
from app.Services.authentication import force_access_token_verify
from app.Services.provider import ServiceProvider
from app.config import config
from app.util.calculate_vectors_cosine import calculate_vectors_cosine
search_router = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None),
tags=["Search"])
services: ServiceProvider | None = None # The service provider will be injected in the webapp initialize
class SearchBasisParams:
def __init__(self,
basis: Annotated[SearchBasisEnum, Query(
description="The basis used to search the image.")] = SearchBasisEnum.vision):
if basis == SearchBasisEnum.ocr and not config.ocr_search.enable:
raise HTTPException(400, "OCR search is not enabled.")
self.basis = basis
async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse:
if not config.storage.method.enabled:
return resp
for item in resp.result:
if item.img.local:
img_extension = item.img.format or item.img.url.split('.')[-1]
img_remote_filename = f"{item.img.id}.{img_extension}"
item.img.url = await services.storage_service.active_storage.presign_url(img_remote_filename)
if item.img.thumbnail_url is not None and (item.img.local or item.img.local_thumbnail):
thumbnail_remote_filename = f"thumbnails/{item.img.id}.webp"
item.img.thumbnail_url = await services.storage_service.active_storage.presign_url(
thumbnail_remote_filename)
return resp
@search_router.get("/text/{prompt}", description="Search images by text prompt")
async def textSearch(
prompt: Annotated[
str, Path(max_length=100, description="The image prompt text you want to search.")],
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)],
exact: Annotated[bool, Query(
description="If using OCR search, this option will require the ocr text contains **exactly** the "
"criteria you have given. This won't take any effect in vision search.")] = False
) -> SearchApiResponse:
logger.info("Text search request received, prompt: {}", prompt)
text_vector = services.transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \
else services.transformers_service.get_bert_vector(prompt)
if basis.basis == SearchBasisEnum.ocr and exact:
filter_param.ocr_text = prompt
results = await services.db_context.querySearch(text_vector,
query_vector_name=services.db_context.vector_name_for_basis(
basis.basis),
filter_param=filter_param,
top_k=paging.count,
skip=paging.skip)
return await result_postprocessing(
SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))
@search_router.post("/image", description="Search images by image")
async def imageSearch(
image: Annotated[bytes, File(max_length=10 * 1024 * 1024, media_type="image/*",
description="The image you want to search.")],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
fakefile = BytesIO(image)
img = Image.open(fakefile)
logger.info("Image search request received")
image_vector = services.transformers_service.get_image_vector(img)
results = await services.db_context.querySearch(image_vector,
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param)
return await result_postprocessing(
SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))
@search_router.get("/similar/{image_id}",
description="Search images similar to the image with given id. "
"Won't include the given image itself in the result.")
async def similarWith(
image_id: Annotated[UUID, Path(description="The id of the image you want to search.")],
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
logger.info("Similar search request received, id: {}", image_id)
results = await services.db_context.querySimilar(search_id=str(image_id),
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param,
query_vector_name=services.db_context.vector_name_for_basis(
basis.basis))
return await result_postprocessing(
SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))
@search_router.post("/advanced", description="Search with multiple criteria")
async def advancedSearch(
model: AdvancedSearchModel,
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
logger.info("Advanced search request received: {}", model)
result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging)
return await result_postprocessing(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))
@search_router.post("/combined", description="Search with combined criteria")
async def combinedSearch(
model: CombinedSearchModel,
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
if not config.ocr_search.enable:
raise HTTPException(400, "You used combined search, but it needs OCR search which is not "
"enabled.")
logger.info("Combined search request received: {}", model)
result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging, True)
calculate_and_sort_by_combined_scores(model, basis, result)
result = result[:paging.count] if len(result) > paging.count else result
return await result_postprocessing(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))
@search_router.get("/random", description="Get random images")
async def randomPick(
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)],
seed: Annotated[int | None, Query(
description="The seed for random pick. This is helpful for generating a reproducible random pick.")] = None,
) -> SearchApiResponse:
logger.info("Random pick request received")
random_vector = services.transformers_service.get_random_vector(seed)
result = await services.db_context.querySearch(random_vector, top_k=paging.count, skip=paging.skip,
filter_param=filter_param)
return await result_postprocessing(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))
# @search_router.get("/recall/{query_id}", description="Recall the query with given queryId")
# async def recallQuery(query_id: str):
# raise NotImplementedError()
async def process_advanced_and_combined_search_query(model: AdvancedSearchModel,
basis: SearchBasisParams,
filter_param: FilterParams,
paging: SearchPagingParams,
is_combined_search=False) -> List[SearchResult]:
match basis.basis:
case SearchBasisEnum.ocr:
positive_vectors = [services.transformers_service.get_bert_vector(t) for t in model.criteria]
negative_vectors = [services.transformers_service.get_bert_vector(t) for t in model.negative_criteria]
case SearchBasisEnum.vision:
positive_vectors = [services.transformers_service.get_text_vector(t) for t in model.criteria]
negative_vectors = [services.transformers_service.get_text_vector(t) for t in model.negative_criteria]
case _: # pragma: no cover
raise NotImplementedError()
# In order to ensure the query effect of the combined query, modify the actual top_k
_query_top_k = min(max(30, paging.count * 3), 100) if is_combined_search else paging.count
result = await services.db_context.querySimilar(
query_vector_name=services.db_context.vector_name_for_basis(basis.basis),
positive_vectors=positive_vectors,
negative_vectors=negative_vectors,
mode=model.mode,
filter_param=filter_param,
with_vectors=is_combined_search,
top_k=_query_top_k,
skip=paging.skip)
return result
def calculate_and_sort_by_combined_scores(model: CombinedSearchModel,
basis: SearchBasisParams,
result: List[SearchResult]) -> None:
# Use a different method to calculate the extra prompt vector based on the basis
match basis.basis:
case SearchBasisEnum.ocr:
extra_prompt_vector = services.transformers_service.get_text_vector(model.extra_prompt)
case SearchBasisEnum.vision:
extra_prompt_vector = services.transformers_service.get_bert_vector(model.extra_prompt)
case _: # pragma: no cover
raise NotImplementedError()
# Calculate combined_similar_score (original score * similar_score) and write to SearchResult.score
for itm in result:
match basis.basis:
case SearchBasisEnum.ocr:
extra_vector = itm.img.image_vector
case SearchBasisEnum.vision:
extra_vector = itm.img.text_contain_vector
case _: # pragma: no cover
raise NotImplementedError()
if extra_vector is not None:
similar_score = calculate_vectors_cosine(extra_vector, extra_prompt_vector)
itm.score = (1 + similar_score) * itm.score
# Finally, sort the result by combined_similar_score
result.sort(key=lambda i: i.score, reverse=True)
|