Spaces:
Running
Running
File size: 11,977 Bytes
21db53c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
from io import BytesIO
from typing import Annotated, List
from uuid import uuid4, UUID
from PIL import Image
from fastapi import APIRouter, HTTPException
from fastapi.params import File, Query, Path, Depends
from loguru import logger
from app.Models.api_models.search_api_model import AdvancedSearchModel, CombinedSearchModel, SearchBasisEnum
from app.Models.api_response.search_api_response import SearchApiResponse
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
from app.Services.authentication import force_access_token_verify
from app.Services.provider import ServiceProvider
from app.config import config
from app.util.calculate_vectors_cosine import calculate_vectors_cosine
search_router = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None),
tags=["Search"])
services: ServiceProvider | None = None # The service provider will be injected in the webapp initialize
class SearchBasisParams:
def __init__(self,
basis: Annotated[SearchBasisEnum, Query(
description="The basis used to search the image.")] = SearchBasisEnum.vision):
if basis == SearchBasisEnum.ocr and not config.ocr_search.enable:
raise HTTPException(400, "OCR search is not enabled.")
self.basis = basis
async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse:
if not config.storage.method.enabled:
return resp
for item in resp.result:
if item.img.local:
img_extension = item.img.format or item.img.url.split('.')[-1]
img_remote_filename = f"{item.img.id}.{img_extension}"
item.img.url = await services.storage_service.active_storage.presign_url(img_remote_filename)
if item.img.thumbnail_url is not None and (item.img.local or item.img.local_thumbnail):
thumbnail_remote_filename = f"thumbnails/{item.img.id}.webp"
item.img.thumbnail_url = await services.storage_service.active_storage.presign_url(
thumbnail_remote_filename)
return resp
@search_router.get("/text/{prompt}", description="Search images by text prompt")
async def textSearch(
prompt: Annotated[
str, Path(max_length=100, description="The image prompt text you want to search.")],
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)],
exact: Annotated[bool, Query(
description="If using OCR search, this option will require the ocr text contains **exactly** the "
"criteria you have given. This won't take any effect in vision search.")] = False
) -> SearchApiResponse:
logger.info("Text search request received, prompt: {}", prompt)
text_vector = services.transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \
else services.transformers_service.get_bert_vector(prompt)
if basis.basis == SearchBasisEnum.ocr and exact:
filter_param.ocr_text = prompt
results = await services.db_context.querySearch(text_vector,
query_vector_name=services.db_context.vector_name_for_basis(
basis.basis),
filter_param=filter_param,
top_k=paging.count,
skip=paging.skip)
return await result_postprocessing(
SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))
@search_router.post("/image", description="Search images by image")
async def imageSearch(
image: Annotated[bytes, File(max_length=10 * 1024 * 1024, media_type="image/*",
description="The image you want to search.")],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
fakefile = BytesIO(image)
img = Image.open(fakefile)
logger.info("Image search request received")
image_vector = services.transformers_service.get_image_vector(img)
results = await services.db_context.querySearch(image_vector,
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param)
return await result_postprocessing(
SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))
@search_router.get("/similar/{image_id}",
description="Search images similar to the image with given id. "
"Won't include the given image itself in the result.")
async def similarWith(
image_id: Annotated[UUID, Path(description="The id of the image you want to search.")],
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
logger.info("Similar search request received, id: {}", image_id)
results = await services.db_context.querySimilar(search_id=str(image_id),
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param,
query_vector_name=services.db_context.vector_name_for_basis(
basis.basis))
return await result_postprocessing(
SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))
@search_router.post("/advanced", description="Search with multiple criteria")
async def advancedSearch(
model: AdvancedSearchModel,
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
logger.info("Advanced search request received: {}", model)
result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging)
return await result_postprocessing(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))
@search_router.post("/combined", description="Search with combined criteria")
async def combinedSearch(
model: CombinedSearchModel,
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
if not config.ocr_search.enable:
raise HTTPException(400, "You used combined search, but it needs OCR search which is not "
"enabled.")
logger.info("Combined search request received: {}", model)
result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging, True)
calculate_and_sort_by_combined_scores(model, basis, result)
result = result[:paging.count] if len(result) > paging.count else result
return await result_postprocessing(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))
@search_router.get("/random", description="Get random images")
async def randomPick(
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)],
seed: Annotated[int | None, Query(
description="The seed for random pick. This is helpful for generating a reproducible random pick.")] = None,
) -> SearchApiResponse:
logger.info("Random pick request received")
random_vector = services.transformers_service.get_random_vector(seed)
result = await services.db_context.querySearch(random_vector, top_k=paging.count, skip=paging.skip,
filter_param=filter_param)
return await result_postprocessing(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))
# @search_router.get("/recall/{query_id}", description="Recall the query with given queryId")
# async def recallQuery(query_id: str):
# raise NotImplementedError()
async def process_advanced_and_combined_search_query(model: AdvancedSearchModel,
basis: SearchBasisParams,
filter_param: FilterParams,
paging: SearchPagingParams,
is_combined_search=False) -> List[SearchResult]:
match basis.basis:
case SearchBasisEnum.ocr:
positive_vectors = [services.transformers_service.get_bert_vector(t) for t in model.criteria]
negative_vectors = [services.transformers_service.get_bert_vector(t) for t in model.negative_criteria]
case SearchBasisEnum.vision:
positive_vectors = [services.transformers_service.get_text_vector(t) for t in model.criteria]
negative_vectors = [services.transformers_service.get_text_vector(t) for t in model.negative_criteria]
case _: # pragma: no cover
raise NotImplementedError()
# In order to ensure the query effect of the combined query, modify the actual top_k
_query_top_k = min(max(30, paging.count * 3), 100) if is_combined_search else paging.count
result = await services.db_context.querySimilar(
query_vector_name=services.db_context.vector_name_for_basis(basis.basis),
positive_vectors=positive_vectors,
negative_vectors=negative_vectors,
mode=model.mode,
filter_param=filter_param,
with_vectors=is_combined_search,
top_k=_query_top_k,
skip=paging.skip)
return result
def calculate_and_sort_by_combined_scores(model: CombinedSearchModel,
basis: SearchBasisParams,
result: List[SearchResult]) -> None:
# Use a different method to calculate the extra prompt vector based on the basis
match basis.basis:
case SearchBasisEnum.ocr:
extra_prompt_vector = services.transformers_service.get_text_vector(model.extra_prompt)
case SearchBasisEnum.vision:
extra_prompt_vector = services.transformers_service.get_bert_vector(model.extra_prompt)
case _: # pragma: no cover
raise NotImplementedError()
# Calculate combined_similar_score (original score * similar_score) and write to SearchResult.score
for itm in result:
match basis.basis:
case SearchBasisEnum.ocr:
extra_vector = itm.img.image_vector
case SearchBasisEnum.vision:
extra_vector = itm.img.text_contain_vector
case _: # pragma: no cover
raise NotImplementedError()
if extra_vector is not None:
similar_score = calculate_vectors_cosine(extra_vector, extra_prompt_vector)
itm.score = (1 + similar_score) * itm.score
# Finally, sort the result by combined_similar_score
result.sort(key=lambda i: i.score, reverse=True)
|