Spaces:
Running
Running
import itertools | |
import uuid | |
import pytest_asyncio | |
from .conftest import check_local_dir_empty | |
from ..assets import assets_path | |
test_images = {'bsn': ['bsn_0.jpg', 'bsn_1.jpg', 'bsn_2.jpg'], | |
'cat': ['cat_0.jpg', 'cat_1.jpg'], | |
'cg': ['cg_0.jpg', 'cg_1.png']} | |
async def img_ids(test_client, wait_for_background_task): | |
img_ids = {} | |
for img_cls, item_images in test_images.items(): | |
img_ids[img_cls] = [] | |
for image in item_images: | |
print(f'upload image {image}...') | |
with open(assets_path / 'test_images' / image, 'rb') as f: | |
resp = test_client.post('/admin/upload', | |
files={'image_file': f}, | |
params={'local': True}) | |
assert resp.status_code == 200 | |
img_ids[img_cls].append(resp.json()['image_id']) | |
print('Waiting for images to be processed...') | |
await wait_for_background_task(sum(len(v) for v in test_images.values())) | |
yield img_ids | |
# cleanup | |
for img_cls in test_images.keys(): | |
for img_id in img_ids[img_cls]: | |
resp = test_client.delete(f"/admin/delete/{img_id}") | |
assert resp.status_code == 200 | |
check_local_dir_empty() | |
def test_search_text(test_client, img_ids): | |
resp = test_client.get('/search/text/hatsune+miku') | |
assert resp.status_code == 200 | |
assert resp.json()['result'][0]['img']['id'] in img_ids['cg'] | |
def test_search_image(test_client, img_ids): | |
with open(assets_path / 'test_images' / test_images['cat'][0], 'rb') as f: | |
resp = test_client.post('/search/image', | |
files={'image': f}) | |
assert resp.status_code == 200 | |
assert resp.json()['result'][0]['img']['id'] in img_ids['cat'] | |
def test_search_similar(test_client, img_ids): | |
resp = test_client.get(f"/search/similar/{img_ids['bsn'][0]}") | |
assert resp.status_code == 200 | |
assert resp.json()['result'][0]['img']['id'] in img_ids['bsn'] | |
def test_search_advanced(test_client, img_ids): | |
resp = test_client.post("/search/advanced", | |
json={'criteria': ['white background', 'grayscale image'], | |
'negative_criteria': ['cat', 'hatsune miku']}) | |
assert resp.status_code == 200 | |
assert resp.json()['result'][0]['img']['id'] in img_ids['bsn'] | |
def test_search_combined(test_client, img_ids): | |
resp = test_client.post('/search/combined', json={'criteria': ['hatsune miku'], | |
'negative_criteria': ['grayscale image', 'cat'], | |
'extra_prompt': 'hatsunemiku'}) | |
assert resp.status_code == 200 | |
assert resp.json()['result'][0]['img']['id'] == img_ids['cg'][1] | |
resp = test_client.post('/search/combined?basis=ocr', | |
json={'criteria': ['hatsunemiku'], 'extra_prompt': 'hatsune miku'}) | |
assert resp.status_code == 200 | |
assert resp.json()['result'][0]['img']['id'] == img_ids['cg'][1] | |
def test_search_filters(test_client, img_ids): | |
resp = test_client.put(f"/admin/update_opt/{img_ids['bsn'][0]}", json={'categories': ['bsn'], 'starred': True}) | |
assert resp.status_code == 200 | |
resp = test_client.get("/search/text/cat", params={'categories': 'bsn'}) | |
assert resp.status_code == 200 | |
assert resp.json()['result'][0]['img']['id'] == img_ids['bsn'][0] | |
resp = test_client.get("/search/text/cat", params={'starred': True}) | |
assert resp.status_code == 200 | |
assert resp.json()['result'][0]['img']['id'] == img_ids['bsn'][0] | |
def test_images_query_by_id(test_client, img_ids): | |
resp = test_client.get(f"/images/id/{img_ids['bsn'][0]}") | |
assert resp.status_code == 200 | |
assert resp.json()['img']['id'] == img_ids['bsn'][0] | |
def test_images_query_not_exist(test_client, img_ids): | |
resp = test_client.get(f"/images/id/{uuid.uuid4()}") | |
assert resp.status_code == 404 | |
def test_images_query_scroll(test_client, img_ids): | |
resp = test_client.get("/images/", params={'count': 50}) | |
assert resp.status_code == 200 | |
resp_imgs = resp.json()['images'] | |
all_images_id = list(itertools.chain(*img_ids.values())) | |
for item in resp_imgs: | |
assert item['id'] in all_images_id | |
paging_test = test_client.get(f'/images', | |
params={'prev_offset_id': resp_imgs[len(resp_imgs) // 2]['id']}) | |
assert paging_test.status_code == 200 | |
assert paging_test.json()['images'][0]['id'] == resp_imgs[len(resp_imgs) // 2]['id'] | |
no_exist_test = test_client.get(f'/images', | |
params={'prev_offset_id': uuid.uuid4()}) | |
assert no_exist_test.status_code == 404 | |