File size: 11,977 Bytes
21db53c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
from io import BytesIO
from typing import Annotated, List
from uuid import uuid4, UUID

from PIL import Image
from fastapi import APIRouter, HTTPException
from fastapi.params import File, Query, Path, Depends
from loguru import logger

from app.Models.api_models.search_api_model import AdvancedSearchModel, CombinedSearchModel, SearchBasisEnum
from app.Models.api_response.search_api_response import SearchApiResponse
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
from app.Services.authentication import force_access_token_verify
from app.Services.provider import ServiceProvider
from app.config import config
from app.util.calculate_vectors_cosine import calculate_vectors_cosine

search_router = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None),
                          tags=["Search"])

services: ServiceProvider | None = None  # The service provider will be injected in the webapp initialize


class SearchBasisParams:
    def __init__(self,
                 basis: Annotated[SearchBasisEnum, Query(
                     description="The basis used to search the image.")] = SearchBasisEnum.vision):
        if basis == SearchBasisEnum.ocr and not config.ocr_search.enable:
            raise HTTPException(400, "OCR search is not enabled.")
        self.basis = basis


async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse:
    if not config.storage.method.enabled:
        return resp
    for item in resp.result:
        if item.img.local:
            img_extension = item.img.format or item.img.url.split('.')[-1]
            img_remote_filename = f"{item.img.id}.{img_extension}"
            item.img.url = await services.storage_service.active_storage.presign_url(img_remote_filename)
        if item.img.thumbnail_url is not None and (item.img.local or item.img.local_thumbnail):
            thumbnail_remote_filename = f"thumbnails/{item.img.id}.webp"
            item.img.thumbnail_url = await services.storage_service.active_storage.presign_url(
                thumbnail_remote_filename)
    return resp


@search_router.get("/text/{prompt}", description="Search images by text prompt")
async def textSearch(
        prompt: Annotated[
            str, Path(max_length=100, description="The image prompt text you want to search.")],
        basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)],
        exact: Annotated[bool, Query(
            description="If using OCR search, this option will require the ocr text contains **exactly** the "
                        "criteria you have given. This won't take any effect in vision search.")] = False
) -> SearchApiResponse:
    logger.info("Text search request received, prompt: {}", prompt)
    text_vector = services.transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \
        else services.transformers_service.get_bert_vector(prompt)
    if basis.basis == SearchBasisEnum.ocr and exact:
        filter_param.ocr_text = prompt
    results = await services.db_context.querySearch(text_vector,
                                                    query_vector_name=services.db_context.vector_name_for_basis(
                                                        basis.basis),
                                                    filter_param=filter_param,
                                                    top_k=paging.count,
                                                    skip=paging.skip)
    return await result_postprocessing(
        SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))


@search_router.post("/image", description="Search images by image")
async def imageSearch(
        image: Annotated[bytes, File(max_length=10 * 1024 * 1024, media_type="image/*",
                                     description="The image you want to search.")],
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
    fakefile = BytesIO(image)
    img = Image.open(fakefile)
    logger.info("Image search request received")
    image_vector = services.transformers_service.get_image_vector(img)
    results = await services.db_context.querySearch(image_vector,
                                                    top_k=paging.count,
                                                    skip=paging.skip,
                                                    filter_param=filter_param)
    return await result_postprocessing(
        SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))


@search_router.get("/similar/{image_id}",
                   description="Search images similar to the image with given id. "
                               "Won't include the given image itself in the result.")
async def similarWith(
        image_id: Annotated[UUID, Path(description="The id of the image you want to search.")],
        basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
    logger.info("Similar search request received, id: {}", image_id)
    results = await services.db_context.querySimilar(search_id=str(image_id),
                                                     top_k=paging.count,
                                                     skip=paging.skip,
                                                     filter_param=filter_param,
                                                     query_vector_name=services.db_context.vector_name_for_basis(
                                                         basis.basis))
    return await result_postprocessing(
        SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))


@search_router.post("/advanced", description="Search with multiple criteria")
async def advancedSearch(
        model: AdvancedSearchModel,
        basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
    logger.info("Advanced search request received: {}", model)
    result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging)
    return await result_postprocessing(
        SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


@search_router.post("/combined", description="Search with combined criteria")
async def combinedSearch(
        model: CombinedSearchModel,
        basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
    if not config.ocr_search.enable:
        raise HTTPException(400, "You used combined search, but it needs OCR search which is not "
                                 "enabled.")
    logger.info("Combined search request received: {}", model)
    result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging, True)
    calculate_and_sort_by_combined_scores(model, basis, result)
    result = result[:paging.count] if len(result) > paging.count else result
    return await result_postprocessing(
        SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


@search_router.get("/random", description="Get random images")
async def randomPick(
        filter_param: Annotated[FilterParams, Depends(FilterParams)],
        paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)],
        seed: Annotated[int | None, Query(
            description="The seed for random pick. This is helpful for generating a reproducible random pick.")] = None,
) -> SearchApiResponse:
    logger.info("Random pick request received")
    random_vector = services.transformers_service.get_random_vector(seed)
    result = await services.db_context.querySearch(random_vector, top_k=paging.count, skip=paging.skip,
                                                   filter_param=filter_param)
    return await result_postprocessing(
        SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


# @search_router.get("/recall/{query_id}", description="Recall the query with given queryId")
# async def recallQuery(query_id: str):
#     raise NotImplementedError()

async def process_advanced_and_combined_search_query(model: AdvancedSearchModel,
                                                     basis: SearchBasisParams,
                                                     filter_param: FilterParams,
                                                     paging: SearchPagingParams,
                                                     is_combined_search=False) -> List[SearchResult]:
    match basis.basis:
        case SearchBasisEnum.ocr:
            positive_vectors = [services.transformers_service.get_bert_vector(t) for t in model.criteria]
            negative_vectors = [services.transformers_service.get_bert_vector(t) for t in model.negative_criteria]
        case SearchBasisEnum.vision:
            positive_vectors = [services.transformers_service.get_text_vector(t) for t in model.criteria]
            negative_vectors = [services.transformers_service.get_text_vector(t) for t in model.negative_criteria]
        case _:  # pragma: no cover
            raise NotImplementedError()
    # In order to ensure the query effect of the combined query, modify the actual top_k
    _query_top_k = min(max(30, paging.count * 3), 100) if is_combined_search else paging.count
    result = await services.db_context.querySimilar(
        query_vector_name=services.db_context.vector_name_for_basis(basis.basis),
        positive_vectors=positive_vectors,
        negative_vectors=negative_vectors,
        mode=model.mode,
        filter_param=filter_param,
        with_vectors=is_combined_search,
        top_k=_query_top_k,
        skip=paging.skip)
    return result


def calculate_and_sort_by_combined_scores(model: CombinedSearchModel,
                                          basis: SearchBasisParams,
                                          result: List[SearchResult]) -> None:
    # Use a different method to calculate the extra prompt vector based on the basis
    match basis.basis:
        case SearchBasisEnum.ocr:
            extra_prompt_vector = services.transformers_service.get_text_vector(model.extra_prompt)
        case SearchBasisEnum.vision:
            extra_prompt_vector = services.transformers_service.get_bert_vector(model.extra_prompt)
        case _:  # pragma: no cover
            raise NotImplementedError()
    # Calculate combined_similar_score (original score * similar_score) and write to SearchResult.score
    for itm in result:
        match basis.basis:
            case SearchBasisEnum.ocr:
                extra_vector = itm.img.image_vector
            case SearchBasisEnum.vision:
                extra_vector = itm.img.text_contain_vector
            case _:  # pragma: no cover
                raise NotImplementedError()
        if extra_vector is not None:
            similar_score = calculate_vectors_cosine(extra_vector, extra_prompt_vector)
            itm.score = (1 + similar_score) * itm.score
    # Finally, sort the result by combined_similar_score
    result.sort(key=lambda i: i.score, reverse=True)