diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..0fbc9840155cbda2e4d305b5632bccab30e97e09 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,43 @@ +# Include any files or directories that you don't want to be copied to your +# container here (e.g., local build artifacts, temporary files, etc.). +# +# For more help, visit the .dockerignore file reference guide at +# https://docs.docker.com/engine/reference/builder/#dockerignore-file + +**/.DS_Store +**/__pycache__ +**/.venv +**/.classpath +**/.dockerignore +**/.env +**/.git +**/.gitignore +**/.project +**/.settings +**/.toolstarget +**/.vs +**/.vscode +**/*.*proj.user +**/*.dbmdl +**/*.jfm +**/.idea +**/bin +**/charts +**/docker-compose* +**/compose* +**/*Dockerfile* +**/node_modules +**/npm-debug.log +**/obj +**/secrets.dev.yaml +**/values.dev.yaml +**/venv + +.github/ +static/ +web/ +tests/ +pylintrc.toml +LICENSE +readme.md +readme_cn.md diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..762a4dee0941f2e1ccff4bc000a3a0408476653d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +web/screenshots/1.png filter=lfs diff=lfs merge=lfs -text +web/screenshots/2.png filter=lfs diff=lfs merge=lfs -text diff --git "a/.github/ISSUE_TEMPLATE/bug-report---\351\224\231\350\257\257\345\217\215\351\246\210.md" "b/.github/ISSUE_TEMPLATE/bug-report---\351\224\231\350\257\257\345\217\215\351\246\210.md" new file mode 100644 index 0000000000000000000000000000000000000000..3fe576d5a4457d1c2483e7938e6897e6aa465438 --- /dev/null +++ "b/.github/ISSUE_TEMPLATE/bug-report---\351\224\231\350\257\257\345\217\215\351\246\210.md" @@ -0,0 +1,28 @@ +--- +name: Bug report / 错误反馈 +about: Create a report to help us improve 报告您在使用本项目过程中遇到的Bug。 +title: '' +labels: bug +assignees: '' + +--- + +## Environment +NekoImageGallery version: Place the version of NekoImageGallery you're using here. 在此处添加您正在使用的NekoImageGallery版本。 + +Deployment Method: `Local / Docker` + +## Describe the bug + + +## To Reproduce + + +## Expected behavior + + +## Screenshots + + +## Additional context + diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..c7d2476788fe8362f7572c7ceb3f88df583ec48b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: true +contact_links: + - name: Ask a question about the project. 询问有关本项目的问题。 + url: https://github.com/hv0905/NekoImageGallery/discussions/new?category=q-a + about: Ask a question if you encounter a problem when using NekoImageGallery. Please use this option instead of Bug Report unless you are sure your problem is caused by a bug. 询问在您使用NekoImageGallery过程中遇到的问题。请优先使用此选项(而不是Bug Report),除非您认为您的问题是由NekoImageGallery的BUG造成。 diff --git "a/.github/ISSUE_TEMPLATE/feature-request---\345\212\237\350\203\275\350\257\267\346\261\202.md" "b/.github/ISSUE_TEMPLATE/feature-request---\345\212\237\350\203\275\350\257\267\346\261\202.md" new file mode 100644 index 0000000000000000000000000000000000000000..974b34fc1a229d0810440e6d7eb84d0e02460a4f --- /dev/null +++ "b/.github/ISSUE_TEMPLATE/feature-request---\345\212\237\350\203\275\350\257\267\346\261\202.md" @@ -0,0 +1,20 @@ +--- +name: Feature request / 功能请求 +about: Suggest an idea for this project 向本项目提交新的功能建议。 +title: '' +labels: enhancement +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000000000000000000000000000000000..91abb11fdf507883caeeb2d2958e1c65fb6cbdc1 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/.github/page_build/_config.yml b/.github/page_build/_config.yml new file mode 100644 index 0000000000000000000000000000000000000000..c4192631f25b34d77a7f159aa0da0e3ae99c4ef4 --- /dev/null +++ b/.github/page_build/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-cayman \ No newline at end of file diff --git a/.github/workflows/jekyll-gh-pages.yml b/.github/workflows/jekyll-gh-pages.yml new file mode 100644 index 0000000000000000000000000000000000000000..f2093c26c516e5e88d22a2fe9c1cad72a98a6c2a --- /dev/null +++ b/.github/workflows/jekyll-gh-pages.yml @@ -0,0 +1,61 @@ +name: Deploy project pages + +on: + # Runs on pushes targeting the default branch + push: + branches: ["master"] + paths: + - '**/*.md' + - '**/*.png' # for screenshots + - 'page_build/**' + - '.github/workflows/**' + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. +# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. +concurrency: + group: "pages" + cancel-in-progress: false + +jobs: + # Build job + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Pages + uses: actions/configure-pages@v4 + - name: prepare site src + run: | + mkdir -p .github/page_build + cp *.md .github/page_build/ + cp web/ .github/page_build/ -r + ls -lR .github/page_build/ + - name: Build with Jekyll + uses: actions/jekyll-build-pages@v1 + with: + source: ./.github/page_build/ + destination: ./_site + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + + # Deployment job + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/prod.yml b/.github/workflows/prod.yml new file mode 100644 index 0000000000000000000000000000000000000000..479660699a3190343cb5477df6004fe34afe714a --- /dev/null +++ b/.github/workflows/prod.yml @@ -0,0 +1,74 @@ +name: Check & deploy to DockerHub + +on: + push: + branches: + - 'master' + tags: + - '*' + workflow_dispatch: + +jobs: + perform-check: + uses: ./.github/workflows/test_lint.yml + secrets: inherit + docker: + runs-on: ubuntu-latest + environment: DockerHub + needs: + - perform-check + strategy: + matrix: + configurations: + - dockerfile: "Dockerfile" + suffixes: | + "" + "-cuda" + "-cuda12.1" + args: | + CUDA_VERSION=12.1 + - dockerfile: "Dockerfile" + suffixes: '"-cuda11.8"' + args: | + CUDA_VERSION=11.8 + - dockerfile: "cpu-only.Dockerfile" + suffixes: '"-cpu"' + args: "" + steps: + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Docker Meta + id: docker-meta + uses: docker/metadata-action@v5 + with: + images: edgeneko/neko-image-gallery + tags: | + type=edge,branch=master + type=semver,pattern=v{{version}} + type=semver,pattern=v{{major}}.{{minor}} + - name: Build combined tags + id: combine-tags + run: | + SUFFIXES=(${{ matrix.configurations.suffixes }}) + echo 'tags<> $GITHUB_OUTPUT + for SUFFIX in "${SUFFIXES[@]}"; do + echo '${{ steps.docker-meta.outputs.tags }}' | sed 's/$/'"$SUFFIX"'/' >> $GITHUB_OUTPUT + done + echo EOF >> $GITHUB_OUTPUT + + printf 'cache_tag=%s' "$(echo '${{ steps.docker-meta.outputs.tags }}' | tail -1 | sed 's/$/'"${SUFFIXES[0]}"'/')" >> $GITHUB_OUTPUT + - name: Build and push + uses: docker/build-push-action@v5 + with: + file: ${{ matrix.configurations.dockerfile }} + push: true + tags: ${{ steps.combine-tags.outputs.tags }} + build-args: ${{ matrix.configurations.args }} + labels: ${{ steps.docker-meta.outputs.labels }} + cache-from: type=registry,ref=${{steps.combine-tags.outputs.cache_tag}} + cache-to: type=inline diff --git a/.github/workflows/test_lint.yml b/.github/workflows/test_lint.yml new file mode 100644 index 0000000000000000000000000000000000000000..cf855b634e16680f6905448b62ffe212ae2b7202 --- /dev/null +++ b/.github/workflows/test_lint.yml @@ -0,0 +1,50 @@ +name: Test and Lint Project + +on: + workflow_call: + push: + branches-ignore: + - 'master' + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [ "3.10", "3.11" ] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + - name: Cache for models + id: cache-models + uses: actions/cache@v4 + with: + path: | + ~/.cache/huggingface/ + key: ${{ runner.os }}-models-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-models- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + pip install -r requirements.dev.txt + - name: Test the code with pytest + run: | + pytest --cov=app . + - name: Upload coverage reports to Codecov with GitHub Action + uses: codecov/codecov-action@v4.2.0 + if: ${{ matrix.python-version == '3.11' }} # Only upload coverage reports for the latest Python version + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + - name: Analysing the code with pylint + run: | + pylint --rc-file pylintrc.toml -j 0 app scripts tests && lint_result=$? || lint_result=$? + exit $(( $lint_result & 35 )) + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b00d08f0478de8356c40780ea7e91993b8491685 --- /dev/null +++ b/.gitignore @@ -0,0 +1,246 @@ +### PyCharm template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +static/ +qdrant_data/ +images_metadata/ +local_*/ +.idea \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..35410cacdc5e87f985c93a96520f5e11a5c822e4 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/NekoImageGallery.iml b/.idea/NekoImageGallery.iml new file mode 100644 index 0000000000000000000000000000000000000000..8d250ed13592ed92045867c3c81941ad4adc6722 --- /dev/null +++ b/.idea/NekoImageGallery.iml @@ -0,0 +1,13 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000000000000000000000000000000000000..d1884fd831579aa22143eec4b259d85a382eec24 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,14 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000000000000000000000000000000000000..2d29711fe633c397aaf2cc9f02ee62041f8a9153 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..9d210c970cbe36900eb5d92b0bca2d33b341087c --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..94a25f7f4cb416c083d265558da75d457237d671 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0ad25db4bd1d86c452db3f9602ccdbe172438f52 --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/app/Controllers/admin.py b/app/Controllers/admin.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c3416c386508293e63debcd995ba407d959fc9 --- /dev/null +++ b/app/Controllers/admin.py @@ -0,0 +1,164 @@ +from datetime import datetime +from io import BytesIO +from pathlib import PurePath +from typing import Annotated +from uuid import UUID + +from PIL import Image, UnidentifiedImageError +from fastapi import APIRouter, Depends, HTTPException, params, UploadFile, File +from loguru import logger + +from app.Models.api_models.admin_api_model import ImageOptUpdateModel, DuplicateValidationModel +from app.Models.api_models.admin_query_params import UploadImageModel +from app.Models.api_response.admin_api_response import ServerInfoResponse, ImageUploadResponse, \ + DuplicateValidationResponse +from app.Models.api_response.base import NekoProtocol +from app.Models.errors import PointDuplicateError +from app.Models.img_data import ImageData +from app.Services.authentication import force_admin_token_verify +from app.Services.provider import ServiceProvider +from app.Services.vector_db_context import PointNotFoundError +from app.config import config +from app.util.generate_uuid import generate_uuid_from_sha1 +from app.util.local_file_utility import VALID_IMAGE_EXTENSIONS + +admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)], tags=["Admin"]) + +services: ServiceProvider | None = None + + +@admin_router.delete("/delete/{image_id}", + description="Delete image with the given id from database. " + "If the image is a local image, it will be moved to `/static/_deleted` folder.") +async def delete_image( + image_id: Annotated[UUID, params.Path(description="The id of the image you want to delete.")]) -> NekoProtocol: + try: + point = await services.db_context.retrieve_by_id(str(image_id)) + except PointNotFoundError as ex: + raise HTTPException(404, "Cannot find the image with the given ID.") from ex + await services.db_context.deleteItems([str(point.id)]) + logger.success("Image {} deleted from database.", point.id) + + if config.storage.method.enabled: # local image + if point.local: + image_files = [itm[0] async for itm in + services.storage_service.active_storage.list_files("", f"{point.id}.*")] + assert len(image_files) <= 1 + if not image_files: + logger.warning("Image {} is a local image but not found in static folder.", point.id) + else: + await services.storage_service.active_storage.move(image_files[0], f"_deleted/{image_files[0].name}") + logger.success("Image {} removed.", image_files[0].name) + if point.thumbnail_url is not None and (point.local or point.local_thumbnail): + thumbnail_file = PurePath(f"thumbnails/{point.id}.webp") + if await services.storage_service.active_storage.is_exist(thumbnail_file): + await services.storage_service.active_storage.delete(thumbnail_file) + logger.success("Thumbnail {} removed.", thumbnail_file.name) + else: + logger.warning("Thumbnail {} not found.", thumbnail_file.name) + + return NekoProtocol(message="Image deleted.") + + +@admin_router.put("/update_opt/{image_id}", description="Update a image's optional information") +async def update_image(image_id: Annotated[UUID, params.Path(description="The id of the image you want to delete.")], + model: ImageOptUpdateModel) -> NekoProtocol: + if model.empty(): + raise HTTPException(422, "Nothing to update.") + try: + point = await services.db_context.retrieve_by_id(str(image_id)) + except PointNotFoundError as ex: + raise HTTPException(404, "Cannot find the image with the given ID.") from ex + + if model.thumbnail_url is not None: + if point.local or point.local_thumbnail: + raise HTTPException(422, "Cannot change the thumbnail URL of a local image.") + point.thumbnail_url = model.thumbnail_url + if model.url is not None: + if point.local: + raise HTTPException(422, "Cannot change the URL of a local image.") + point.url = model.url + if model.starred is not None: + point.starred = model.starred + if model.categories is not None: + point.categories = model.categories + + await services.db_context.updatePayload(point) + logger.success("Image {} updated.", point.id) + + return NekoProtocol(message="Image updated.") + + +IMAGE_MIMES = { + "image/jpeg": "jpeg", + "image/png": "png", + "image/webp": "webp", + "image/gif": "gif", +} + + +@admin_router.post("/upload", + description="Upload image to server. The image will be indexed and stored in the database. If " + "local is set to true, the image will be uploaded to local storage.") +async def upload_image(image_file: Annotated[UploadFile, File(description="The image to be uploaded.")], + model: Annotated[UploadImageModel, Depends()]) -> ImageUploadResponse: + # generate an ID for the image + img_type = None + if image_file.content_type.lower() in IMAGE_MIMES: + img_type = IMAGE_MIMES[image_file.content_type.lower()] + elif image_file.filename: + extension = PurePath(image_file.filename).suffix.lower() + if extension in VALID_IMAGE_EXTENSIONS: + img_type = extension[1:] + if not img_type: + logger.warning("Failed to infer image format of the uploaded image. Content Type: {}, Filename: {}", + image_file.content_type, image_file.filename) + raise HTTPException(415, "Unsupported image format.") + img_bytes = await image_file.read() + try: + img_id = await services.upload_service.assign_image_id(img_bytes) + except PointDuplicateError as ex: + raise HTTPException(409, + f"The uploaded point is already contained in the database! entity id: {ex.entity_id}") \ + from ex + try: + image = Image.open(BytesIO(img_bytes)) + image.verify() + image.close() + except UnidentifiedImageError as ex: + logger.warning("Invalid image file from upload request. id: {}", img_id) + raise HTTPException(422, "Cannot open the image file.") from ex + + image_data = ImageData(id=img_id, + url=model.url, + thumbnail_url=model.thumbnail_url, + local=model.local, + categories=model.categories, + starred=model.starred, + format=img_type, + index_date=datetime.now()) + + await services.upload_service.queue_upload_image(image_data, img_bytes, model.skip_ocr, model.local_thumbnail) + return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id) + + +@admin_router.get("/server_info", description="Get server information") +async def server_info() -> ServerInfoResponse: + return ServerInfoResponse(message="Successfully get server information!", + image_count=await services.db_context.get_counts(exact=True), + index_queue_length=services.upload_service.get_queue_size()) + + +@admin_router.post("/duplication_validate", + description="Check if an image exists in the server by its SHA1 hash. If the image exists, " + "the image ID will be returned.\n" + "This is helpful for checking if an image is already in the server without " + "uploading the image.") +async def duplication_validate(model: DuplicateValidationModel) -> DuplicateValidationResponse: + ids = [generate_uuid_from_sha1(t) for t in model.hashes] + valid_ids = await services.db_context.validate_ids([str(t) for t in ids]) + exists_matrix = [str(t) in valid_ids or t in services.upload_service.uploading_ids for t in ids] + return DuplicateValidationResponse( + exists=exists_matrix, + entity_ids=[(str(t) if exists else None) for (t, exists) in zip(ids, exists_matrix)], + message="Validation completed.") diff --git a/app/Controllers/images.py b/app/Controllers/images.py new file mode 100644 index 0000000000000000000000000000000000000000..903ed95ee98dbbdc6853203713f130facbb8df38 --- /dev/null +++ b/app/Controllers/images.py @@ -0,0 +1,43 @@ +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, Path, HTTPException, Query + +from app.Models.api_response.images_api_response import QueryByIdApiResponse, ImageStatus, QueryImagesApiResponse +from app.Models.query_params import FilterParams +from app.Services.authentication import force_access_token_verify +from app.Services.provider import ServiceProvider +from app.Services.vector_db_context import PointNotFoundError +from app.config import config + +images_router = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None), + tags=["Images"]) + +services: ServiceProvider | None = None # The service provider will be injected in the webapp initialize + + +@images_router.get("/id/{image_id}", description="Query the image info with the given image ID. \n" + "This can also be used to check the status" + " of an image in the index queue.") +async def query_image_by_id(image_id: Annotated[UUID, Path(description="The id of the image you want to query.")]): + try: + return QueryByIdApiResponse(img=await services.db_context.retrieve_by_id(str(image_id)), + img_status=ImageStatus.MAPPED, + message="Success query the image with the given ID.") + except PointNotFoundError as ex: + if services.upload_service and image_id in services.upload_service.uploading_ids: + return QueryByIdApiResponse(img=None, + img_status=ImageStatus.IN_QUEUE, + message="The image is in the indexing queue.") + raise HTTPException(404, "Cannot find the image with the given ID.") from ex + + +@images_router.get("/", description="Query images in order of ID.") +async def scroll_images(filter_param: Annotated[FilterParams, Depends()], + prev_offset_id: Annotated[UUID, Query(description="The previous offset image ID.")] = None, + count: Annotated[int, Query(ge=1, le=100, description="The number of images to query.")] = 15): + # validate the offset ID + if prev_offset_id is not None and len(await services.db_context.validate_ids([str(prev_offset_id)])) == 0: + raise HTTPException(404, "The previous offset ID is invalid.") + images, offset = await services.db_context.scroll_points(str(prev_offset_id), count, filter_param=filter_param) + return QueryImagesApiResponse(images=images, next_page_offset=offset, message="Success query images.") diff --git a/app/Controllers/search.py b/app/Controllers/search.py new file mode 100644 index 0000000000000000000000000000000000000000..911d8038f1cd2f56718d61b747509e0d8d2ccda5 --- /dev/null +++ b/app/Controllers/search.py @@ -0,0 +1,214 @@ +from io import BytesIO +from typing import Annotated, List +from uuid import uuid4, UUID + +from PIL import Image +from fastapi import APIRouter, HTTPException +from fastapi.params import File, Query, Path, Depends +from loguru import logger + +from app.Models.api_models.search_api_model import AdvancedSearchModel, CombinedSearchModel, SearchBasisEnum +from app.Models.api_response.search_api_response import SearchApiResponse +from app.Models.query_params import SearchPagingParams, FilterParams +from app.Models.search_result import SearchResult +from app.Services.authentication import force_access_token_verify +from app.Services.provider import ServiceProvider +from app.config import config +from app.util.calculate_vectors_cosine import calculate_vectors_cosine + +search_router = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None), + tags=["Search"]) + +services: ServiceProvider | None = None # The service provider will be injected in the webapp initialize + + +class SearchBasisParams: + def __init__(self, + basis: Annotated[SearchBasisEnum, Query( + description="The basis used to search the image.")] = SearchBasisEnum.vision): + if basis == SearchBasisEnum.ocr and not config.ocr_search.enable: + raise HTTPException(400, "OCR search is not enabled.") + self.basis = basis + + +async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse: + if not config.storage.method.enabled: + return resp + for item in resp.result: + if item.img.local: + img_extension = item.img.format or item.img.url.split('.')[-1] + img_remote_filename = f"{item.img.id}.{img_extension}" + item.img.url = await services.storage_service.active_storage.presign_url(img_remote_filename) + if item.img.thumbnail_url is not None and (item.img.local or item.img.local_thumbnail): + thumbnail_remote_filename = f"thumbnails/{item.img.id}.webp" + item.img.thumbnail_url = await services.storage_service.active_storage.presign_url( + thumbnail_remote_filename) + return resp + + +@search_router.get("/text/{prompt}", description="Search images by text prompt") +async def textSearch( + prompt: Annotated[ + str, Path(max_length=100, description="The image prompt text you want to search.")], + basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], + filter_param: Annotated[FilterParams, Depends(FilterParams)], + paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)], + exact: Annotated[bool, Query( + description="If using OCR search, this option will require the ocr text contains **exactly** the " + "criteria you have given. This won't take any effect in vision search.")] = False +) -> SearchApiResponse: + logger.info("Text search request received, prompt: {}", prompt) + text_vector = services.transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \ + else services.transformers_service.get_bert_vector(prompt) + if basis.basis == SearchBasisEnum.ocr and exact: + filter_param.ocr_text = prompt + results = await services.db_context.querySearch(text_vector, + query_vector_name=services.db_context.vector_name_for_basis( + basis.basis), + filter_param=filter_param, + top_k=paging.count, + skip=paging.skip) + return await result_postprocessing( + SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4())) + + +@search_router.post("/image", description="Search images by image") +async def imageSearch( + image: Annotated[bytes, File(max_length=10 * 1024 * 1024, media_type="image/*", + description="The image you want to search.")], + filter_param: Annotated[FilterParams, Depends(FilterParams)], + paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)] +) -> SearchApiResponse: + fakefile = BytesIO(image) + img = Image.open(fakefile) + logger.info("Image search request received") + image_vector = services.transformers_service.get_image_vector(img) + results = await services.db_context.querySearch(image_vector, + top_k=paging.count, + skip=paging.skip, + filter_param=filter_param) + return await result_postprocessing( + SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4())) + + +@search_router.get("/similar/{image_id}", + description="Search images similar to the image with given id. " + "Won't include the given image itself in the result.") +async def similarWith( + image_id: Annotated[UUID, Path(description="The id of the image you want to search.")], + basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], + filter_param: Annotated[FilterParams, Depends(FilterParams)], + paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)] +) -> SearchApiResponse: + logger.info("Similar search request received, id: {}", image_id) + results = await services.db_context.querySimilar(search_id=str(image_id), + top_k=paging.count, + skip=paging.skip, + filter_param=filter_param, + query_vector_name=services.db_context.vector_name_for_basis( + basis.basis)) + return await result_postprocessing( + SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4())) + + +@search_router.post("/advanced", description="Search with multiple criteria") +async def advancedSearch( + model: AdvancedSearchModel, + basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], + filter_param: Annotated[FilterParams, Depends(FilterParams)], + paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse: + logger.info("Advanced search request received: {}", model) + result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging) + return await result_postprocessing( + SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4())) + + +@search_router.post("/combined", description="Search with combined criteria") +async def combinedSearch( + model: CombinedSearchModel, + basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], + filter_param: Annotated[FilterParams, Depends(FilterParams)], + paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse: + if not config.ocr_search.enable: + raise HTTPException(400, "You used combined search, but it needs OCR search which is not " + "enabled.") + logger.info("Combined search request received: {}", model) + result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging, True) + calculate_and_sort_by_combined_scores(model, basis, result) + result = result[:paging.count] if len(result) > paging.count else result + return await result_postprocessing( + SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4())) + + +@search_router.get("/random", description="Get random images") +async def randomPick( + filter_param: Annotated[FilterParams, Depends(FilterParams)], + paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)], + seed: Annotated[int | None, Query( + description="The seed for random pick. This is helpful for generating a reproducible random pick.")] = None, +) -> SearchApiResponse: + logger.info("Random pick request received") + random_vector = services.transformers_service.get_random_vector(seed) + result = await services.db_context.querySearch(random_vector, top_k=paging.count, skip=paging.skip, + filter_param=filter_param) + return await result_postprocessing( + SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4())) + + +# @search_router.get("/recall/{query_id}", description="Recall the query with given queryId") +# async def recallQuery(query_id: str): +# raise NotImplementedError() + +async def process_advanced_and_combined_search_query(model: AdvancedSearchModel, + basis: SearchBasisParams, + filter_param: FilterParams, + paging: SearchPagingParams, + is_combined_search=False) -> List[SearchResult]: + match basis.basis: + case SearchBasisEnum.ocr: + positive_vectors = [services.transformers_service.get_bert_vector(t) for t in model.criteria] + negative_vectors = [services.transformers_service.get_bert_vector(t) for t in model.negative_criteria] + case SearchBasisEnum.vision: + positive_vectors = [services.transformers_service.get_text_vector(t) for t in model.criteria] + negative_vectors = [services.transformers_service.get_text_vector(t) for t in model.negative_criteria] + case _: # pragma: no cover + raise NotImplementedError() + # In order to ensure the query effect of the combined query, modify the actual top_k + _query_top_k = min(max(30, paging.count * 3), 100) if is_combined_search else paging.count + result = await services.db_context.querySimilar( + query_vector_name=services.db_context.vector_name_for_basis(basis.basis), + positive_vectors=positive_vectors, + negative_vectors=negative_vectors, + mode=model.mode, + filter_param=filter_param, + with_vectors=is_combined_search, + top_k=_query_top_k, + skip=paging.skip) + return result + + +def calculate_and_sort_by_combined_scores(model: CombinedSearchModel, + basis: SearchBasisParams, + result: List[SearchResult]) -> None: + # Use a different method to calculate the extra prompt vector based on the basis + match basis.basis: + case SearchBasisEnum.ocr: + extra_prompt_vector = services.transformers_service.get_text_vector(model.extra_prompt) + case SearchBasisEnum.vision: + extra_prompt_vector = services.transformers_service.get_bert_vector(model.extra_prompt) + case _: # pragma: no cover + raise NotImplementedError() + # Calculate combined_similar_score (original score * similar_score) and write to SearchResult.score + for itm in result: + match basis.basis: + case SearchBasisEnum.ocr: + extra_vector = itm.img.image_vector + case SearchBasisEnum.vision: + extra_vector = itm.img.text_contain_vector + case _: # pragma: no cover + raise NotImplementedError() + if extra_vector is not None: + similar_score = calculate_vectors_cosine(extra_vector, extra_prompt_vector) + itm.score = (1 + similar_score) * itm.score + # Finally, sort the result by combined_similar_score + result.sort(key=lambda i: i.score, reverse=True) diff --git a/app/Models/__init__.py b/app/Models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/Models/api_models/__init__.py b/app/Models/api_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/Models/api_models/admin_api_model.py b/app/Models/api_models/admin_api_model.py new file mode 100644 index 0000000000000000000000000000000000000000..12037aa54edcf320938ee4b86d79e84d049f15f4 --- /dev/null +++ b/app/Models/api_models/admin_api_model.py @@ -0,0 +1,31 @@ +from typing import Optional, Annotated + +from pydantic import BaseModel, Field, StringConstraints + + +class ImageOptUpdateModel(BaseModel): + starred: Optional[bool] = Field(None, + description="Whether the image is starred or not. Leave empty to keep the value " + "unchanged.") + categories: Optional[list[str]] = Field(None, + description="The categories of the image. Leave empty to keep the value " + "unchanged.") + url: Optional[str] = Field(None, + description="The url of the image. Leave empty to keep the value unchanged. Changing " + "the url of a local image is not allowed.") + + thumbnail_url: Optional[str] = Field(None, + description="The url of the thumbnail. Leave empty to keep the value " + "unchanged. Changing the thumbnail_url of an image with a local " + "thumbnail is not allowed.") + + def empty(self) -> bool: + return all([item is None for item in self.model_dump().values()]) + + +Sha1HashString = Annotated[ + str, StringConstraints(min_length=40, max_length=40, pattern=r"[0-9a-f]+", to_lower=True, strip_whitespace=True)] + + +class DuplicateValidationModel(BaseModel): + hashes: list[Sha1HashString] = Field(description="The SHA1 hash of the image.", min_length=1) diff --git a/app/Models/api_models/admin_query_params.py b/app/Models/api_models/admin_query_params.py new file mode 100644 index 0000000000000000000000000000000000000000..eba8dd4d11b9af185023df73e0fd7be48f21cc84 --- /dev/null +++ b/app/Models/api_models/admin_query_params.py @@ -0,0 +1,48 @@ +from enum import Enum +from typing import Optional + +from fastapi import Query, HTTPException + + +class UploadImageThumbnailMode(str, Enum): + IF_NECESSARY = "if_necessary" + ALWAYS = "always" + NEVER = "never" + + +class UploadImageModel: + def __init__(self, + url: Optional[str] = Query(None, + description="The image's url. If the image is local, this field will be " + "ignored. Otherwise it is required."), + thumbnail_url: Optional[str] = Query(None, + description="The image's thumbnail url. If the image is local " + "or local_thumbnail's value is always, " + "this field will be ignored. Currently setting a " + "external thumbnail for a local image is " + "unsupported due to compatibility issues."), + categories: Optional[str] = Query(None, + description="The categories of the image. The entries should be " + "seperated by comma."), + starred: bool = Query(False, description="If the image is starred."), + local: bool = Query(False, + description="When set to true, the image will be uploaded to local storage. " + "Otherwise, it will only be indexed in the database."), + local_thumbnail: UploadImageThumbnailMode = + Query(default=None, + description="Whether to generate thumbnail locally. Possible values:\n" + "- `if_necessary`: Only generate thumbnail if the image is larger than 500KB. " + "This is the default value if `local=True`\n" + " - `always`: Always generate thumbnail.\n" + " - `never`: Never generate thumbnail. This is the default value if `local=False`."), + skip_ocr: bool = Query(False, description="Whether to skip the OCR process.")): + self.url = url + self.thumbnail_url = thumbnail_url + self.categories = [t.strip() for t in categories.split(',') if t.strip()] if categories else None + self.starred = starred + self.local = local + self.skip_ocr = skip_ocr + self.local_thumbnail = local_thumbnail if (local_thumbnail is not None) else ( + UploadImageThumbnailMode.IF_NECESSARY if local else UploadImageThumbnailMode.NEVER) + if not self.url and not self.local: + raise HTTPException(422, "A correspond url must be provided for a non-local image.") diff --git a/app/Models/api_models/search_api_model.py b/app/Models/api_models/search_api_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f2de44ea3510a04de487afd179d75f503041ea4a --- /dev/null +++ b/app/Models/api_models/search_api_model.py @@ -0,0 +1,30 @@ +from enum import Enum + +from pydantic import BaseModel, Field + + +class SearchBasisEnum(str, Enum): + vision = "vision" + ocr = "ocr" + + +class SearchModelEnum(str, Enum): + average = "average" + best = "best" + + +class AdvancedSearchModel(BaseModel): + criteria: list[str] = Field([], + description="The positive criteria you want to search with", + max_length=16, + min_length=1) + negative_criteria: list[str] = Field([], + description="The negative criteria you want to search with", + max_length=16) + mode: SearchModelEnum = Field(SearchModelEnum.average, + description="The mode you want to use to combine the criteria.") + + +class CombinedSearchModel(AdvancedSearchModel): + extra_prompt: str = Field(max_length=100, + description="The secondary prompt used for filtering the image.") diff --git a/app/Models/api_response/admin_api_response.py b/app/Models/api_response/admin_api_response.py new file mode 100644 index 0000000000000000000000000000000000000000..2a450f975a4ae3dab4fcd7217cfc06521de3d318 --- /dev/null +++ b/app/Models/api_response/admin_api_response.py @@ -0,0 +1,21 @@ +from uuid import UUID + +from pydantic import Field + +from .base import NekoProtocol + + +class ServerInfoResponse(NekoProtocol): + image_count: int + index_queue_length: int + + +class DuplicateValidationResponse(NekoProtocol): + entity_ids: list[UUID | None] = Field( + description="The image id for each hash. If the image does not exist in the server, the value will be null.") + exists: list[bool] = Field( + description="Whether the image exists in the server. True if the image exists, False otherwise.") + + +class ImageUploadResponse(NekoProtocol): + image_id: UUID diff --git a/app/Models/api_response/base.py b/app/Models/api_response/base.py new file mode 100644 index 0000000000000000000000000000000000000000..50c384d21d0caf78a84d127280ccdd83e810a234 --- /dev/null +++ b/app/Models/api_response/base.py @@ -0,0 +1,25 @@ +from datetime import datetime + +from pydantic import BaseModel + + +class NekoProtocol(BaseModel): + message: str + + +class WelcomeApiAuthenticationResponse(BaseModel): + required: bool + passed: bool + + +class WelcomeApiAdminPortalAuthenticationResponse(BaseModel): + available: bool + passed: bool + + +class WelcomeApiResponse(NekoProtocol): + server_time: datetime + wiki: dict[str, str] + authorization: WelcomeApiAuthenticationResponse + admin_api: WelcomeApiAdminPortalAuthenticationResponse + available_basis: list[str] diff --git a/app/Models/api_response/images_api_response.py b/app/Models/api_response/images_api_response.py new file mode 100644 index 0000000000000000000000000000000000000000..875d19a500cd30abbb6cfdfc76bff6500ca9711c --- /dev/null +++ b/app/Models/api_response/images_api_response.py @@ -0,0 +1,25 @@ +from enum import Enum + +from pydantic import Field + +from app.Models.api_response.base import NekoProtocol +from app.Models.img_data import ImageData + + +class ImageStatus(str, Enum): + MAPPED = "mapped" + IN_QUEUE = "in_queue" + + +class QueryByIdApiResponse(NekoProtocol): + img_status: ImageStatus = Field(description="The status of the image.\n" + "Warning: If NekoImageGallery is deployed in a cluster, " + "the `in_queue` might not be accurate since the index queue " + "is independent of each service instance.") + img: ImageData | None = Field(description="The mapped image data. Only available when `img_status = mapped`.") + + +class QueryImagesApiResponse(NekoProtocol): + images: list[ImageData] = Field(description="The list of images.") + next_page_offset: str | None = Field(description="The offset ID for the next page query. " + "If there are no more images, this field will be null.") diff --git a/app/Models/api_response/search_api_response.py b/app/Models/api_response/search_api_response.py new file mode 100644 index 0000000000000000000000000000000000000000..966bb80b22638203e488ed9b652057cd9d88dca9 --- /dev/null +++ b/app/Models/api_response/search_api_response.py @@ -0,0 +1,8 @@ +from .base import NekoProtocol +from ..search_result import SearchResult +from uuid import UUID + + +class SearchApiResponse(NekoProtocol): + query_id: UUID + result: list[SearchResult] diff --git a/app/Models/errors.py b/app/Models/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..921399b6b3395c5ed6efdeb648c9da6d9f871b92 --- /dev/null +++ b/app/Models/errors.py @@ -0,0 +1,10 @@ +from uuid import UUID + + +class PointDuplicateError(ValueError): + def __init__(self, message: str, entity_id: UUID | None = None): + self.message = message + self.entity_id = entity_id + super().__init__(message) + + pass diff --git a/app/Models/img_data.py b/app/Models/img_data.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8a85498825f767dcac452517cc6636075d9b78 --- /dev/null +++ b/app/Models/img_data.py @@ -0,0 +1,53 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +from numpy import ndarray +from pydantic import BaseModel, Field, ConfigDict + + +class ImageData(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, extra='ignore') + + id: UUID + url: Optional[str] = None + thumbnail_url: Optional[str] = None + ocr_text: Optional[str] = None + image_vector: Optional[ndarray] = Field(None, exclude=True) + text_contain_vector: Optional[ndarray] = Field(None, exclude=True) + index_date: datetime + width: Optional[int] = None + height: Optional[int] = None + aspect_ratio: Optional[float] = None + starred: Optional[bool] = False + categories: Optional[list[str]] = [] + local: Optional[bool] = False + local_thumbnail: Optional[bool] = False + format: Optional[str] = None # required for s3 local storage + + @property + def ocr_text_lower(self) -> str | None: + if self.ocr_text is None: + return None + return self.ocr_text.lower() + + @property + def payload(self): + result = self.model_dump(exclude={'id', 'index_date'}) + # Qdrant database cannot accept datetime object, so we have to convert it to string + result['index_date'] = self.index_date.isoformat() + # Qdrant doesn't support case-insensitive search, so we need to store a lowercase version of the text + result['ocr_text_lower'] = self.ocr_text_lower + return result + + @classmethod + def from_payload(cls, img_id: str, payload: dict, + image_vector: Optional[ndarray] = None, text_contain_vector: Optional[ndarray] = None): + # Convert the datetime string back to datetime object + index_date = datetime.fromisoformat(payload['index_date']) + del payload['index_date'] + return cls(id=UUID(img_id), + index_date=index_date, + **payload, + image_vector=image_vector if image_vector is not None else None, + text_contain_vector=text_contain_vector if text_contain_vector is not None else None) diff --git a/app/Models/query_params.py b/app/Models/query_params.py new file mode 100644 index 0000000000000000000000000000000000000000..ae807352372e981996c3d32133d96b34c6187c01 --- /dev/null +++ b/app/Models/query_params.py @@ -0,0 +1,56 @@ +from typing import Annotated + +from fastapi.params import Query + + +class SearchPagingParams: + def __init__( + self, + count: Annotated[int, Query(ge=1, le=100, description="The number of results you want to get.")] = 10, + skip: Annotated[int, Query(ge=0, description="The number of results you want to skip.")] = 0 + ): + self.count = count + self.skip = skip + + +class FilterParams: + def __init__( + self, + preferred_ratio: Annotated[ + float | None, Query(gt=0, description="The preferred aspect ratio of the image.")] = None, + ratio_tolerance: Annotated[ + float, Query(gt=0, lt=1, description="The tolerance of the aspect ratio.")] = 0.1, + min_width: Annotated[int | None, Query(geq=0, description="The minimum width of the image.")] = None, + min_height: Annotated[int | None, Query(geq=0, description="The minimum height of the image.")] = None, + starred: Annotated[bool | None, Query(description="Whether the image is starred.")] = None, + categories: Annotated[str | None, Query( + description="The categories whitelist of the image. Image with **any of** the given categories will " + "be included. The entries should be seperated by comma.", + examples=["stickers, cg"])] = None, + categories_negative: Annotated[ + str | None, Query( + description="The categories blacklist of the image. Image with **any of** the given categories " + "will be ignored. The entries should be seperated by comma.", + examples=["stickers, cg"])] = None, + ): + self.preferred_ratio = preferred_ratio + self.ratio_tolerance = ratio_tolerance + self.min_width = min_width + self.min_height = min_height + self.starred = starred + self.categories = [t.strip() for t in categories.split(',') if t.strip()] if categories else None + self.categories_negative = [t.strip() for t in categories_negative.split(',') if + t.strip()] if categories_negative else None + self.ocr_text = None # For exact search + + @property + def min_ratio(self) -> float | None: + if self.preferred_ratio is None: + return None + return self.preferred_ratio * (1 - self.ratio_tolerance) + + @property + def max_ratio(self) -> float | None: + if self.preferred_ratio is None: + return None + return self.preferred_ratio * (1 + self.ratio_tolerance) diff --git a/app/Models/search_result.py b/app/Models/search_result.py new file mode 100644 index 0000000000000000000000000000000000000000..c5bd4319f9214ebf4b0312b29b53de03ee18e78c --- /dev/null +++ b/app/Models/search_result.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel +from .img_data import ImageData + + +class SearchResult(BaseModel): + img: ImageData + score: float diff --git a/app/Services/__init__.py b/app/Services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/Services/authentication.py b/app/Services/authentication.py new file mode 100644 index 0000000000000000000000000000000000000000..e8143b79f4d2a63f326ae357651271d694ee8aed --- /dev/null +++ b/app/Services/authentication.py @@ -0,0 +1,32 @@ +from typing import Annotated + +from fastapi import HTTPException +from fastapi.params import Header, Depends + +from app.config import config + + +def verify_access_token(token: str | None) -> bool: + return (not config.access_protected) or (token is not None and token == config.access_token) + + +def permissive_access_token_verify( + x_access_token: Annotated[str | None, Header( + description="Access token set in configuration (if access_protected is enabled)")] = None) -> bool: + return verify_access_token(x_access_token) + + +def force_access_token_verify(token_passed: Annotated[bool, Depends(permissive_access_token_verify)]): + if not token_passed: + raise HTTPException(status_code=401, detail="Access token is not present or invalid.") + + +def permissive_admin_token_verify( + x_admin_token: Annotated[str | None, Header( + description="Admin token set in configuration (if admin_api_enable is enabled)")] = None) -> bool: + return config.admin_api_enable and x_admin_token == config.admin_token + + +def force_admin_token_verify(token_passed: Annotated[bool, Depends(permissive_admin_token_verify)]): + if not token_passed: + raise HTTPException(status_code=401, detail="Admin token is not present or invalid.") diff --git a/app/Services/index_service.py b/app/Services/index_service.py new file mode 100644 index 0000000000000000000000000000000000000000..00cdcf52b9b3103ccc72526a191598be327f0478 --- /dev/null +++ b/app/Services/index_service.py @@ -0,0 +1,60 @@ +from PIL import Image +from fastapi.concurrency import run_in_threadpool + +from app.Models.errors import PointDuplicateError +from app.Models.img_data import ImageData +from app.Services.lifespan_service import LifespanService +from app.Services.ocr_services import OCRService +from app.Services.transformers_service import TransformersService +from app.Services.vector_db_context import VectorDbContext +from app.config import config + + +class IndexService(LifespanService): + def __init__(self, ocr_service: OCRService, transformers_service: TransformersService, db_context: VectorDbContext): + self._ocr_service = ocr_service + self._transformers_service = transformers_service + self._db_context = db_context + + def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False): + image_data.width = image.width + image_data.height = image.height + image_data.aspect_ratio = float(image.width) / image.height + + if image.mode != 'RGB': + image = image.convert('RGB') # to reduce convert in next steps + else: + image = image.copy() + image_data.image_vector = self._transformers_service.get_image_vector(image) + if not skip_ocr and config.ocr_search.enable: + image_data.ocr_text = self._ocr_service.ocr_interface(image) + if image_data.ocr_text != "": + image_data.text_contain_vector = self._transformers_service.get_bert_vector(image_data.ocr_text) + else: + image_data.ocr_text = None + + # currently, here only need just a simple check + async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool: + image_id_list = [str(item.id) for item in image_data] + result = await self._db_context.validate_ids(image_id_list) + return len(result) != 0 + + async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False, + background=False): + if not skip_duplicate_check and (await self._is_point_duplicate([image_data])): + raise PointDuplicateError("The uploaded points are contained in the database!", image_data.id) + + if background: + await run_in_threadpool(self._prepare_image, image, image_data, skip_ocr) + else: + self._prepare_image(image, image_data, skip_ocr) + + await self._db_context.insertItems([image_data]) + + async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], + skip_ocr=False, allow_overwrite=False): + if not allow_overwrite and (await self._is_point_duplicate(image_data)): + raise PointDuplicateError("The uploaded points are contained in the database!") + for img, img_data in zip(image, image_data): + self._prepare_image(img, img_data, skip_ocr) + await self._db_context.insertItems(image_data) diff --git a/app/Services/lifespan_service.py b/app/Services/lifespan_service.py new file mode 100644 index 0000000000000000000000000000000000000000..a6437f01beb015d0070e6c4e5839d73e6b940021 --- /dev/null +++ b/app/Services/lifespan_service.py @@ -0,0 +1,6 @@ +class LifespanService: + async def on_load(self): + pass + + async def on_exit(self): + pass diff --git a/app/Services/ocr_services.py b/app/Services/ocr_services.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6907b3ac35282cad9990b288611ce8249cc769 --- /dev/null +++ b/app/Services/ocr_services.py @@ -0,0 +1,115 @@ +from time import time + +import numpy as np +import torch +from PIL import Image +from loguru import logger + +from app.Services.lifespan_service import LifespanService +from app.config import config + + +class OCRService(LifespanService): + def __init__(self): + self._device = config.device + if self._device == "auto": + self._device = "cuda" if torch.cuda.is_available() else "cpu" + + @staticmethod + def _image_preprocess(img: Image.Image) -> Image.Image: + if img.mode != 'RGB': + img = img.convert('RGB') + if img.size[0] > 1024 or img.size[1] > 1024: + img.thumbnail((1024, 1024), Image.Resampling.LANCZOS) + new_img = Image.new('RGB', (1024, 1024), (0, 0, 0)) + new_img.paste(img, ((1024 - img.size[0]) // 2, (1024 - img.size[1]) // 2)) + return new_img + + def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str: + pass + + +class EasyPaddleOCRService(OCRService): + def __init__(self): + super().__init__() + from easypaddleocr import EasyPaddleOCR + self._paddle_ocr_module = EasyPaddleOCR(use_angle_cls=True, + needWarmUp=True, + devices=self._device, + warmup_size=(960, 960), + model_local_dir=config.model.easypaddleocr if + config.model.easypaddleocr else None) + logger.success("EasyPaddleOCR loaded successfully") + + @staticmethod + def _image_preprocess(img: Image.Image) -> Image.Image: + # Optimized `easypaddleocr` doesn't require scaling preprocess + if img.mode != 'RGB': + img = img.convert('RGB') + return img + + def _easy_paddleocr_process(self, img: Image.Image) -> str: + _, ocr_result, _ = self._paddle_ocr_module.ocr(np.array(img)) + if ocr_result: + return "".join(itm[0] for itm in ocr_result if float(itm[1]) > config.ocr_search.ocr_min_confidence) + return "" + + def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str: + start_time = time() + logger.info("Processing text with EasyPaddleOCR...") + res = self._easy_paddleocr_process(self._image_preprocess(img) if need_preprocess else img) + logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time) + return res + + +class EasyOCRService(OCRService): + def __init__(self): + super().__init__() + # noinspection PyPackageRequirements + import easyocr # pylint: disable=import-error + self._easy_ocr_module = easyocr.Reader(config.ocr_search.ocr_language, + gpu=self._device == "cuda") + logger.success("easyOCR loaded successfully") + + def _easyocr_process(self, img: Image.Image) -> str: + ocr_result = self._easy_ocr_module.readtext(np.array(img)) + return " ".join(itm[1] for itm in ocr_result if itm[2] > config.ocr_search.ocr_min_confidence) + + def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str: + start_time = time() + logger.info("Processing text with easyOCR...") + res = self._easyocr_process(self._image_preprocess(img) if need_preprocess else img) + logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time) + return res + + +class PaddleOCRService(OCRService): + def __init__(self): + super().__init__() + # noinspection PyPackageRequirements + import paddleocr # pylint: disable=import-error + self._paddle_ocr_module = paddleocr.PaddleOCR(lang="ch", use_angle_cls=True, + use_gpu=self._device == "cuda") + logger.success("PaddleOCR loaded successfully") + + def _paddleocr_process(self, img: Image.Image) -> str: + ocr_result = self._paddle_ocr_module.ocr(np.array(img), cls=True) + if ocr_result[0]: + return "".join(itm[1][0] for itm in ocr_result[0] if itm[1][1] > config.ocr_search.ocr_min_confidence) + return "" + + def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str: + start_time = time() + logger.info("Processing text with PaddleOCR...") + res = self._paddleocr_process(self._image_preprocess(img) if need_preprocess else img) + logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time) + return res + + +class DisabledOCRService(OCRService): + def __init__(self): + super().__init__() + logger.warning("OCR search is disabled. Skipping OCR model loading.") + + def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str: + raise NotImplementedError("OCR module is disabled. Consider enable it in config.") diff --git a/app/Services/provider.py b/app/Services/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..6870aa9130a1b94ae81b3635b680229b1de892a4 --- /dev/null +++ b/app/Services/provider.py @@ -0,0 +1,56 @@ +import asyncio +from loguru import logger + +from .index_service import IndexService +from .lifespan_service import LifespanService +from .storage import StorageService +from .transformers_service import TransformersService +from .upload_service import UploadService +from .vector_db_context import VectorDbContext +from ..config import config, environment + + +class ServiceProvider: + def __init__(self): + self.transformers_service = TransformersService() + self.db_context = VectorDbContext() + self.ocr_service = None + + if config.ocr_search.enable and (environment.local_indexing or config.admin_api_enable): + match config.ocr_search.ocr_module: + case "easyocr": + from .ocr_services import EasyOCRService + + self.ocr_service = EasyOCRService() + case "easypaddleocr": + from .ocr_services import EasyPaddleOCRService + + self.ocr_service = EasyPaddleOCRService() + case "paddleocr": + from .ocr_services import PaddleOCRService + + self.ocr_service = PaddleOCRService() + case _: + raise NotImplementedError(f"OCR module {config.ocr_search.ocr_module} not implemented.") + else: + from .ocr_services import DisabledOCRService + + self.ocr_service = DisabledOCRService() + logger.info(f"OCR service '{type(self.ocr_service).__name__}' initialized.") + + self.index_service = IndexService(self.ocr_service, self.transformers_service, self.db_context) + self.storage_service = StorageService() + logger.info(f"Storage service '{type(self.storage_service.active_storage).__name__}' initialized.") + + self.upload_service = UploadService(self.storage_service, self.db_context, self.index_service) + logger.info(f"Upload service '{type(self.upload_service).__name__}' initialized") + + async def onload(self): + tasks = [service.on_load() for service_name in dir(self) + if isinstance((service := getattr(self, service_name)), LifespanService)] + await asyncio.gather(*tasks) + + async def onexit(self): + tasks = [service.on_exit() for service_name in dir(self) + if isinstance((service := getattr(self, service_name)), LifespanService)] + await asyncio.gather(*tasks) diff --git a/app/Services/storage/__init__.py b/app/Services/storage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..162679408c1666983003a6b29da793fd48bae6a7 --- /dev/null +++ b/app/Services/storage/__init__.py @@ -0,0 +1,27 @@ +from app.Services.lifespan_service import LifespanService +from app.Services.storage.base import BaseStorage +from app.Services.storage.disabled_storage import DisabledStorage +from app.Services.storage.local_storage import LocalStorage +from app.Services.storage.s3_compatible_storage import S3Storage +from app.config import config, StorageMode + + +class StorageService(LifespanService): + def __init__(self): + self.active_storage = None + match config.storage.method: + case StorageMode.LOCAL: + self.active_storage = LocalStorage() + case StorageMode.S3: + self.active_storage = S3Storage() + case StorageMode.DISABLED: + self.active_storage = DisabledStorage() + case _: + raise NotImplementedError(f"Storage method {config.storage.method} not implemented. " + f"Available methods: local, s3") + + async def on_load(self): + await self.active_storage.on_load() + + async def on_exit(self): + await self.active_storage.on_exit() diff --git a/app/Services/storage/base.py b/app/Services/storage/base.py new file mode 100644 index 0000000000000000000000000000000000000000..15e81bba6503a61b80588737d8012ff3f9b3f6cf --- /dev/null +++ b/app/Services/storage/base.py @@ -0,0 +1,146 @@ +import abc +import os +from typing import TypeVar, Generic, TypeAlias, Optional, AsyncGenerator + +from app.Services.lifespan_service import LifespanService + +FileMetaDataT = TypeVar('FileMetaDataT') + +PathLikeType: TypeAlias = str | os.PathLike +LocalFilePathType: TypeAlias = PathLikeType | bytes +RemoteFilePathType: TypeAlias = PathLikeType +LocalFileMetaDataType: TypeAlias = FileMetaDataT +RemoteFileMetaDataType: TypeAlias = FileMetaDataT + + +class BaseStorage(LifespanService, abc.ABC, Generic[FileMetaDataT]): + def __init__(self): + self.static_dir: os.PathLike + self.thumbnails_dir: os.PathLike + self.deleted_dir: os.PathLike + self.file_metadata: FileMetaDataT + + @abc.abstractmethod + async def is_exist(self, + remote_file: RemoteFilePathType) -> bool: + """ + Check if a remote_file exists. + :param remote_file: The file path relative to static_dir + :return: True if the file exists, False otherwise + """ + raise NotImplementedError + + @abc.abstractmethod + async def size(self, + remote_file: RemoteFilePathType) -> int: + """ + Get the size of a file in static_dir + :param remote_file: The file path relative to static_dir + :return: file's size + """ + raise NotImplementedError + + @abc.abstractmethod + async def url(self, + remote_file: RemoteFilePathType) -> str: + """ + Get the original URL of a file in static_dir. + This url will be placed in the payload field of the qdrant. + :param remote_file: The file path relative to static_dir + :return: file's "original URL" + """ + raise NotImplementedError + + @abc.abstractmethod + async def presign_url(self, + remote_file: RemoteFilePathType, + expire_second: int = 3600) -> str: + """ + Get the presign URL of a file in static_dir. + :param remote_file: The file path relative to static_dir + :param expire_second: Valid time for presign url + :return: file's "presign URL" + """ + raise NotImplementedError + + @abc.abstractmethod + async def fetch(self, + remote_file: RemoteFilePathType) -> bytes: + """ + Fetch a file from static_dir + :param remote_file: The file path relative to static_dir + :return: file's content + """ + raise NotImplementedError + + @abc.abstractmethod + async def upload(self, + local_file: "LocalFilePathType", + remote_file: RemoteFilePathType) -> None: + """ + Move a local picture file to the static_dir. + :param local_file: The absolute path to the local file or bytes. + :param remote_file: The file path relative to static_dir + """ + raise NotImplementedError + + @abc.abstractmethod + async def copy(self, + old_remote_file: RemoteFilePathType, + new_remote_file: RemoteFilePathType) -> None: + """ + Copy a file in static_dir. + :param old_remote_file: The file path relative to static_dir + :param new_remote_file: The file path relative to static_dir + """ + raise NotImplementedError + + @abc.abstractmethod + async def move(self, + old_remote_file: RemoteFilePathType, + new_remote_file: RemoteFilePathType) -> None: + """ + Move a file in static_dir. + :param old_remote_file: The file path relative to static_dir + :param new_remote_file: The file path relative to static_dir + """ + raise NotImplementedError + + @abc.abstractmethod + async def delete(self, + remote_file: RemoteFilePathType) -> None: + """ + Move a file in static_dir. + :param remote_file: The file path relative to static_dir + """ + raise NotImplementedError + + @abc.abstractmethod + async def list_files(self, + path: RemoteFilePathType, + pattern: Optional[str] = "*", + batch_max_files: Optional[int] = None, + valid_extensions: Optional[set[str]] = None) \ + -> AsyncGenerator[list[RemoteFilePathType], None]: + """ + Asynchronously generates a list of files from a given base directory path that match a specified pattern and set + of file extensions. + + :param path: The relative base directory path from which relative to static_dir to start listing files. + :param pattern: A glob pattern to filter files based on their names. Defaults to '*' which selects all files. + :param batch_max_files: The maximum number of files to return. If None, all matching files are returned. + :param valid_extensions: An extra set of file extensions to include (e.g., {".jpg", ".png"}). + If None, files are not filtered by extension. + :return: An asynchronous generator yielding lists of RemoteFilePathType objects representing the matching files. + + Usage example: + async for batch in list_files(base_path=".", pattern="*", max_files=100, valid_extensions={".jpg", ".png"}): + print(f"Batch: {batch}") + """ + raise NotImplementedError + + @abc.abstractmethod + async def update_metadata(self, + local_file_metadata: LocalFileMetaDataType, + remote_file_metadata: RemoteFileMetaDataType) -> None: + raise NotImplementedError diff --git a/app/Services/storage/disabled_storage.py b/app/Services/storage/disabled_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..902b231a977d47309a8f3a0d37908e11ff5eaa63 --- /dev/null +++ b/app/Services/storage/disabled_storage.py @@ -0,0 +1,43 @@ +from typing import Optional, AsyncGenerator + +from app.Services.storage import BaseStorage +from app.Services.storage.base import RemoteFilePathType, LocalFileMetaDataType, RemoteFileMetaDataType, \ + LocalFilePathType + + +class DisabledStorage(BaseStorage): # pragma: no cover + async def size(self, remote_file: RemoteFilePathType) -> int: + raise NotImplementedError + + async def url(self, remote_file: RemoteFilePathType) -> str: + raise NotImplementedError + + async def presign_url(self, remote_file: RemoteFilePathType, expire_second: int = 3600) -> str: + raise NotImplementedError + + async def fetch(self, remote_file: RemoteFilePathType) -> bytes: + raise NotImplementedError + + async def upload(self, local_file: "LocalFilePathType", remote_file: RemoteFilePathType) -> None: + raise NotImplementedError + + async def copy(self, old_remote_file: RemoteFilePathType, new_remote_file: RemoteFilePathType) -> None: + raise NotImplementedError + + async def move(self, old_remote_file: RemoteFilePathType, new_remote_file: RemoteFilePathType) -> None: + raise NotImplementedError + + async def delete(self, remote_file: RemoteFilePathType) -> None: + raise NotImplementedError + + async def update_metadata(self, local_file_metadata: LocalFileMetaDataType, + remote_file_metadata: RemoteFileMetaDataType) -> None: + raise NotImplementedError + + async def list_files(self, path: RemoteFilePathType, pattern: Optional[str] = "*", + batch_max_files: Optional[int] = None, valid_extensions: Optional[set[str]] = None) -> \ + AsyncGenerator[list[RemoteFilePathType], None]: + raise NotImplementedError + + async def is_exist(self, remote_file: RemoteFilePathType) -> bool: + raise NotImplementedError diff --git a/app/Services/storage/exception.py b/app/Services/storage/exception.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4b6cde97556ef09ab1d5fb45c6f1d34b012bfb --- /dev/null +++ b/app/Services/storage/exception.py @@ -0,0 +1,30 @@ +class StorageExtension(Exception): + pass + + +class LocalFileNotFoundError(StorageExtension): + pass + + +class LocalFileExistsError(StorageExtension): + pass + + +class LocalFilePermissionError(StorageExtension): + pass + + +class RemoteFileNotFoundError(StorageExtension): + pass + + +class RemoteFileExistsError(StorageExtension): + pass + + +class RemoteFilePermissionError(StorageExtension): + pass + + +class RemoteConnectError(StorageExtension): + pass diff --git a/app/Services/storage/local_storage.py b/app/Services/storage/local_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..fe04a8aaeeea9a7a1b7840e23a1d38b8e8892682 --- /dev/null +++ b/app/Services/storage/local_storage.py @@ -0,0 +1,145 @@ +import os +from asyncio import to_thread +from pathlib import Path as syncPath +from shutil import copy2, move +from typing import Optional, AsyncGenerator + +import aiofiles +from loguru import logger + +from app.Services.storage.base import BaseStorage, FileMetaDataT, RemoteFilePathType, LocalFilePathType +from app.Services.storage.exception import RemoteFileNotFoundError, LocalFileNotFoundError, RemoteFilePermissionError, \ + LocalFilePermissionError, LocalFileExistsError, RemoteFileExistsError +from app.config import config +from app.util.local_file_utility import glob_local_files + + +def transform_exception(param: str): + file_not_found_exp_map = {"local": LocalFileNotFoundError, "remote": RemoteFileNotFoundError} + permission_exp_map = {"remote": RemoteFilePermissionError, "local": LocalFilePermissionError} + file_exist_map = {"local": LocalFileExistsError, "remote": RemoteFileExistsError} + + def decorator(func): + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except FileNotFoundError as ex: + raise file_not_found_exp_map[param] from ex + except PermissionError as ex: + raise permission_exp_map[param] from ex + except FileExistsError as ex: + raise file_exist_map[param] from ex + + return wrapper + + return decorator + + +class LocalStorage(BaseStorage[FileMetaDataT: None]): + def __init__(self): + super().__init__() + self.static_dir = syncPath(os.path.abspath(config.storage.local.path)) + self.thumbnails_dir = self.static_dir / "thumbnails" + self.deleted_dir = self.static_dir / "_deleted" + self.file_metadata = None + self.file_path_warp = lambda x: self.static_dir / syncPath(x) + + def file_path_wrap(self, path: RemoteFilePathType) -> syncPath: + return self.static_dir / syncPath(path) + + async def on_load(self): + if not self.static_dir.is_dir(): + self.static_dir.mkdir(parents=True) + logger.warning(f"static_dir {self.static_dir} not found, created.") + if not self.thumbnails_dir.is_dir(): + self.thumbnails_dir.mkdir(parents=True) + logger.warning(f"thumbnails_dir {self.thumbnails_dir} not found, created.") + if not self.deleted_dir.is_dir(): + self.deleted_dir.mkdir(parents=True) + logger.warning(f"deleted_dir {self.deleted_dir} not found, created.") + + async def is_exist(self, + remote_file: "RemoteFilePathType") -> bool: + return self.file_path_warp(remote_file).exists() + + @transform_exception("remote") + async def size(self, + remote_file: "RemoteFilePathType") -> int: + _file = self.file_path_warp(remote_file) + return self.file_path_warp(remote_file).stat().st_size + + # noinspection PyMethodMayBeStatic + async def url(self, + remote_file: "RemoteFilePathType") -> str: + return f"/static/{str(remote_file)}" + + async def presign_url(self, + remote_file: "RemoteFilePathType", + expire_second: int = 3600) -> str: + return f"/static/{str(remote_file)}" + + @transform_exception("remote") + async def fetch(self, + remote_file: "RemoteFilePathType") -> bytes: + remote_file = self.file_path_warp(remote_file) + async with aiofiles.open(str(remote_file), 'rb') as file: + return await file.read() + + @transform_exception("local") + async def upload(self, + local_file: "LocalFilePathType", + remote_file: "RemoteFilePathType") -> None: + remote_file = self.file_path_warp(remote_file) + if isinstance(local_file, bytes): + async with aiofiles.open(str(remote_file), 'wb') as file: + await file.write(local_file) + else: + await to_thread(copy2, str(local_file), str(remote_file)) + local_file = f"{len(local_file)} bytes" if isinstance(local_file, bytes) else local_file + logger.success(f"Successfully uploaded file {str(local_file)} to {str(remote_file)} via local_storage.") + + @transform_exception("remote") + async def copy(self, + old_remote_file: "RemoteFilePathType", + new_remote_file: "RemoteFilePathType") -> None: + old_remote_file = self.file_path_warp(old_remote_file) + new_remote_file = self.file_path_warp(new_remote_file) + await to_thread(copy2, str(old_remote_file), str(new_remote_file)) + logger.success(f"Successfully copied file {str(old_remote_file)} to {str(new_remote_file)} via local_storage.") + + @transform_exception("remote") + async def move(self, + old_remote_file: "RemoteFilePathType", + new_remote_file: "RemoteFilePathType") -> None: + old_remote_file = self.file_path_warp(old_remote_file) + new_remote_file = self.file_path_warp(new_remote_file) + await to_thread(move, str(old_remote_file), str(new_remote_file), copy_function=copy2) + logger.success(f"Successfully moved file {str(old_remote_file)} to {str(new_remote_file)} via local_storage.") + + @transform_exception("remote") + async def delete(self, + remote_file: "RemoteFilePathType") -> None: + remote_file = self.file_path_warp(remote_file) + await to_thread(os.remove, str(remote_file)) + logger.success(f"Successfully deleted file {str(remote_file)} via local_storage.") + + async def list_files(self, + path: RemoteFilePathType, + pattern: Optional[str] = "*", + batch_max_files: Optional[int] = None, + valid_extensions: Optional[set[str]] = None) \ + -> AsyncGenerator[list[RemoteFilePathType], None]: + local_path = self.file_path_warp(path) + files = [] + for file in glob_local_files(local_path, pattern, valid_extensions): + files.append(file) + if batch_max_files is not None and len(files) == batch_max_files: + yield files + files = [] + if files: + yield files + + async def update_metadata(self, + local_file_metadata: None, + remote_file_metadata: None) -> None: + raise NotImplementedError diff --git a/app/Services/storage/s3_compatible_storage.py b/app/Services/storage/s3_compatible_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..195b751e35240d4f99d3b481b862777057544f9c --- /dev/null +++ b/app/Services/storage/s3_compatible_storage.py @@ -0,0 +1,173 @@ +# pylint now reporting `opendal` as a `no-name-in-module` error, so we need to disable it as a temporary workaround +# Related issue: https://github.com/pylint-dev/pylint/issues/9185 +# Remove below `# pylint` once the issue is resolved +# pylint: disable=import-error,no-name-in-module +import os +import urllib.parse +from pathlib import PurePosixPath +from typing import Optional, AsyncGenerator + +import aiofiles +from loguru import logger +from opendal import AsyncOperator +from opendal.exceptions import NotFound, PermissionDenied, AlreadyExists +from wcmatch import glob + +from app.Services.storage.base import BaseStorage, FileMetaDataT, RemoteFilePathType, LocalFilePathType, \ + LocalFileMetaDataType, RemoteFileMetaDataType +from app.Services.storage.exception import LocalFileNotFoundError, RemoteFileNotFoundError, RemoteFilePermissionError, \ + RemoteFileExistsError +from app.config import config +from app.util.local_file_utility import VALID_IMAGE_EXTENSIONS + + +def transform_exception(func): + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except FileNotFoundError as ex: + raise LocalFileNotFoundError from ex + except NotFound as ex: + raise RemoteFileNotFoundError from ex + except PermissionDenied as ex: + raise RemoteFilePermissionError from ex + except AlreadyExists as ex: + raise RemoteFileExistsError from ex + + return wrapper + + +class S3Storage(BaseStorage[FileMetaDataT: None]): + def __init__(self): + super().__init__() + + # Paths + self.static_dir = PurePosixPath(config.storage.s3.path) + self.thumbnails_dir = self.static_dir / "thumbnails" + self.deleted_dir = self.static_dir / "_deleted" + + self.file_metadata = None + self.bucket = config.storage.s3.bucket + self.region = config.storage.s3.region + self.endpoint = config.storage.s3.endpoint_url + + self.op = AsyncOperator("s3", + root=str(self.static_dir), + bucket=self.bucket, + region=self.region, + endpoint=self.endpoint, + access_key_id=config.storage.s3.access_key_id, + secret_access_key=config.storage.s3.secret_access_key) + + self._file_path_str_warp = lambda x: str(PurePosixPath(x)) + + @staticmethod + def _file_path_str_wrap(p: RemoteFilePathType): + return str(PurePosixPath(p)) + + async def is_exist(self, + remote_file: "RemoteFilePathType") -> bool: + try: + # the easiest way to confirm the existence of a file + await self.op.stat(self._file_path_str_warp(remote_file)) + return True + except NotFound: + return False + + @transform_exception + async def size(self, + remote_file: "RemoteFilePathType") -> int: + _stat = await self.op.stat(self._file_path_str_warp(remote_file)) + return _stat.content_length + + @transform_exception + async def url(self, + remote_file: "RemoteFilePathType") -> str: + return f"{self._res_endpoint}/{str(self.static_dir)}/{str(remote_file)}" + + @transform_exception + async def presign_url(self, + remote_file: "RemoteFilePathType", + expire_second: int = 3600) -> str: + _presign = await self.op.presign_read(self._file_path_str_warp(remote_file), expire_second) + return _presign.url + + @transform_exception + async def fetch(self, + remote_file: "RemoteFilePathType") -> bytes: + with await self.op.read(self._file_path_str_warp(remote_file)) as f: + return bytes(f) + + @transform_exception + async def upload(self, + local_file: "LocalFilePathType", + remote_file: "RemoteFilePathType") -> None: + if isinstance(local_file, bytes): + b = local_file + else: + async with aiofiles.open(local_file, "rb") as f: + b = await f.read() + await self.op.write(self._file_path_str_warp(remote_file), b) + local_file = f"{len(local_file)} bytes" if isinstance(local_file, bytes) else local_file + logger.success(f"Successfully uploaded file {str(local_file)} to {str(remote_file)} via s3_storage.") + + @transform_exception + async def copy(self, + old_remote_file: "RemoteFilePathType", + new_remote_file: "RemoteFilePathType") -> None: + await self.op.copy(self._file_path_str_warp(old_remote_file), self._file_path_str_warp(new_remote_file)) + logger.success(f"Successfully copied file {str(old_remote_file)} to {str(new_remote_file)} via s3_storage.") + + @transform_exception + async def move(self, + old_remote_file: "RemoteFilePathType", + new_remote_file: "RemoteFilePathType") -> None: + await self.op.copy(self._file_path_str_warp(old_remote_file), self._file_path_str_warp(new_remote_file)) + await self.op.delete(self._file_path_str_warp(old_remote_file)) + logger.success(f"Successfully moved file {str(old_remote_file)} to {str(new_remote_file)} via s3_storage.") + + @transform_exception + async def delete(self, + remote_file: "RemoteFilePathType") -> None: + await self.op.delete(self._file_path_str_warp(remote_file)) + logger.success(f"Successfully deleted file {str(remote_file)} via s3_storage.") + + async def list_files(self, + path: RemoteFilePathType, + pattern: Optional[str] = "*", + batch_max_files: Optional[int] = None, + valid_extensions: Optional[set[str]] = None) \ + -> AsyncGenerator[list[RemoteFilePathType], None]: + if valid_extensions is None: + valid_extensions = VALID_IMAGE_EXTENSIONS + files = [] + # In opendal, current path should be "" instead of "." + _path = "" if self._file_path_str_warp(path) == "." else self._file_path_str_warp(path) + async for itm in await self.op.scan(_path): + if self._list_files_check(itm.path, pattern, valid_extensions): + files.append(PurePosixPath(itm.path)) + if batch_max_files is not None and len(files) == batch_max_files: + yield files + files = [] + if files: + yield files + + async def update_metadata(self, + local_file_metadata: "LocalFileMetaDataType", + remote_file_metadata: "RemoteFileMetaDataType") -> None: + raise NotImplementedError + + @staticmethod + def _list_files_check(x: str, pattern: str, valid_extensions: Optional[set[str]] = None) -> bool: + matches_pattern = glob.globmatch(x, pattern, flags=glob.GLOBSTAR) + has_valid_extension = os.path.splitext(x)[-1] in valid_extensions + is_not_directory = not x.endswith("/") + return matches_pattern and has_valid_extension and is_not_directory + + @property + def _res_endpoint(self): + parsed_url = urllib.parse.urlparse(self.endpoint) + # If the endpoint is a subdomain of the bucket, then the endpoint is already resolved. + if self.bucket in parsed_url.netloc.split('.'): + return self.endpoint + return f"{self.endpoint}/{self.bucket}" diff --git a/app/Services/transformers_service.py b/app/Services/transformers_service.py new file mode 100644 index 0000000000000000000000000000000000000000..2048c254911cfd3732a5c875cb692bd03951aa18 --- /dev/null +++ b/app/Services/transformers_service.py @@ -0,0 +1,70 @@ +from time import time + +import numpy as np +import torch +from PIL import Image +from loguru import logger +from numpy import ndarray +from torch import FloatTensor, no_grad +from transformers import CLIPProcessor, CLIPModel, BertTokenizer, BertModel + +from app.Services.lifespan_service import LifespanService +from app.config import config + + +class TransformersService(LifespanService): + def __init__(self): + self.device = config.device + if self.device == "auto": + self.device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info("Using device: {}; CLIP Model: {}, BERT Model: {}", + self.device, config.model.clip, config.model.bert) + self._clip_model = CLIPModel.from_pretrained(config.model.clip).to(self.device) + self._clip_processor = CLIPProcessor.from_pretrained(config.model.clip) + logger.success("CLIP Model loaded successfully") + if config.ocr_search.enable: + self._bert_model = BertModel.from_pretrained(config.model.bert).to(self.device) + self._bert_tokenizer = BertTokenizer.from_pretrained(config.model.bert) + logger.success("BERT Model loaded successfully") + else: + logger.info("OCR search is disabled. Skipping BERT model loading.") + + @no_grad() + def get_image_vector(self, image: Image.Image) -> ndarray: + if image.mode != "RGB": + image = image.convert("RGB") + logger.info("Processing image...") + start_time = time() + inputs = self._clip_processor(images=image, return_tensors="pt").to(self.device) + logger.success("Image processed, now Inferring with CLIP model...") + outputs: FloatTensor = self._clip_model.get_image_features(**inputs) + logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time) + outputs /= outputs.norm(dim=-1, keepdim=True) + return outputs.numpy(force=True).reshape(-1) + + @no_grad() + def get_text_vector(self, text: str) -> ndarray: + logger.info("Processing text...") + start_time = time() + inputs = self._clip_processor(text=text, return_tensors="pt").to(self.device) + logger.success("Text processed, now Inferring with CLIP model...") + outputs: FloatTensor = self._clip_model.get_text_features(**inputs) + logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time) + outputs /= outputs.norm(dim=-1, keepdim=True) + return outputs.numpy(force=True).reshape(-1) + + @no_grad() + def get_bert_vector(self, text: str) -> ndarray: + start_time = time() + logger.info("Inferring with BERT model...") + inputs = self._bert_tokenizer(text.strip().lower(), return_tensors="pt", truncation=True).to(self.device) + outputs = self._bert_model(**inputs) + vector = outputs.last_hidden_state.mean(dim=1).squeeze() + logger.success("BERT inference done. Time elapsed: {:.2f}s", time() - start_time) + return vector.cpu().numpy() + + @staticmethod + def get_random_vector(seed: int | None = None) -> ndarray: + generator = np.random.default_rng(seed) + vec = generator.uniform(-1, 1, 768) + return vec diff --git a/app/Services/upload_service.py b/app/Services/upload_service.py new file mode 100644 index 0000000000000000000000000000000000000000..51b59fb7fde2d3ee1e3cbbc59d4997b1cb90409b --- /dev/null +++ b/app/Services/upload_service.py @@ -0,0 +1,108 @@ +import asyncio +import gc +import io +import pathlib +from io import BytesIO + +from PIL import Image +from loguru import logger + +from app.Models.api_models.admin_query_params import UploadImageThumbnailMode +from app.Models.errors import PointDuplicateError +from app.Models.img_data import ImageData +from app.Services.index_service import IndexService +from app.Services.lifespan_service import LifespanService +from app.Services.storage import StorageService +from app.Services.vector_db_context import VectorDbContext +from app.config import config +from app.util.generate_uuid import generate_uuid + + +class UploadService(LifespanService): + def __init__(self, storage_service: StorageService, db_context: VectorDbContext, index_service: IndexService): + self._storage_service = storage_service + self._db_context = db_context + self._index_service = index_service + + self._queue = asyncio.Queue(config.admin_index_queue_max_length) + self._upload_worker_task = asyncio.create_task(self._upload_worker()) + + self.uploading_ids = set() + self._processed_count = 0 + + async def _upload_worker(self): + while True: + img_data, *args = await self._queue.get() + try: + await self._upload_task(img_data, *args) + logger.success("Image {} uploaded and indexed. Queue Length: {} [-1]", img_data.id, self._queue.qsize()) + except Exception as ex: + logger.error("Error occurred while uploading image {}", img_data.id) + logger.exception(ex) + finally: + self._queue.task_done() + self.uploading_ids.remove(img_data.id) + self._processed_count += 1 + if self._processed_count % 50 == 0: + gc.collect() + + async def _upload_task(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool, + thumbnail_mode: UploadImageThumbnailMode): + img = Image.open(BytesIO(img_bytes)) + logger.info('Start indexing image {}. Local: {}. Size: {}', img_data.id, img_data.local, len(img_bytes)) + file_name = f"{img_data.id}.{img_data.format}" + thumb_path = f"thumbnails/{img_data.id}.webp" + gen_thumb = thumbnail_mode == UploadImageThumbnailMode.ALWAYS or ( + thumbnail_mode == UploadImageThumbnailMode.IF_NECESSARY and len(img_bytes) > 1024 * 500) + + if img_data.local: + img_data.url = await self._storage_service.active_storage.url(file_name) + if gen_thumb: + img_data.thumbnail_url = await self._storage_service.active_storage.url( + f"thumbnails/{img_data.id}.webp") + img_data.local_thumbnail = True + + await self._index_service.index_image(img, img_data, skip_ocr=skip_ocr, background=True) + logger.success("Image {} indexed.", img_data.id) + + if img_data.local: + logger.info("Start uploading image {} to local storage.", img_data.id) + await self._storage_service.active_storage.upload(img_bytes, file_name) + logger.success("Image {} uploaded to local storage.", img_data.id) + if gen_thumb: + logger.info("Start generate and upload thumbnail for {}.", img_data.id) + img.thumbnail((256, 256), resample=Image.Resampling.LANCZOS) + img_byte_arr = BytesIO() + img.save(img_byte_arr, 'WebP', save_all=True) + await self._storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path) + logger.success("Thumbnail for {} generated and uploaded!", img_data.id) + + img.close() + + async def queue_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool, + thumbnail_mode: UploadImageThumbnailMode): + self.uploading_ids.add(img_data.id) + await self._queue.put((img_data, img_bytes, skip_ocr, thumbnail_mode)) + logger.success("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize()) + + async def assign_image_id(self, img_file: pathlib.Path | io.BytesIO | bytes): + img_id = generate_uuid(img_file) + # check for duplicate points + if img_id in self.uploading_ids or len(await self._db_context.validate_ids([str(img_id)])) != 0: + logger.warning("Duplicate upload request for image id: {}", img_id) + raise PointDuplicateError(f"The uploaded point is already contained in the database! entity id: {img_id}", + img_id) + return img_id + + async def sync_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool, + thumbnail_mode: UploadImageThumbnailMode): + await self._upload_task(img_data, img_bytes, skip_ocr, thumbnail_mode) + + def get_queue_size(self): + return self._queue.qsize() + + async def on_exit(self): # pragma: no cover Hard to test in UT. + if self.get_queue_size() != 0: + logger.warning("There are still {} images in the upload queue. Waiting for upload process to be completed.", + self.get_queue_size()) + await self._queue.join() diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1c22b4495d48897b9ae80f655b1156a8899747 --- /dev/null +++ b/app/Services/vector_db_context.py @@ -0,0 +1,334 @@ +from typing import Optional + +import numpy +from grpc.aio import AioRpcError +from httpx import HTTPError +from loguru import logger +from qdrant_client import AsyncQdrantClient +from qdrant_client.http import models +from qdrant_client.models import RecommendStrategy + +from app.Models.api_models.search_api_model import SearchModelEnum, SearchBasisEnum +from app.Models.img_data import ImageData +from app.Models.query_params import FilterParams +from app.Models.search_result import SearchResult +from app.Services.lifespan_service import LifespanService +from app.config import config, QdrantMode +from app.util.retry_deco_async import wrap_object, retry_async + + +class PointNotFoundError(ValueError): + def __init__(self, point_id: str): + self.point_id = point_id + super().__init__(f"Point {point_id} not found.") + + +class VectorDbContext(LifespanService): + IMG_VECTOR = "image_vector" + TEXT_VECTOR = "text_contain_vector" + AVAILABLE_POINT_TYPES = models.Record | models.ScoredPoint | models.PointStruct + + def __init__(self): + match config.qdrant.mode: + case QdrantMode.SERVER: + self._client = AsyncQdrantClient(host=config.qdrant.host, port=config.qdrant.port, + grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key, + prefer_grpc=config.qdrant.prefer_grpc) + wrap_object(self._client, retry_async((AioRpcError, HTTPError))) + case QdrantMode.LOCAL: + self._client = AsyncQdrantClient(path=config.qdrant.local_path) + case QdrantMode.MEMORY: + logger.warning("Using in-memory Qdrant client. Data will be lost after application restart. " + "This should only be used for testing and debugging.") + self._client = AsyncQdrantClient(":memory:") + case _: + raise ValueError("Invalid Qdrant mode.") + self.collection_name = config.qdrant.coll + + async def on_load(self): + if not await self.check_collection(): + logger.warning("Collection not found. Initializing...") + await self.initialize_collection() + + async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData: + """ + Retrieve an item from database by id. Will raise PointNotFoundError if the given ID doesn't exist. + :param image_id: The ID to retrieve. + :param with_vectors: Whether to retrieve vectors. + :return: The retrieved item. + """ + logger.info("Retrieving item {} from database...", image_id) + result = await self._client.retrieve(collection_name=self.collection_name, + ids=[image_id], + with_payload=True, + with_vectors=with_vectors) + if len(result) != 1: + logger.error("Point not exist.") + raise PointNotFoundError(image_id) + return self._get_img_data_from_point(result[0]) + + async def retrieve_by_ids(self, image_id: list[str], with_vectors=False) -> list[ImageData]: + """ + Retrieve items from the database by IDs. + An exception is thrown if there are items in the IDs that do not exist in the database. + :param image_id: The list of IDs to retrieve. + :param with_vectors: Whether to retrieve vectors. + :return: The list of retrieved items. + """ + logger.info("Retrieving {} items from database...", len(image_id)) + result = await self._client.retrieve(collection_name=self.collection_name, + ids=image_id, + with_payload=True, + with_vectors=with_vectors) + result_point_ids = {t.id for t in result} + missing_point_ids = set(image_id) - result_point_ids + if len(missing_point_ids) > 0: + logger.error("{} points not exist.", len(missing_point_ids)) + raise PointNotFoundError(str(missing_point_ids)) + return self._get_img_data_from_points(result) + + async def validate_ids(self, image_id: list[str]) -> list[str]: + """ + Validate a list of IDs. Will return a list of valid IDs. + :param image_id: The list of IDs to validate. + :return: The list of valid IDs. + """ + logger.info("Validating {} items from database...", len(image_id)) + result = await self._client.retrieve(collection_name=self.collection_name, + ids=image_id, + with_payload=False, + with_vectors=False) + return [t.id for t in result] + + async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR, + top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]: + logger.info("Querying Qdrant... top_k = {}", top_k) + result = await self._client.search(collection_name=self.collection_name, + query_vector=(query_vector_name, query_vector), + query_filter=self._get_filters_by_filter_param(filter_param), + limit=top_k, + offset=skip, + with_payload=True) + logger.success("Query completed!") + return [self._get_search_result_from_scored_point(t) for t in result] + + async def querySimilar(self, + query_vector_name: str = IMG_VECTOR, + search_id: Optional[str] = None, + positive_vectors: Optional[list[numpy.ndarray]] = None, + negative_vectors: Optional[list[numpy.ndarray]] = None, + mode: Optional[SearchModelEnum] = None, + with_vectors: bool = False, + filter_param: FilterParams | None = None, + top_k: int = 10, + skip: int = 0) -> list[SearchResult]: + _positive_vectors = [t.tolist() for t in positive_vectors] if positive_vectors is not None else [search_id] + _negative_vectors = [t.tolist() for t in negative_vectors] if negative_vectors is not None else None + _strategy = None if mode is None else (RecommendStrategy.AVERAGE_VECTOR if + mode == SearchModelEnum.average else RecommendStrategy.BEST_SCORE) + # since only combined_search need return vectors, We can define _combined_search_need_vectors like below + _combined_search_need_vectors = [ + self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.TEXT_VECTOR] if with_vectors else None + logger.info("Querying Qdrant... top_k = {}", top_k) + result = await self._client.recommend(collection_name=self.collection_name, + using=query_vector_name, + positive=_positive_vectors, + negative=_negative_vectors, + strategy=_strategy, + with_vectors=_combined_search_need_vectors, + query_filter=self._get_filters_by_filter_param(filter_param), + limit=top_k, + offset=skip, + with_payload=True) + logger.success("Query completed!") + + return [self._get_search_result_from_scored_point(t) for t in result] + + async def insertItems(self, items: list[ImageData]): + logger.info("Inserting {} items into Qdrant...", len(items)) + + points = [self._get_point_from_img_data(t) for t in items] + + response = await self._client.upsert(collection_name=self.collection_name, + wait=True, + points=points) + logger.success("Insert completed! Status: {}", response.status) + + async def deleteItems(self, ids: list[str]): + logger.info("Deleting {} items from Qdrant...", len(ids)) + response = await self._client.delete(collection_name=self.collection_name, + points_selector=models.PointIdsList( + points=ids + ), + ) + logger.success("Delete completed! Status: {}", response.status) + + async def updatePayload(self, new_data: ImageData): + """ + Update the payload of an existing item in the database. + Warning: This method will not update the vector of the item. + :param new_data: The new data to update. + """ + response = await self._client.set_payload(collection_name=self.collection_name, + payload=new_data.payload, + points=[str(new_data.id)], + wait=True) + logger.success("Update completed! Status: {}", response.status) + + async def updateVectors(self, new_points: list[ImageData]): + resp = await self._client.update_vectors(collection_name=self.collection_name, + points=[self._get_vector_from_img_data(t) for t in new_points], + ) + logger.success("Update vectors completed! Status: {}", resp.status) + + async def scroll_points(self, + from_id: str | None = None, + count=50, + with_vectors=False, + filter_param: FilterParams | None = None, + ) -> tuple[list[ImageData], str]: + resp, next_id = await self._client.scroll(collection_name=self.collection_name, + limit=count, + offset=from_id, + with_vectors=with_vectors, + scroll_filter=self._get_filters_by_filter_param(filter_param) + ) + + return [self._get_img_data_from_point(t) for t in resp], next_id + + async def get_counts(self, exact: bool) -> int: + resp = await self._client.count(collection_name=self.collection_name, exact=exact) + return resp.count + + async def check_collection(self) -> bool: + resp = await self._client.get_collections() + resp = [t.name for t in resp.collections] + return self.collection_name in resp + + async def initialize_collection(self): + if await self.check_collection(): + logger.warning("Collection already exists. Skip initialization.") + return + logger.info("Initializing database, collection name: {}", self.collection_name) + vectors_config = { + self.IMG_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE), + self.TEXT_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE) + } + await self._client.create_collection(collection_name=self.collection_name, + vectors_config=vectors_config) + logger.success("Collection created!") + + @classmethod + def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors: + vector = {} + if img_data.image_vector is not None: + vector[cls.IMG_VECTOR] = img_data.image_vector.tolist() + if img_data.text_contain_vector is not None: + vector[cls.TEXT_VECTOR] = img_data.text_contain_vector.tolist() + return models.PointVectors( + id=str(img_data.id), + vector=vector + ) + + @classmethod + def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct: + return models.PointStruct( + id=str(img_data.id), + payload=img_data.payload, + vector=cls._get_vector_from_img_data(img_data).vector + ) + + def _get_img_data_from_point(self, point: AVAILABLE_POINT_TYPES) -> ImageData: + return (ImageData + .from_payload(point.id, + point.payload, + image_vector=numpy.array(point.vector[self.IMG_VECTOR], dtype=numpy.float32) + if point.vector and self.IMG_VECTOR in point.vector else None, + text_contain_vector=numpy.array(point.vector[self.TEXT_VECTOR], dtype=numpy.float32) + if point.vector and self.TEXT_VECTOR in point.vector else None + )) + + def _get_img_data_from_points(self, points: list[AVAILABLE_POINT_TYPES]) -> list[ImageData]: + return [self._get_img_data_from_point(t) for t in points] + + def _get_search_result_from_scored_point(self, point: models.ScoredPoint) -> SearchResult: + return SearchResult(img=self._get_img_data_from_point(point), score=point.score) + + @classmethod + def vector_name_for_basis(cls, basis: SearchBasisEnum) -> str: + match basis: + case SearchBasisEnum.vision: + return cls.IMG_VECTOR + case SearchBasisEnum.ocr: + return cls.TEXT_VECTOR + case _: + raise ValueError("Invalid basis") + + @staticmethod + def _get_filters_by_filter_param(filter_param: FilterParams | None) -> models.Filter | None: + if filter_param is None: + return None + + filters = [] + neg_filter = [] + if filter_param.min_width is not None and filter_param.min_width > 0: + filters.append(models.FieldCondition( + key="width", + range=models.Range( + gte=filter_param.min_width + ) + )) + + if filter_param.min_height is not None and filter_param.min_height > 0: + filters.append(models.FieldCondition( + key="height", + range=models.Range( + gte=filter_param.min_height + ) + )) + + if filter_param.min_ratio is not None: + filters.append(models.FieldCondition( + key="aspect_ratio", + range=models.Range( + gte=filter_param.min_ratio, + lte=filter_param.max_ratio + ) + )) + + if filter_param.starred is not None: + filters.append(models.FieldCondition( + key="starred", + match=models.MatchValue( + value=filter_param.starred + ) + )) + + if filter_param.ocr_text is not None: + filters.append(models.FieldCondition( + key="ocr_text_lower", + match=models.MatchText( + text=filter_param.ocr_text.lower() + ) + )) + + if filter_param.categories is not None: + filters.append(models.FieldCondition( + key="categories", + match=models.MatchAny( + any=filter_param.categories + ) + )) + + if filter_param.categories_negative is not None: + neg_filter.append(models.FieldCondition( + key="categories", + match=models.MatchAny(any=filter_param.categories_negative), + )) + + if not filters and not neg_filter: + return None + return models.Filter( + must=filters, + must_not=neg_filter + ) diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3ecc3cfdc7c516ac7b1ca1d479ea9aee14a0559a --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,6 @@ +__title__ = 'NekoImageGallery' +__description__ = 'An AI-powered natural language & reverse Image Search Engine powered by CLIP & qdrant.' +__version__ = '1.2.0' +__author__ = 'EdgeNeko; pk5ls20' +__author_email__ = 'service@edgeneko.com' +__url__ = 'https://github.com/hv0905/NekoImageGallery' diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4c8afdfdebd046de3b6d010b3f5b249152c1600b --- /dev/null +++ b/app/config.py @@ -0,0 +1,113 @@ +import os +from enum import Enum + +from loguru import logger +from pydantic import BaseModel +from pydantic_settings import BaseSettings, SettingsConfigDict + +DOCKER_SECRETS_DIR = '/run/secrets' + + +class QdrantMode(str, Enum): + SERVER = 'server' + LOCAL = 'local' + MEMORY = 'memory' + + +class QdrantSettings(BaseModel): + mode: QdrantMode = QdrantMode.SERVER + + host: str = 'localhost' + port: int = 6333 + grpc_port: int = 6334 + coll: str = 'NekoImg' + prefer_grpc: bool = True + api_key: str | None = None + + local_path: str = './images_metadata' + + +class ModelsSettings(BaseModel): + clip: str = 'openai/clip-vit-large-patch14' + bert: str = 'bert-base-chinese' + easypaddleocr: str | None = None + + +class OCRSearchSettings(BaseModel): + enable: bool = True + ocr_module: str = 'easypaddleocr' + ocr_language: list[str] = ['ch_sim', 'en'] + ocr_min_confidence: float = 1e-2 + + +class S3StorageSettings(BaseModel): + path: str = "./static" + bucket: str | None = None + region: str | None = None + endpoint_url: str | None = None + access_key_id: str | None = None + secret_access_key: str | None = None + session_token: str | None = None + + +class LocalStorageSettings(BaseModel): + path: str = './static' + + +class StorageMode(str, Enum): + LOCAL = 'local' + S3 = 's3' + DISABLED = 'disabled' + + @property + def enabled(self): + return self != StorageMode.DISABLED + + +class StorageSettings(BaseModel): + method: StorageMode = StorageMode.LOCAL + s3: S3StorageSettings = S3StorageSettings() + local: LocalStorageSettings = LocalStorageSettings() + + +# [Deprecated] +class StaticFileSettings(BaseModel): + path: str = '[DEPRECATED]' + enable: bool = True # Deprecated + + +class Config(BaseSettings): + qdrant: QdrantSettings = QdrantSettings() + model: ModelsSettings = ModelsSettings() + ocr_search: OCRSearchSettings = OCRSearchSettings() + static_file: StaticFileSettings = StaticFileSettings() # [Deprecated] + storage: StorageSettings = StorageSettings() + + device: str = 'auto' + cors_origins: set[str] = {'*'} + admin_api_enable: bool = False + admin_token: str = '' + admin_index_queue_max_length: int = 200 + + access_protected: bool = False + access_token: str = '' + + model_config = SettingsConfigDict(env_prefix="app_", env_nested_delimiter='__', + env_file=('config/default.env', 'config/local.env'), + env_file_encoding='utf-8', + secrets_dir=DOCKER_SECRETS_DIR if os.path.exists( + DOCKER_SECRETS_DIR) else None) # for docker secret + + +class Environment(BaseModel): + local_indexing: bool = False + + +def _check_deprecated_settings(_config): + if _config.static_file.path != '[DEPRECATED]': + logger.warning("Config StaticFileSettings is deprecated and should not be set.") + + +config = Config() +environment = Environment() +_check_deprecated_settings(config) diff --git a/app/util/calculate_vectors_cosine.py b/app/util/calculate_vectors_cosine.py new file mode 100644 index 0000000000000000000000000000000000000000..a31288f00d85b90717204b510f111956a171dc01 --- /dev/null +++ b/app/util/calculate_vectors_cosine.py @@ -0,0 +1,5 @@ +import numpy as np + + +def calculate_vectors_cosine(image_vector, text_vector): + return np.dot(image_vector, text_vector) / (np.linalg.norm(image_vector) * np.linalg.norm(text_vector)) diff --git a/app/util/fastapi_log_handler.py b/app/util/fastapi_log_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ef4a3195a3b952182bb7de417a0537ce4d5d8edc --- /dev/null +++ b/app/util/fastapi_log_handler.py @@ -0,0 +1,46 @@ +import logging + +from loguru import logger + + +class InterceptHandler(logging.Handler): # pragma: no cover Hard to test in test environments + + def emit(self, record: logging.LogRecord): + # Get corresponding Loguru level if it exists + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + # Find caller from where originated the logged message + frame, depth = logging.currentframe(), 2 + while frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + + +def init_logging(): + """ + Replaces logging handlers with a handler for using the custom handler. + """ + + # disable handlers for specific uvicorn loggers + # to redirect their output to the default uvicorn logger + # works with uvicorn==0.11.6 + intercept_handler = InterceptHandler() + loggers = ( + logging.getLogger(name) + for name in logging.root.manager.loggerDict + if name.startswith("uvicorn.") + ) + for uvicorn_logger in loggers: + uvicorn_logger.handlers = [intercept_handler] + + # change handler for default uvicorn logger + + # logging.getLogger("uvicorn").handlers = [intercept_handler] + logging.getLogger("uvicorn").handlers = [] diff --git a/app/util/generate_uuid.py b/app/util/generate_uuid.py new file mode 100644 index 0000000000000000000000000000000000000000..28f96c30b7383eafbf71556436a07073c317fd13 --- /dev/null +++ b/app/util/generate_uuid.py @@ -0,0 +1,26 @@ +import hashlib +import io +import pathlib +from uuid import UUID, uuid5, NAMESPACE_DNS + +NAMESPACE_STR = 'github.com/hv0905/NekoImageGallery' +namespace_uuid = uuid5(NAMESPACE_DNS, NAMESPACE_STR) + + +def generate_uuid(file_input: pathlib.Path | io.BytesIO | bytes) -> UUID: + if isinstance(file_input, pathlib.Path): + with open(file_input, 'rb') as f: + file_content = f.read() + elif isinstance(file_input, io.BytesIO): + file_input.seek(0) + file_content = file_input.read() + elif isinstance(file_input, bytes): + file_content = file_input + else: + raise ValueError("Unsupported file type. Must be pathlib.Path or io.BytesIO.") + file_hash = hashlib.sha1(file_content).hexdigest() + return generate_uuid_from_sha1(file_hash) + + +def generate_uuid_from_sha1(sha1_hash: str) -> UUID: + return uuid5(namespace_uuid, sha1_hash.lower()) diff --git a/app/util/local_file_utility.py b/app/util/local_file_utility.py new file mode 100644 index 0000000000000000000000000000000000000000..389dcc9d2e82e5b9104f7ec6b1eb9dcd5cb7410a --- /dev/null +++ b/app/util/local_file_utility.py @@ -0,0 +1,12 @@ +from pathlib import Path + +VALID_IMAGE_EXTENSIONS = {'.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif'} + + +def glob_local_files(path: Path, pattern: str = "*", valid_extensions: set[str] = None): + if valid_extensions is None: + valid_extensions = VALID_IMAGE_EXTENSIONS + + for file in path.glob(pattern): + if file.suffix.lower() in valid_extensions: + yield file diff --git a/app/util/retry_deco_async.py b/app/util/retry_deco_async.py new file mode 100644 index 0000000000000000000000000000000000000000..c70535010a95193c0a136fe37830bfc9005fba84 --- /dev/null +++ b/app/util/retry_deco_async.py @@ -0,0 +1,31 @@ +import asyncio +import functools +from typing import Callable + +from loguru import logger + + +def retry_async(exceptions=Exception, tries=3, delay=0) -> Callable[[Callable], Callable]: + def deco_retry(f): + @functools.wraps(f) + async def f_retry(*args, **kwargs): + m_tries, m_delay = tries, delay + while m_tries > 1: + try: + return await f(*args, **kwargs) + except exceptions as e: + logger.warning(f"{e}, Retrying in {m_delay} seconds...") + if m_delay > 0: + await asyncio.sleep(m_delay) + m_tries -= 1 + return await f(*args, **kwargs) + + return f_retry + + return deco_retry + + +def wrap_object(obj: object, deco: Callable[[Callable], Callable]): + for attr in dir(obj): + if not attr.startswith('_') and asyncio.iscoroutinefunction(attr_val := getattr(obj, attr)): + setattr(obj, attr, deco(attr_val)) diff --git a/app/webapp.py b/app/webapp.py new file mode 100644 index 0000000000000000000000000000000000000000..3aec3a4b06ace7abb9436ca80a83d6194d6fc962 --- /dev/null +++ b/app/webapp.py @@ -0,0 +1,77 @@ +import pathlib +from contextlib import asynccontextmanager +from datetime import datetime +from typing import Annotated + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.params import Depends +from fastapi.staticfiles import StaticFiles + +import app +import app.Controllers.admin as admin_controller +import app.Controllers.images as images_controller +import app.Controllers.search as search_controller +from app.Services.authentication import permissive_access_token_verify, permissive_admin_token_verify +from app.Services.provider import ServiceProvider +from app.config import config +from .Models.api_response.base import WelcomeApiResponse, WelcomeApiAuthenticationResponse, \ + WelcomeApiAdminPortalAuthenticationResponse +from .util.fastapi_log_handler import init_logging + + +@asynccontextmanager +async def lifespan(_: FastAPI): + provider = ServiceProvider() + await provider.onload() + + search_controller.services = provider + admin_controller.services = provider + images_controller.services = provider + yield + + await provider.onexit() + + +app = FastAPI(lifespan=lifespan, title=app.__title__, description=app.__description__, version=app.__version__) +init_logging() + +# noinspection PyTypeChecker +app.add_middleware( + CORSMiddleware, + allow_origins=config.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(search_controller.search_router, prefix="/search") +app.include_router(images_controller.images_router, prefix="/images") +if config.admin_api_enable: + app.include_router(admin_controller.admin_router, prefix="/admin") + +if config.storage.method == "local": + # Since we will check & create the static directory soon later when the StorageService initialized, we don't need to + # check it here. + app.mount("/static", StaticFiles(directory=pathlib.Path(config.storage.local.path), check_dir=False), name="static") + + +@app.get("/", description="Default portal. Test for server availability.") +def welcome(request: Request, + token_passed: Annotated[bool, Depends(permissive_access_token_verify)], + admin_token_passed: Annotated[bool, Depends(permissive_admin_token_verify)], + ) -> WelcomeApiResponse: + root_path: str = request.scope.get('root_path').rstrip('/') + return WelcomeApiResponse( + message="Ciallo~ Welcome to NekoImageGallery API!", + server_time=datetime.now(), + wiki={ + "openAPI": f"{root_path}/openapi.json", + "swagger UI": f"{root_path}/docs", + "redoc": f"{root_path}/redoc" + }, + admin_api=WelcomeApiAdminPortalAuthenticationResponse(available=config.admin_api_enable, + passed=admin_token_passed), + authorization=WelcomeApiAuthenticationResponse(required=config.access_protected, passed=token_passed), + available_basis=["vision", "ocr"] if config.ocr_search.enable else ["vision"] + ) diff --git a/config/.gitignore b/config/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..104922e03269e75ff5f385959d50b4f7029a2343 --- /dev/null +++ b/config/.gitignore @@ -0,0 +1 @@ +local.env \ No newline at end of file diff --git a/config/default.env b/config/default.env new file mode 100644 index 0000000000000000000000000000000000000000..ca3a940519a79fa16e3299e31268f39c3fdff82b --- /dev/null +++ b/config/default.env @@ -0,0 +1,111 @@ +# This is an example configuration file for the app. +# All the value below are the default values. To change the value, uncomment the line and set the value you want. +# You can also use environment variables or docker secrets to set these values (the key should correspond to the key below). +# Checkout https://docs.pydantic.dev/latest/concepts/pydantic_settings/ for more information. + +# ------ +# Vector Database Configuration +# ------ +# Mode for the vector database, options includes "server" (default), "local" and "memory" +# - server: The preferred mode, uses Qdrant server for vector storage. +# - local: Store vectors as a file on the local disk, this is not recommended for production use (see readme for more information) +# - memory: Uses in-memory storage for vector storage, this is not persistent and should only be used for testing and debugging. +# APP_QDRANT__MODE=server + +# Remote Qdrant Server Configuration +# Hostname or IP address of the Qdrant server +# APP_QDRANT__HOST="localhost" +# Port number for the Qdrant HTTP server +# APP_QDRANT__PORT=6333 +# Port number for the Qdrant gRPC server +# APP_QDRANT__GRPC_PORT=6334 +# Set to True if you want to use gRPC for qdrant connection instead of HTTP +# APP_QDRANT__PREFER_GRPC=True +# Set your API key here if you have set one, otherwise leave it None +# APP_QDRANT__API_KEY= +# Collection name to use in Qdrant +# APP_QDRANT__COLL="NekoImg" + +# Local Qdrant File Configuration +# Path to the file where vectors will be stored +# APP_QDRANT__LOCAL_PATH="./images_metadata" + + +# ------ +# Server Configuration +# ------ +# Specify device to be used while inferencing vectors by PyTorch. Setting this to "auto" allows the system to automatically detect and use available devices, otherwise specify the device name +# APP_DEVICE="auto" +# List of allowed origins for CORS (Cross-Origin Resource Sharing) +# APP_CORS_ORIGINS=["*"] + + +# ------ +# Models Configuration +# ------ +# Model used for CLIP embeddings (Vision Search), accepts both huggingface hub (transformers) model name and path to the model. +# APP_MODEL__CLIP="openai/clip-vit-large-patch14" +# Model used for BERT embeddings (OCR Search), accepts both huggingface hub (transformers) model name and path to the model. +# APP_MODEL__BERT="bert-base-chinese" +# Model used for easypaddocr inference (OCR indexing), accepts path to the model. Leave it blank will download automatically from huggingface hub. +# APP_MODEL__EASYPADDLEOCR="" + + +# ------ +# OCR Search Configuration +# ------ +# Enable OCR search functionality +# APP_OCR_SEARCH__ENABLE=True +# OCR module to use for text extraction +# APP_OCR_SEARCH__OCR_MODULE="easypaddleocr" +# Minimum confidence level required for OCR results to be considered +# APP_OCR_SEARCH__OCR_MIN_CONFIDENCE=1e-2 +# List of languages supported by the OCR module +# APP_OCR_SEARCH__OCR_LANGUAGE=["ch_sim", "en"] + + +# ------ +# Admin API Configuration +# ------ +# Set to True to enable admin API, this allows you to access the admin API using the token specified below. +# APP_ADMIN_API_ENABLE=False +# Uncomment the line below if you enabled admin API. Use this token to access admin API. For security reasons, the admin token is always required if you want to use admin API. +# APP_ADMIN_TOKEN="your-super-secret-admin-token" +# Max length of the upload queue for admin API, higher value means more indexing requests can be queued but also means more memory usage. Upload requests will be blocked when the queue is full. +# APP_ADMIN_INDEX_QUEUE_MAX_LENGTH=200 + + +# ------ +# Access Protection Configuration +# ------ +# Set to True to enable access protection using tokens +# APP_ACCESS_PROTECTED=False +# Use this token to access the API. This is required if you enabled access protection. +# APP_ACCESS_TOKEN="your-super-secret-access-token" + + +# ------ +# Storage Settings +# ------ +# Method for storing files, options includes "local", "s3" and "disabled" +# APP_STORAGE__METHOD="local" + +# Storage Settings - local +# Path where files will be stored locally +# APP_STORAGE__LOCAL__PATH="./static" + +# Storage Settings - S3 +# Name of the S3 bucket +# APP_STORAGE__S3__BUCKET="your-s3-bucket-name" +# Path where files will be stored in the S3 bucket +# APP_STORAGE__S3__PATH="./static" +# Region where the S3 bucket is located +# APP_STORAGE__S3__REGION="your-s3-region" +# Endpoint URL for the S3 service +# APP_STORAGE__S3__ENDPOINT_URL="your-s3-endpoint-url" +# Access key ID for accessing the S3 bucket +# APP_STORAGE__S3__ACCESS_KEY_ID="your-s3-access-key-id" +# Secret access key for accessing the S3 bucket +# APP_STORAGE__S3__SECRET_ACCESS_KEY="your-s3-secret-access-key" +# Session token for accessing the S3 bucket (optional) +# APP_STORAGE__S3__SESSION_TOKEN="your-s3-session-token" diff --git a/cpu-only.Dockerfile b/cpu-only.Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..e05fec3b24f05d62502427a30f9364279ad7e94c --- /dev/null +++ b/cpu-only.Dockerfile @@ -0,0 +1,38 @@ +FROM python:3.11-slim-bookworm + +RUN PYTHONDONTWRITEBYTECODE=1 pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu --no-cache-dir + +WORKDIR /opt/NekoImageGallery + +COPY requirements.txt . + +RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -r requirements.txt + +RUN mkdir -p /opt/models && \ + export PYTHONDONTWRITEBYTECODE=1 && \ + huggingface-cli download openai/clip-vit-large-patch14 'model.safetensors' '*.txt' '*.json' --local-dir /opt/models/clip && \ + huggingface-cli download google-bert/bert-base-chinese 'model.safetensors' '*.txt' '*.json' --local-dir /opt/models/bert && \ + huggingface-cli download pk5ls20/PaddleModel 'PaddleOCR2Pytorch/ch_ptocr_v4_det_infer.pth' 'PaddleOCR2Pytorch/ch_ptocr_v4_rec_infer.pth' \ + 'PaddleOCR2Pytorch/ch_ptocr_mobile_v2.0_cls_infer.pth' 'PaddleOCR2Pytorch/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml' \ + 'PaddleOCR2Pytorch/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml' 'ppocr_keys_v1.txt' --local-dir /opt/models/ocr && \ + rm -rf /root/.cache/huggingface + +ENV APP_MODEL__CLIP=/opt/models/clip +ENV APP_MODEL__BERT=/opt/models/bert +ENV APP_MODEL__EASYPADDLEOCR=/opt/models/ocr + +COPY . . + +EXPOSE 8000 + +VOLUME ["/opt/NekoImageGallery/static"] + +ENV APP_CLIP__DEVICE="cpu" + +LABEL org.opencontainers.image.authors="EdgeNeko" \ + org.opencontainers.image.url="https://github.com/hv0905/NekoImageGallery" \ + org.opencontainers.image.source="https://github.com/hv0905/NekoImageGallery" \ + org.opencontainers.image.title="NekoImageGallery" \ + org.opencontainers.image.description="An AI-powered natural language & reverse Image Search Engine powered by CLIP & qdrant." + +ENTRYPOINT ["python", "main.py"] diff --git a/docker-compose-cpu.yml b/docker-compose-cpu.yml new file mode 100644 index 0000000000000000000000000000000000000000..1363b3c911532f1888f807fc69e51016f307d89f --- /dev/null +++ b/docker-compose-cpu.yml @@ -0,0 +1,28 @@ +services: + qdrant-database: + image: qdrant/qdrant:latest + ports: + - "127.0.0.1:6333:6333" + - "127.0.0.1:6334:6334" + volumes: + - "./qdrant_data:/qdrant/storage:z" + neko-image-gallery: + # Uncomment this section to build image from source code + # build: + # context: . + # dockerfile: cpu-only.Dockerfile + image: edgeneko/neko-image-gallery:latest-cpu + ports: + - "8000:8000" + volumes: + - "./static:/opt/NekoImageGallery/static" + environment: + - APP_QDRANT__HOST=qdrant-database + - APP_QDRANT__PORT=6333 + - APP_QDRANT__GRPC_PORT=6334 + - APP_QDRANT__PREFER_GRPC=True + depends_on: + - qdrant-database +networks: + default: + name: neko-image-gallery diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..099683808bf20649f745aa13e4ca1e2dce91fd1a --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,36 @@ +services: + qdrant-database: + image: qdrant/qdrant:latest + ports: + - "127.0.0.1:6333:6333" + - "127.0.0.1:6334:6334" + volumes: + - "./qdrant_data:/qdrant/storage:z" + neko-image-gallery: + # Uncomment this section to build image from source code + # build: + # context: . + # dockerfile: Dockerfile + image: edgeneko/neko-image-gallery:latest + ports: + - "8000:8000" + volumes: + - "./static:/opt/NekoImageGallery/static" + environment: + - APP_QDRANT__HOST=qdrant-database + - APP_QDRANT__PORT=6333 + - APP_QDRANT__GRPC_PORT=6334 + - APP_QDRANT__PREFER_GRPC=True + depends_on: + - qdrant-database + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: + - gpu +networks: + default: + name: neko-image-gallery diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa22172c5f1564b5cf390d7515e2e817b0714d1 --- /dev/null +++ b/main.py @@ -0,0 +1,110 @@ +import asyncio +from pathlib import Path +from typing import Annotated, Optional + +import rich +import typer +import uvicorn + +import app +from app.Models.api_models.admin_query_params import UploadImageThumbnailMode + +parser = typer.Typer(name=app.__title__, + epilog="Build with ♥ By EdgeNeko. Github: " + "https://github.com/hv0905/NekoImageGallery", + rich_markup_mode='markdown' + ) + + +def version_callback(value: bool): + if value: + print(f"{app.__title__} v{app.__version__}") + raise typer.Exit() + + +@parser.callback(invoke_without_command=True) +def server(ctx: typer.Context, + host: Annotated[str, typer.Option(help='The host to bind on.')] = '0.0.0.0', + port: Annotated[int, typer.Option(help='The port to listen on.')] = 8000, + root_path: Annotated[str, typer.Option( + help='Root path of the server if your server is deployed behind a reverse proxy. See ' + 'https://fastapi.tiangolo.com/advanced/behind-a-proxy/ for detail.')] = '', + _: Annotated[ + Optional[bool], typer.Option("--version", callback=version_callback, is_eager=True, + help="Show version and exit.") + ] = None + ): + """ + Ciallo~ Welcome to NekoImageGallery Server. + + - Website: https://image-insights.edgeneko.com + + - Repository & Issue tracker: https://github.com/hv0905/NekoImageGallery + + + + By default, running without command will start the server. + You can perform other actions by using the commands below. + """ + if ctx.invoked_subcommand is not None: + return + uvicorn.run("app.webapp:app", host=host, port=port, root_path=root_path) + + +@parser.command('show-config') +def show_config(): + """ + Print the current configuration and exit. + """ + from app.config import config + rich.print_json(config.model_dump_json()) + + +@parser.command('init-database') +def init_database(): + """ + Initialize qdrant database using connection settings in configuration. + Note. The server will automatically initialize the database if it's not initialized. So you don't need to run this + command unless you want to explicitly initialize the database. + """ + from scripts import qdrant_create_collection + asyncio.run(qdrant_create_collection.main()) + + +@parser.command("local-index") +def local_index( + target_dir: Annotated[ + list[Path], typer.Argument(dir_okay=True, file_okay=False, exists=True, resolve_path=True, readable=True, + help="Directories you want to index.")], + categories: Annotated[Optional[list[str]], typer.Option(help="Categories for the indexed images.")] = None, + starred: Annotated[bool, typer.Option(help="Whether the indexed images are starred.")] = False, + thumbnail_mode: Annotated[ + UploadImageThumbnailMode, typer.Option( + help="Whether to generate thumbnail for images. Possible values:\n" + "- `if_necessary`:(Recommended) Only generate thumbnail if the image is larger than 500KB.\n" + "- `always`: Always generate thumbnail.\n" + "- `never`: Never generate thumbnail.")] = UploadImageThumbnailMode.IF_NECESSARY +): + """ + Index all the images in the specified directory. + The images will be copied to the local storage directory set in configuration. + """ + from scripts import local_indexing + if categories is None: + categories = [] + asyncio.run(local_indexing.main(target_dir, categories, starred, thumbnail_mode)) + + +@parser.command('local-create-thumbnail', deprecated=True) +def local_create_thumbnail(): + """ + Create thumbnail for all local images in static folder, this won't affect non-local images. + This is generally not required since the server will automatically create thumbnails for new images by default. + This option will be refactored in the future. + """ + from scripts import local_create_thumbnail + asyncio.run(local_create_thumbnail.main()) + + +if __name__ == '__main__': + parser() diff --git a/pylintrc.toml b/pylintrc.toml new file mode 100644 index 0000000000000000000000000000000000000000..7b6b2fc1d0a9989d6a29f5894daa424f910c77c1 --- /dev/null +++ b/pylintrc.toml @@ -0,0 +1,538 @@ +[tool.pylint.main] +# Analyse import fallback blocks. This can be used to support both Python 2 and 3 +# compatible code, which means that the block might have code that exists only in +# one or another interpreter, leading to false positives when analysed. +# analyse-fallback-blocks = + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint in +# a server-like mode. +# clear-cache-post-run = + +# Always return a 0 (non-error) status code, even if lint errors are found. This +# is primarily useful in continuous integration scripts. +# exit-zero = + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +# extension-pkg-allow-list = + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +# extension-pkg-whitelist = + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +# fail-on = + +# Specify a score threshold under which the program will exit with error. +fail-under = 10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +# from-stdin = + +# Files or directories to be skipped. They should be base names, not paths. +ignore = ["CVS"] + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, it +# can't be used as an escape character. +# ignore-paths = + +# Files or directories matching the regular expression patterns are skipped. The +# regex matches against base names, not paths. The default value ignores Emacs +# file locks +ignore-patterns = ["^\\.#"] + +# List of module names for which member attributes should not be checked (useful +# for modules/projects where namespaces are manipulated during runtime and thus +# existing member attributes cannot be deduced by static analysis). It supports +# qualified module names, as well as Unix pattern matching. +# ignored-modules = + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +# init-hook = + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs = 1 + +# Control the amount of potential inferred values when inferring a single object. +# This can help the performance when dealing with large functions or complex, +# nested conditions. +limit-inference-results = 100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +# load-plugins = + +# Pickle collected data for later comparisons. +persistent = true + +# Minimum Python version to use for version dependent checks. Will default to the +# version used to run pylint. +py-version = "3.10" + +# Discover python modules and packages in the file system subtree. +# recursive = + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +# source-roots = + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode = true + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +# unsafe-load-any-extension = + +[tool.pylint.basic] +# Naming style matching correct argument names. +argument-naming-style = "snake_case" + +# Regular expression matching correct argument names. Overrides argument-naming- +# style. If left empty, argument names will be checked with the set naming style. +# argument-rgx = + +# Naming style matching correct attribute names. +attr-naming-style = "snake_case" + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +# attr-rgx = + +# Bad variable names which should always be refused, separated by a comma. +bad-names = ["foo", "bar", "baz", "toto", "tutu", "tata"] + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +# bad-names-rgxs = + +# Naming style matching correct class attribute names. +class-attribute-naming-style = "any" + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +# class-attribute-rgx = + +# Naming style matching correct class constant names. +class-const-naming-style = "UPPER_CASE" + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +# class-const-rgx = + +# Naming style matching correct class names. +class-naming-style = "PascalCase" + +# Regular expression matching correct class names. Overrides class-naming-style. +# If left empty, class names will be checked with the set naming style. +# class-rgx = + +# Naming style matching correct constant names. +const-naming-style = "UPPER_CASE" + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming style. +# const-rgx = + +# Minimum line length for functions/classes that require docstrings, shorter ones +# are exempt. +docstring-min-length = -1 + +# Naming style matching correct function names. +function-naming-style = "snake_case" + +# Regular expression matching correct function names. Overrides function-naming- +# style. If left empty, function names will be checked with the set naming style. +# function-rgx = + +# Good variable names which should always be accepted, separated by a comma. +good-names = ["i", "j", "k", "ex", "Run", "_"] + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +# good-names-rgxs = + +# Include a hint for the correct naming format with invalid-name. +# include-naming-hint = + +# Naming style matching correct inline iteration names. +inlinevar-naming-style = "any" + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +# inlinevar-rgx = + +# Naming style matching correct method names. +method-naming-style = "snake_case" + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +# method-rgx = + +# Naming style matching correct module names. +module-naming-style = "snake_case" + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +# module-rgx = + +# Colon-delimited sets of names that determine each other's naming style when the +# name regexes allow several styles. +# name-group = + +# Regular expression which should only match function or class names that do not +# require a docstring. +no-docstring-rgx = "^_" + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. These +# decorators are taken in consideration only for invalid-name. +property-classes = ["abc.abstractproperty"] + +# Regular expression matching correct type alias names. If left empty, type alias +# names will be checked with the set naming style. +# typealias-rgx = + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +# typevar-rgx = + +# Naming style matching correct variable names. +variable-naming-style = "snake_case" + +# Regular expression matching correct variable names. Overrides variable-naming- +# style. If left empty, variable names will be checked with the set naming style. +# variable-rgx = + +[tool.pylint.classes] +# Warn about protected attribute access inside special methods +# check-protected-access-in-special-methods = + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods = ["__init__", "__new__", "setUp", "asyncSetUp", "__post_init__"] + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected = ["_asdict", "_fields", "_replace", "_source", "_make", "os._exit"] + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg = ["cls"] + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg = ["mcs"] + +[tool.pylint.design] +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +# exclude-too-few-public-methods = + +# List of qualified class names to ignore when counting class parents (see R0901) +# ignored-parents = + +# Maximum number of arguments for function / method. +max-args = 5 + +# Maximum number of attributes for a class (see R0902). +max-attributes = 7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr = 5 + +# Maximum number of branch for function / method body. +max-branches = 12 + +# Maximum number of locals for function / method body. +max-locals = 15 + +# Maximum number of parents for a class (see R0901). +max-parents = 7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods = 20 + +# Maximum number of return / yield for function / method body. +max-returns = 6 + +# Maximum number of statements in function / method body. +max-statements = 50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods = 2 + +[tool.pylint.exceptions] +# Exceptions that will emit a warning when caught. +overgeneral-exceptions = ["builtins.BaseException", "builtins.Exception"] + +[tool.pylint.format] +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +# expected-line-ending-format = + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines = "^\\s*(# )??$" + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren = 4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string = " " + +# Maximum number of characters on a single line. +max-line-length = 120 + +# Maximum number of lines in a module. +max-module-lines = 1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +# single-line-class-stmt = + +# Allow the body of an if to be on the same line as the test if there is no else. +# single-line-if-stmt = + +[tool.pylint.imports] +# List of modules that can be imported at any level, not just the top level one. +# allow-any-import-level = + +# Allow explicit reexports by alias from a package __init__. +# allow-reexport-from-package = + +# Allow wildcard imports from modules that define __all__. +# allow-wildcard-with-all = + +# Deprecated modules which should not be used, separated by a comma. +# deprecated-modules = + +# Output a graph (.gv or any supported image format) of external dependencies to +# the given file (report RP0402 must not be disabled). +# ext-import-graph = + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be disabled). +# import-graph = + +# Output a graph (.gv or any supported image format) of internal dependencies to +# the given file (report RP0402 must not be disabled). +# int-import-graph = + +# Force import order to recognize a module as part of the standard compatibility +# libraries. +# known-standard-library = + +# Force import order to recognize a module as part of a third party library. +known-third-party = ["enchant"] + +# Couples of modules and preferred modules, separated by a comma. +# preferred-modules = + +[tool.pylint.logging] +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style = "old" + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules = ["logging"] + +[tool.pylint."messages control"] +# Only show warnings with the listed confidence levels. Leave empty to show all. +# Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence = ["HIGH", "CONTROL_FLOW", "INFERENCE", "INFERENCE_FAILURE", "UNDEFINED"] + +# Disable the message, report, category or checker with the given id(s). You can +# either give multiple identifiers separated by comma (,) or put this option +# multiple times (only on the command line, not in the configuration file where +# it should appear only once). You can also use "--disable=all" to disable +# everything first and then re-enable specific checks. For example, if you want +# to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable = ["raw-checker-failed", "bad-inline-option", "locally-disabled", "file-ignored", "suppressed-message", "useless-suppression", "deprecated-pragma", "use-symbolic-message-instead", "use-implicit-booleaness-not-comparison-to-string", "use-implicit-booleaness-not-comparison-to-zero", "missing-function-docstring", "missing-class-docstring", "missing-module-docstring"] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where it +# should appear only once). See also the "--disable" option for examples. +# enable = + +[tool.pylint.method_args] +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods = ["requests.api.delete", "requests.api.get", "requests.api.head", "requests.api.options", "requests.api.patch", "requests.api.post", "requests.api.put", "requests.api.request"] + +[tool.pylint.miscellaneous] +# List of note tags to take in consideration, separated by a comma. +notes = ["FIXME", "XXX", "TODO"] + +# Regular expression of note tags to take in consideration. +# notes-rgx = + +[tool.pylint.refactoring] +# Maximum number of nested blocks for function / method body +max-nested-blocks = 5 + +# Complete name of functions that never returns. When checking for inconsistent- +# return-statements if a never returning function is called then it will be +# considered as an explicit return statement and no message will be printed. +never-returning-functions = ["sys.exit", "argparse.parse_error"] + +[tool.pylint.reports] +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each category, +# as well as 'statement' which is the total number of statements analyzed. This +# score is used by the global evaluation report (RP0004). +evaluation = "max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))" + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +# msg-template = + +# Set the output format. Available formats are: text, parseable, colorized, json2 +# (improved json format), json (old json format) and msvs (visual studio). You +# can also give a reporter class, e.g. mypackage.mymodule.MyReporterClass. +# output-format = + +# Tells whether to display a full report or only the messages. +# reports = + +# Activate the evaluation score. +score = true + +[tool.pylint.similarities] +# Comments are removed from the similarity computation +ignore-comments = true + +# Docstrings are removed from the similarity computation +ignore-docstrings = true + +# Imports are removed from the similarity computation +ignore-imports = true + +# Signatures are removed from the similarity computation +ignore-signatures = true + +# Minimum lines number of a similarity. +min-similarity-lines = 4 + +[tool.pylint.spelling] +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions = 4 + +# Spelling dictionary name. No available dictionaries : You need to install both +# the python package and the system dependency for enchant to work. +# spelling-dict = + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives = "fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:" + +# List of comma separated words that should not be checked. +# spelling-ignore-words = + +# A path to a file that contains the private dictionary; one word per line. +# spelling-private-dict-file = + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +# spelling-store-unknown-words = + +[tool.pylint.typecheck] +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators = ["contextlib.contextmanager"] + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +# generated-members = + +# Tells whether missing members accessed in mixin class should be ignored. A +# class is considered mixin if its name matches the mixin-class-rgx option. +# Tells whether to warn about missing members when the owner of the attribute is +# inferred to be None. +ignore-none = true + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference can +# return multiple potential results while evaluating a Python object, but some +# branches might not be evaluated, which results in partial inference. In that +# case, it might be useful to still emit no-member and other checks for the rest +# of the inferred objects. +ignore-on-opaque-inference = true + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins = ["no-member", "not-async-context-manager", "not-context-manager", "attribute-defined-outside-init"] + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes = ["optparse.Values", "thread._local", "_thread._local", "argparse.Namespace"] + +# Show a hint with possible names when a member name was not found. The aspect of +# finding the hint is based on edit distance. +missing-member-hint = true + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance = 1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices = 1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx = ".*[Mm]ixin" + +# List of decorators that change the signature of a decorated function. +# signature-mutators = + +[tool.pylint.variables] +# List of additional names supposed to be defined in builtins. Remember that you +# should avoid defining new builtins when possible. +# additional-builtins = + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables = true + +# List of names allowed to shadow builtins +# allowed-redefined-builtins = + +# List of strings which can identify a callback function by name. A callback name +# must start or end with one of those strings. +callbacks = ["cb_", "_cb"] + +# A regular expression matching the name of dummy variables (i.e. expected to not +# be used). +dummy-variables-rgx = "_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_" + +# Argument names that match this expression will be ignored. +ignored-argument-names = "_.*|^ignored_|^unused_" + +# Tells whether we should check for unused import in __init__ files. +# init-import = + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules = ["six.moves", "past.builtins", "future.builtins", "builtins", "io"] + + diff --git a/readme.md b/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..df6b48a4b3d374d61830474e5289332e55aa888e --- /dev/null +++ b/readme.md @@ -0,0 +1,206 @@ +# NekoImageGallery + +[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/hv0905/NekoImageGallery/prod.yml?logo=github)](https://github.com/hv0905/NekoImageGallery/actions) +[![codecov](https://codecov.io/gh/hv0905/NekoImageGallery/branch/master/graph/badge.svg?token=JK2KZBDIYP)](https://codecov.io/gh/hv0905/NekoImageGallery) +[![Maintainability](https://api.codeclimate.com/v1/badges/ac97a1146648996b68ea/maintainability)](https://codeclimate.com/github/hv0905/NekoImageGallery/maintainability) +![Man hours](https://img.shields.io/endpoint?url=https%3A%2F%2Fmanhours.aiursoft.cn%2Fr%2Fgithub.com%2Fhv0905%2FNekoImageGallery.json) +[![Docker Pulls](https://img.shields.io/docker/pulls/edgeneko/neko-image-gallery)](https://hub.docker.com/r/edgeneko/neko-image-gallery) + +An online AI image search engine based on the Clip model and Qdrant vector database. Supports keyword search and similar +image search. + +[中文文档](readme_cn.md) + +## ✨ Features + +- Use the Clip model to generate 768-dimensional vectors for each image as the basis for search. No need for manual + annotation or classification, unlimited classification categories. +- OCR Text search is supported, use PaddleOCR to extract text from images and use BERT to generate text vectors for + search. +- Use Qdrant vector database for efficient vector search. + +## 📷Screenshots + +![Screenshot1](web/screenshots/1.png) +![Screenshot2](web/screenshots/2.png) +![Screenshot3](web/screenshots/3.png) +![Screenshot4](web/screenshots/4.png) +![Screenshot5](web/screenshots/5.png) +![Screenshot6](web/screenshots/6.png) + +> The above screenshots may contain copyrighted images from different artists, please do not use them for other +> purposes. + +## ✈️ Deployment + +### 🖥️ Local Deployment + +#### Choose a metadata storage method + +##### Qdrant Database (Recommended) + +In most cases, we recommend using the Qdrant database to store metadata. The Qdrant database provides efficient +retrieval performance, flexible scalability, and better data security. + +Please deploy the Qdrant database according to +the [Qdrant documentation](https://qdrant.tech/documentation/quick-start/). It is recommended to use Docker for +deployment. + +If you don't want to deploy Qdrant yourself, you can use +the [online service provided by Qdrant](https://qdrant.tech/documentation/cloud/). + +##### Local File Storage + +Local file storage directly stores image metadata (including feature vectors, etc.) in a local SQLite database. It is +only recommended for small-scale deployments or development deployments. + +Local file storage does not require an additional database deployment process, but has the following disadvantages: + +- Local storage does not index and optimize vectors, so the time complexity of all searches is `O(n)`. Therefore, if the + data scale is large, the performance of search and indexing will decrease. +- Using local file storage will make NekoImageGallery stateful, so it will lose horizontal scalability. +- When you want to migrate to Qdrant database for storage, the indexed metadata may be difficult to migrate directly. + +#### Deploy NekoImageGallery + +1. Clone the project directory to your own PC or server, then checkout to a specific version tag (like `v1.0.0`). +2. It is highly recommended to install the dependencies required for this project in a Python venv virtual environment. + Run the following command: + ```shell + python -m venv .venv + . .venv/bin/activate + ``` +3. Install PyTorch. Follow the [PyTorch documentation](https://pytorch.org/get-started/locally/) to install the torch + version suitable for your system using pip. + > If you want to use CUDA acceleration for inference, be sure to install a CUDA-supported PyTorch version in this + step. After installation, you can use `torch.cuda.is_available()` to confirm whether CUDA is available. +4. Install other dependencies required for this project: + ```shell + pip install -r requirements.txt + ``` +5. Modify the project configuration file inside `config/`, you can edit `default.env` directly, but it's recommended to + create a new file named `local.env` and override the configuration in `default.env`. +6. Run this application: + ```shell + python main.py + ``` + You can use `--host` to specify the IP address you want to bind to (default is 0.0.0.0) and `--port` to specify the + port you want to bind to (default is 8000). + You can see all available commands and options by running `python main.py --help`. +7. (Optional) Deploy the front-end application: [NekoImageGallery.App](https://github.com/hv0905/NekoImageGallery.App) + is a simple web front-end application for this project. If you want to deploy it, please refer to + its [deployment documentation](https://github.com/hv0905/NekoImageGallery.App). + +### 🐋 Docker Deployment + +#### About docker images + +NekoImageGallery's docker image are built and released on Docker Hub, including serval variants: + +| Tags | Description | Latest Image Size | +|---------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `edgeneko/neko-image-gallery:`
`edgeneko/neko-image-gallery:-cuda`
`edgeneko/neko-image-gallery:-cuda12.1` | Supports GPU inferencing with CUDA12.1 | [![Docker Image Size (tag)](https://img.shields.io/docker/image-size/edgeneko/neko-image-gallery/latest?label=Image%20(cuda))](https://hub.docker.com/r/edgeneko/neko-image-gallery) | +| `edgeneko/neko-image-gallery:-cuda11.8` | Supports GPU inferencing with CUDA11.8 | [![Docker Image Size (tag)](https://img.shields.io/docker/image-size/edgeneko/neko-image-gallery/latest-cuda11.8?label=Image%20(cuda11.8))](https://hub.docker.com/r/edgeneko/neko-image-gallery) | +| `edgeneko/neko-image-gallery:-cpu` | Only supports CPU inferencing | [![Docker Image Size (tag)](https://img.shields.io/docker/image-size/edgeneko/neko-image-gallery/latest-cpu?label=Image%20(cpu))](https://hub.docker.com/r/edgeneko/neko-image-gallery) | + +Where `` is the version number or version alias of NekoImageGallery, as follows: + +| Version | Description | +|-------------------|--------------------------------------------------------------------------------------------------------| +| `latest` | The latest stable version of NekoImageGallery | +| `v*.*.*` / `v*.*` | The specific version number (correspond to Git tags) | +| `edge` | The latest development version of NekoImageGallery, may contain unstable features and breaking changes | + +In each image, we have bundled the necessary dependencies, `openai/clip-vit-large-patch14` model +weights, `bert-base-chinese` model weights and `easy-paddle-ocr` models to provide a complete and ready-to-use image. + +The images uses `/opt/NekoImageGallery/static` as volume to store image files, mount it to your own volume or directory +if local storage is required. + +For configuration, we suggest using environment variables to override the default configuration. Secrets (like API +tokens) can be provided by [docker secrets](https://docs.docker.com/engine/swarm/secrets/). + +#### Prepare `nvidia-container-runtime` (CUDA users only) + +If you want to use CUDA acceleration, you need to install `nvidia-container-runtime` on your system. Please refer to +the [official documentation](https://docs.docker.com/config/containers/resource_constraints/#gpu) for installation. + +> Related Document: +> 1. https://docs.docker.com/config/containers/resource_constraints/#gpu +> 2. https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker +> 3. https://nvidia.github.io/nvidia-container-runtime/ + +#### Run the server + +1. Download the `docker-compose.yml` file from repository. + ```shell + # For cuda deployment (default) + wget https://raw.githubusercontent.com/hv0905/NekoImageGallery/master/docker-compose.yml + # For CPU-only deployment + wget https://raw.githubusercontent.com/hv0905/NekoImageGallery/master/docker-compose-cpu.yml && mv docker-compose-cpu.yml docker-compose.yml + ``` +2. Modify the docker-compose.yml file as needed +3. Run the following command to start the server: + ```shell + # start in foreground + docker compose up + # start in background(detached mode) + docker compose up -d + ``` + +### Upload images to NekoImageGallery + +There are serval ways to upload images to NekoImageGallery + +- Through the web interface: You can use the web interface to upload images to the server. The web interface is provided + by [NekoImageGallery.App](https://github.com/hv0905/NekoImageGallery.App). Make sure you have enabled the + **Admin API** and set your **Admin Token** in the configuration file. +- Through local indexing: This is suitable for local deployment or when the images you want to upload are already on the + server. + Use the following command to index your local image directory: + ```shell + python main.py local-index + ``` + The above command will recursively upload all images in the specified directory and its subdirectories to the server. + You can also specify categories/starred for images you upload, see `python main.py local-index --help` for more + information. +- Through the API: You can use the upload API provided by NekoImageGallery to upload images. By using this method, the + server can prevent saving the image files locally but only store their URLs and metadata. + Make sure you have enabled the **Admin API** and set your **Admin Token** in the configuration file. + This method is suitable for automated image uploading or sync NekoImageGallery with external systems. + Checkout [API documentation](#-api-documentation) for more information. + +## 📚 API Documentation + +The API documentation is provided by FastAPI's built-in Swagger UI. You can access the API documentation by visiting +the `/docs` or `/redoc` path of the server. + +## ⚡ Related Project + +Those project works with NekoImageGallery :D + +[![NekoImageGallery.App](https://github-readme-stats.vercel.app/api/pin/?username=hv0905&repo=NekoImageGallery.App&show_owner=true)](https://github.com/hv0905/NekoImageGallery.App) +[![LiteLoaderQQNT-NekoImageGallerySearch](https://github-readme-stats.vercel.app/api/pin/?username=pk5ls20&repo=LiteLoaderQQNT-NekoImageGallerySearch&show_owner=true)](https://github.com/pk5ls20/LiteLoaderQQNT-NekoImageGallerySearch) +[![nonebot-plugin-nekoimage](https://github-readme-stats.vercel.app/api/pin/?username=pk5ls20&repo=nonebot-plugin-nekoimage&show_owner=true)](https://github.com/pk5ls20/pk5ls20/nonebot-plugin-nekoimage) + +## 📊 Repository Summary + +![Alt](https://repobeats.axiom.co/api/embed/ac080afa0d2d8af0345f6818b9b7c35bf8de1d31.svg "Repobeats analytics image") + +## ♥ Contributing + +There are many ways to contribute to the project: logging bugs, submitting pull requests, reporting issues, and creating +suggestions. + +Even if you with push access on the repository, you should create a personal feature branches when you need them. +This keeps the main repository clean and your workflow cruft out of sight. + +We're also interested in your feedback on the future of this project. You can submit a suggestion or feature request +through the issue tracker. To make this process more effective, we're asking that these include more information to help +define them more clearly. + +## Copyright + +Copyright 2023 EdgeNeko + +Licensed under AGPLv3 license. diff --git a/readme_cn.md b/readme_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..85b7df590c1840efa909f74a5546bd264590e9fe --- /dev/null +++ b/readme_cn.md @@ -0,0 +1,184 @@ +# NekoImageGallery + +[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/hv0905/NekoImageGallery/prod.yml?logo=github)](https://github.com/hv0905/NekoImageGallery/actions) +[![codecov](https://codecov.io/gh/hv0905/NekoImageGallery/branch/master/graph/badge.svg?token=JK2KZBDIYP)](https://codecov.io/gh/hv0905/NekoImageGallery) +[![Maintainability](https://api.codeclimate.com/v1/badges/ac97a1146648996b68ea/maintainability)](https://codeclimate.com/github/hv0905/NekoImageGallery/maintainability) +![Man hours](https://img.shields.io/endpoint?url=https%3A%2F%2Fmanhours.aiursoft.cn%2Fr%2Fgithub.com%2Fhv0905%2FNekoImageGallery.json) +[![Docker Pulls](https://img.shields.io/docker/pulls/edgeneko/neko-image-gallery)](https://hub.docker.com/r/edgeneko/neko-image-gallery) + +基于Clip模型与Qdrant向量数据库的在线AI图片搜索引擎。支持关键字搜索以及相似图片搜索。 + +[English Document](readme.md) + +## ✨特性 + +- 使用Clip模型为每张图片生成768维向量作为搜索依据。无需人工标注或分类,无限分类类别。 +- 支持OCR文本搜索,使用PaddleOCR提取图片文本并使用BERT模型生成文本特征向量。 +- 使用Qdrant向量数据库进行高效的向量搜索。 + +## 📷截图 + +![Screenshot1](web/screenshots/1.png) +![Screenshot2](web/screenshots/2.png) +![Screenshot3](web/screenshots/3.png) +![Screenshot4](web/screenshots/4.png) +![Screenshot5](web/screenshots/5.png) +![Screenshot6](web/screenshots/6.png) + +> 以上截图可能包含来自不同画师的版权图片,请不要将其用作其它用途。 + +## ✈️部署 + +### 🖥️ 本地部署 + +#### 选择元数据存储方式 + +NekoImageGallery支持两种元数据存储方式:Qdrant数据库存储与本地文件存储。您可以根据自己的需求选择其中一种方式。 + +##### Qdrant数据库 (推荐) + +在大多数情况下,我们推荐使用Qdrant数据库存储元数据。Qdrant数据库提供了高效的检索性能,灵活的扩展性以及更好的数据安全性。 + +请根据[Qdrant文档](https://qdrant.tech/documentation/quick-start/)部署Qdrant数据库,推荐使用docker部署。 + +如果你不想自己部署Qdrant,可以使用[Qdrant官方提供的在线服务](https://qdrant.tech/documentation/cloud/)。 + +##### 本地文件存储 + +本地文件存储直接将图片元数据(包括特征向量等)存在本地的Sqlite数据库中。仅建议在小规模部署或开发部署中使用。 + +本地文件存储不需要额外的数据库部署流程,但是存在以下缺点: + +- 本地存储没有对向量进行索引和优化,所有搜索的时间复杂度为`O(n)`,因此若数据规模较大,搜索与索引的性能会下降。 +- 使用本地文件存储会使得NekoImageGallery变得有状态,因此会丧失横向扩展能力。 +- 当你希望迁移到Qdrant数据库进行存储时,已索引的元数据可能难以直接迁移。 + +#### 部署NekoImageGallery + +1. 将项目目录clone到你自己的PC或服务器中,然后按需checkout到特定版本tag(如`v1.0.0`)。 +2. 强烈建议在python venv虚拟环境中安装本项目所需依赖, 运行下面命令: + ```shell + python -m venv .venv + . .venv/bin/activate + ``` +3. 安装PyTorch. 按照[PyTorch文档](https://pytorch.org/get-started/locally/)使用pip安装适合你的系统的torch版本 + > 如果您希望使用CUDA加速推理,务必在本步中安装支持Cuda的pytorch版本,安装完成后可以使用`torch.cuda.is_available()` + 确认CUDA是否可用。 +4. 安装其它本项目所需依赖: + ```shell + pip install -r requirements.txt + ``` +5. 按需修改位于`config`目录下的配置文件,您可以直接修改`default.env`,但是建议创建一个名为`local.env` + 的文件,覆盖`default.env`中的配置。 +6. 运行本应用: + ```shell + python main.py + ``` + 你可以通过`--host`指定希望绑定到的ip地址(默认为0.0.0.0),通过`--port`指定希望绑定到的端口(默认为8000)。 + 通过`python main.py --help`可以查看所有可用命令和选项。 +7. (可选)部署前端应用:[NekoImageGallery.App](https://github.com/hv0905/NekoImageGallery.App) + 是本项目的一个简易web前端应用,如需部署请参照它的[部署文档](https://github.com/hv0905/NekoImageGallery.App)。 + +### 🐋 Docker 部署 + +#### 关于Docker镜像 + +NekoImageGallery镜像发布在DockerHub上,并包含多个变种,设计于在不同的环境使用。 + +| Tags | 介绍 | Latest 镜像尺寸 | +|---------------------------------------------------------------------------------------------------------------------------------------------|------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `edgeneko/neko-image-gallery:`
`edgeneko/neko-image-gallery:-cuda`
`edgeneko/neko-image-gallery:-cuda12.1` | 基于CUDA12.1, 支持GPU推理的镜像 | [![Docker Image Size (tag)](https://img.shields.io/docker/image-size/edgeneko/neko-image-gallery/latest?label=Image%20(cuda))](https://hub.docker.com/r/edgeneko/neko-image-gallery) | +| `edgeneko/neko-image-gallery:-cuda11.8` | 基于CUDA11.8, 支持GPU推理的镜像 | [![Docker Image Size (tag)](https://img.shields.io/docker/image-size/edgeneko/neko-image-gallery/latest-cuda11.8?label=Image%20(cuda11.8))](https://hub.docker.com/r/edgeneko/neko-image-gallery) | +| `edgeneko/neko-image-gallery:-cpu` | 仅支持CPU推理的镜像 | [![Docker Image Size (tag)](https://img.shields.io/docker/image-size/edgeneko/neko-image-gallery/latest-cpu?label=Image%20(cpu))](https://hub.docker.com/r/edgeneko/neko-image-gallery) | + +其中,``为NekoImageGallery的版本号或版本代称,具体如下: + +| Version | 介绍 | +|-------------------|------------------------------------------------------| +| `latest` | 最新的稳定版本 | +| `v*.*.*` / `v*.*` | 特定版本号(与GitHub Tag对应) | +| `edge` | 最新的开发版本,与master分支同步更新,可能包含未经完善测试的功能和breaking changes | + +在每个镜像中,我们捆绑了必要的依赖项,包括 `openai/clip-vit-large-patch14` 模型权重、`bert-base-chinese` +模型权重和 `easy-paddle-ocr` 模型,以提供一个完整且可直接使用的镜像。 + +镜像使用 `/opt/NekoImageGallery/static` 作为存储图像文件的卷,如果需要本地存储,可以将其挂载到您自己的卷或目录。 + +对于配置,我们建议使用环境变量来覆盖默认配置。机密信息(如 API +令牌)可以通过 [docker secrets](https://docs.docker.com/engine/swarm/secrets/) 提供。 + +#### 准备`nvidia-container-runtime` + +如果你希望在推理时支持CUDA加速,请参考[Docker GPU相关文档](https://docs.docker.com/config/containers/resource_constraints/#gpu) +准备支持GPU的容器运行时。 + +> 相关文档: +> 1. https://docs.docker.com/config/containers/resource_constraints/#gpu +> 2. https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker +> 3. https://nvidia.github.io/nvidia-container-runtime/ + +#### 运行 + +1. 下载`docker-compose.yml`文件。 + ```shell + # For cuda deployment (default) + wget https://raw.githubusercontent.com/hv0905/NekoImageGallery/master/docker-compose.yml + # For CPU-only deployment + wget https://raw.githubusercontent.com/hv0905/NekoImageGallery/master/docker-compose-cpu.yml && mv docker-compose-cpu.yml docker-compose.yml + ``` +2. 按需修改docker-compose.yml文件 +3. 运行下面命令启动docker-compose + ```shell + # start in foreground + docker compose up + # start in background(detached mode) + docker compose up -d + ``` + +### 将图片上传至NekoImageGallery + +有几种方法可以将图片上传至NekoImageGallery: + +- 通过网页界面:您可以使用网页界面将图片上传到服务器。网页界面由 [NekoImageGallery.App](https://github.com/hv0905/NekoImageGallery.App) + 提供。请确保您已启用 **Admin API** 并在配置文件中设置了您的 **Admin Token**。 +- 通过本地索引:这适用于本地部署或当您想上传的图片已经在服务器上时。 + 使用以下命令来索引您的本地图片目录: + ```shell + python main.py local-index + ``` + 上述命令将递归地将指定目录及其子目录中的所有图片上传到服务器。 + 你可以通过附加选项指定上传的图片的类别和星标状态,具体参考`python main.py local-index --help`。 +- 通过API:您可以使用NekoImageGallery提供的上传API来上传图片。通过此方法,可允许服务器本地不保存图片文件而仅仅存储其URL以及元数据。 + 请确保您已启用 **Admin API** 并在配置文件中设置了您的 **Admin Token**。 + 此方法适用于自动化图片上传或将NekoImageGallery与外部系统进行同步。更多信息请查看 [API文档](#-api文档)。 + +## 📚 API文档 + +API文档由FastAPI内置的Swagger UI提供。您可以通过访问服务器的`/docs`或`/redoc`路径来查看API文档。 + +## ⚡ 相关项目 + +以下项目基于NekoImageGallery工作! + +[![NekoImageGallery.App](https://github-readme-stats.vercel.app/api/pin/?username=hv0905&repo=NekoImageGallery.App&show_owner=true)](https://github.com/hv0905/NekoImageGallery.App) +[![LiteLoaderQQNT-NekoImageGallerySearch](https://github-readme-stats.vercel.app/api/pin/?username=pk5ls20&repo=LiteLoaderQQNT-NekoImageGallerySearch&show_owner=true)](https://github.com/pk5ls20/LiteLoaderQQNT-NekoImageGallerySearch) +[![nonebot-plugin-nekoimage](https://github-readme-stats.vercel.app/api/pin/?username=pk5ls20&repo=nonebot-plugin-nekoimage&show_owner=true)](https://github.com/pk5ls20/pk5ls20/nonebot-plugin-nekoimage) + +## 📊仓库信息 + +![Alt](https://repobeats.axiom.co/api/embed/ac080afa0d2d8af0345f6818b9b7c35bf8de1d31.svg "Repobeats analytics image") + +## ❤️贡献指南 + +有很多种可以为本项目提供贡献的方式:记录 Bug,提交 Pull Request,报告问题,提出建议等等。 + +即使您拥有对本仓库的写入权限,您也应该在有需要时创建自己的功能分支并通过 Pull Request 的方式提交您的变更。 +这有助于让我们的主仓库保持整洁并使您的个人工作流程不可见。 + +我们也很感兴趣听到您关于这个项目未来的反馈。您可以通过 Issues 追踪器提交建议或功能请求。为了使这个过程更加有效,我们希望这些内容包含更多信息,以更清晰地定义它们。 + +## Copyright + +Copyright 2023 EdgeNeko + +Licensed under GPLv3 license. diff --git a/requirements.dev.txt b/requirements.dev.txt new file mode 100644 index 0000000000000000000000000000000000000000..2b30bbc5bfe79dc15ba58620193cbb337d9c452b --- /dev/null +++ b/requirements.dev.txt @@ -0,0 +1,6 @@ +# Requirements for development and testing + +pytest +pytest-asyncio +pytest-cov +pylint \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b4f416d079cfab99987fdba11e05140c46efceb1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,35 @@ +# Fastapi +fastapi>=0.110.2 +python-multipart>=0.0.9 +uvicorn[standard] +pydantic +pydantic-settings +typer + +# AI - Manually install cuda-capable pytorch +torch>=2.1.0 +torchvision +transformers>4.35.2 +pillow>9.3.0 +numpy + +# OCR - you can choose other option if necessary, or completely disable it if you don't need this feature +easypaddleocr>=0.2.1 +# easyocr +# paddleocr + +# Vector Database +qdrant-client>=1.9.2 + +# Storage +opendal + +# Misc +aiofiles +aiopath +wcmatch +pyyaml +loguru +httpx +pytest +rich \ No newline at end of file diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/db_migrations.py b/scripts/db_migrations.py new file mode 100644 index 0000000000000000000000000000000000000000..a457a266ef35f335516b56368a9c0291c94f2e06 --- /dev/null +++ b/scripts/db_migrations.py @@ -0,0 +1,44 @@ +from loguru import logger + +from app.Services.provider import ServiceProvider + +CURRENT_VERSION = 2 + +services: ServiceProvider | None = None + + +async def migrate_v1_v2(): + logger.info("Migrating from v1 to v2...") + next_id = None + count = 0 + while True: + points, next_id = await services.db_context.scroll_points(next_id, count=100) + for point in points: + count += 1 + logger.info("[{}] Migrating point {}", count, point.id) + if point.url.startswith('/'): + # V1 database assuming all image with '/' as begins is a local image, + # v2 migrate to a more strict approach + point.local = True + await services.db_context.updatePayload(point) # This will also store ocr_text_lower field, if present + if point.ocr_text is not None: + point.text_contain_vector = services.transformers_service.get_bert_vector(point.ocr_text_lower) + + logger.info("Updating vectors...") + # Update vectors for this group of points + await services.db_context.updateVectors([t for t in points if t.text_contain_vector is not None]) + if next_id is None: + break + + +async def migrate(from_version: int): + global services + services = ServiceProvider() + await services.onload() + match from_version: + case 1: + await migrate_v1_v2() + case 2: + logger.info("Already up to date.") + case _: + raise ValueError(f"Unknown version {from_version}") diff --git a/scripts/local_create_thumbnail.py b/scripts/local_create_thumbnail.py new file mode 100644 index 0000000000000000000000000000000000000000..5ad69108341a46db30a36422a3a30c39dea20fb0 --- /dev/null +++ b/scripts/local_create_thumbnail.py @@ -0,0 +1,57 @@ +import io +import uuid + +from PIL import Image +from loguru import logger + +from app.Services.provider import ServiceProvider + + +async def main(): + services = ServiceProvider() + await services.onload() + # Here path maybe either local path or pure path + count = 0 + async for item in services.storage_service.active_storage.list_files("", '*.*', batch_max_files=1): + item = item[0] + count += 1 + logger.info("[{}] Processing {}", str(count), str(item)) + size = await services.storage_service.active_storage.size(item) + if size < 1024 * 500: + logger.warning("File size too small: {}. Skip...", size) + continue + try: + if await services.storage_service.active_storage.is_exist(f'thumbnails/{item.stem}.webp'): + logger.warning("Thumbnail for {} already exists. Skip...", item.stem) + continue + image_id = uuid.UUID(item.stem) + except ValueError: + logger.warning("Invalid file name: {}. Skip...", item.stem) + continue + try: + imgdata = await services.db_context.retrieve_by_id(str(image_id)) + except Exception as e: + logger.error("Error when retrieving image {}: {}", image_id, e) + continue + try: + img_byte = await services.storage_service.active_storage.fetch(item) + img = Image.open(io.BytesIO(img_byte)) + except Exception as e: + logger.error("Error when opening image {}: {}", item, e) + continue + + # generate thumbnail max size 256*256 + img.thumbnail((256, 256)) + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, 'WebP', save_all=True) + await services.storage_service.active_storage.upload(img_byte_arr.getvalue(), + f'thumbnails/{str(image_id)}.webp') + logger.success("Thumbnail for {} generated!", image_id) + + # update payload + imgdata.thumbnail_url = await services.storage_service.active_storage.url(f'thumbnails/{str(image_id)}.webp') + imgdata.local_thumbnail = True + await services.db_context.updatePayload(imgdata) + logger.success("Payload for {} updated!", image_id) + + logger.success("OK. Updated {} items.", count) diff --git a/scripts/local_indexing.py b/scripts/local_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..11c086e63e7b8edf192bf36bdcb8dbba0bd994fb --- /dev/null +++ b/scripts/local_indexing.py @@ -0,0 +1,53 @@ +import sys +from datetime import datetime +from pathlib import Path + +import PIL +from loguru import logger +from rich.progress import Progress + +from app.Models.api_models.admin_query_params import UploadImageThumbnailMode +from app.Models.errors import PointDuplicateError +from app.Models.img_data import ImageData +from app.Services.provider import ServiceProvider +from app.util.local_file_utility import glob_local_files + +services: ServiceProvider | None = None + + +async def index_task(file_path: Path, categories: list[str], starred: bool, thumbnail_mode: UploadImageThumbnailMode): + try: + img_id = await services.upload_service.assign_image_id(file_path) + image_data = ImageData(id=img_id, + local=True, + categories=categories, + starred=starred, + format=file_path.suffix[1:], # remove the dot + index_date=datetime.now()) + await services.upload_service.sync_upload_image(image_data, file_path.read_bytes(), skip_ocr=False, + thumbnail_mode=thumbnail_mode) + except PointDuplicateError as ex: + logger.warning("Image {} already exists in the database", file_path) + except PIL.UnidentifiedImageError as e: + logger.error("Error when processing image {}: {}", file_path, e) + + +@logger.catch() +async def main(root_directory: list[Path], categories: list[str], starred: bool, + thumbnail_mode: UploadImageThumbnailMode): + global services + services = ServiceProvider() + await services.onload() + files = [] + for root in root_directory: + files.extend(list(glob_local_files(root, '**/*'))) + with Progress() as progress: + # A workaround for the loguru logger to work with rich progressbar + logger.remove() + logger.add(sys.stderr, colorize=True) + for idx, item in enumerate(progress.track(files, description="Indexing...")): + logger.info("[{} / {}] Indexing {}", idx + 1, len(files), str(item)) + + await index_task(item, categories, starred, thumbnail_mode) + + logger.success("Indexing completed!") diff --git a/scripts/qdrant_create_collection.py b/scripts/qdrant_create_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..3095485b8a7a3778d7fab3f25fbf4512c9d66c03 --- /dev/null +++ b/scripts/qdrant_create_collection.py @@ -0,0 +1,6 @@ +from app.Services.vector_db_context import VectorDbContext + + +async def main(): + context = VectorDbContext() + await context.initialize_collection() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/api/conftest.py b/tests/api/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..f0ed2711f6be5a7d9281c2e953ad14dc461eeec6 --- /dev/null +++ b/tests/api/conftest.py @@ -0,0 +1,69 @@ +import asyncio +import importlib +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +from app import config + +TEST_ACCESS_TOKEN = 'test_token' +TEST_ADMIN_TOKEN = 'test_admin_token' + + +@pytest.fixture(scope="session") +def unauthorized_test_client(tmp_path_factory) -> TestClient: + # Modify the configuration for testing + config.config.qdrant.mode = "memory" + config.config.admin_api_enable = True + config.config.access_protected = True + config.config.access_token = TEST_ACCESS_TOKEN + config.config.admin_token = TEST_ADMIN_TOKEN + config.config.storage.method = config.StorageMode.LOCAL + config.config.storage.local.path = tmp_path_factory.mktemp("static_files") + # Start the application + + with TestClient(importlib.import_module('app.webapp').app) as client: + yield client + + +@pytest.fixture(scope="module") +def test_client(unauthorized_test_client): + unauthorized_test_client.headers = {'x-access-token': TEST_ACCESS_TOKEN, 'x-admin-token': TEST_ADMIN_TOKEN} + yield unauthorized_test_client + unauthorized_test_client.headers = {} + + +def check_local_dir_empty(): + dir = Path(config.config.storage.local.path) + files = [f for f in dir.glob('*.*') if f.is_file()] + assert len(files) == 0 + + thumbnail_dir = dir / 'thumbnails' + if thumbnail_dir.exists(): + thumbnail_files = [f for f in thumbnail_dir.glob('*.*') if f.is_file()] + assert len(thumbnail_files) == 0 + + +@pytest.fixture() +def ensure_local_dir_empty(): + yield + check_local_dir_empty() + + +@pytest.fixture(scope="module") +def wait_for_background_task(test_client): + async def func(expected_image_count): + while True: + resp = test_client.get('/admin/server_info') + if resp.json()['image_count'] >= expected_image_count: + break + await asyncio.sleep(0.2) + assert resp.json()['index_queue_length'] == 0 + + return func + + +@pytest.fixture +def anyio_backend(): + return 'asyncio' diff --git a/tests/api/test_home.py b/tests/api/test_home.py new file mode 100644 index 0000000000000000000000000000000000000000..b69826f868a5e8a01a21d34863e1604e9cefe607 --- /dev/null +++ b/tests/api/test_home.py @@ -0,0 +1,24 @@ +class TestHome: + + def test_get_home_no_tokens(self, unauthorized_test_client): + response = unauthorized_test_client.get("/") + assert response.status_code == 200 + assert response.json()['authorization']['required'] + assert not response.json()['authorization']['passed'] + assert response.json()['admin_api']['available'] + assert not response.json()['admin_api']['passed'] + + def test_get_home_access_token(self, unauthorized_test_client): + response = unauthorized_test_client.get("/", headers={'x-access-token': 'test_token'}) + assert response.status_code == 200 + assert response.json()['authorization']['required'] + assert response.json()['authorization']['passed'] + + def test_get_home_admin_token(self, unauthorized_test_client): + response = unauthorized_test_client.get("/", headers={'x-admin-token': 'test_admin_token', + 'x-access-token': 'test_token'}) + assert response.status_code == 200 + assert response.json()['admin_api']['available'] + assert response.json()['admin_api']['passed'] + assert response.json()['authorization']['required'] + assert response.json()['authorization']['passed'] diff --git a/tests/api/test_search.py b/tests/api/test_search.py new file mode 100644 index 0000000000000000000000000000000000000000..93aded84cb3e2514a63c3d9be50c9937d775c2f7 --- /dev/null +++ b/tests/api/test_search.py @@ -0,0 +1,126 @@ +import itertools +import uuid + +import pytest_asyncio + +from .conftest import check_local_dir_empty +from ..assets import assets_path + +test_images = {'bsn': ['bsn_0.jpg', 'bsn_1.jpg', 'bsn_2.jpg'], + 'cat': ['cat_0.jpg', 'cat_1.jpg'], + 'cg': ['cg_0.jpg', 'cg_1.png']} + + +@pytest_asyncio.fixture(scope="module") +async def img_ids(test_client, wait_for_background_task): + img_ids = {} + for img_cls, item_images in test_images.items(): + img_ids[img_cls] = [] + for image in item_images: + print(f'upload image {image}...') + with open(assets_path / 'test_images' / image, 'rb') as f: + resp = test_client.post('/admin/upload', + files={'image_file': f}, + params={'local': True}) + assert resp.status_code == 200 + img_ids[img_cls].append(resp.json()['image_id']) + + print('Waiting for images to be processed...') + + await wait_for_background_task(sum(len(v) for v in test_images.values())) + + yield img_ids + + # cleanup + for img_cls in test_images.keys(): + for img_id in img_ids[img_cls]: + resp = test_client.delete(f"/admin/delete/{img_id}") + assert resp.status_code == 200 + + check_local_dir_empty() + + +def test_search_text(test_client, img_ids): + resp = test_client.get('/search/text/hatsune+miku') + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] in img_ids['cg'] + + +def test_search_image(test_client, img_ids): + with open(assets_path / 'test_images' / test_images['cat'][0], 'rb') as f: + resp = test_client.post('/search/image', + files={'image': f}) + + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] in img_ids['cat'] + + +def test_search_similar(test_client, img_ids): + resp = test_client.get(f"/search/similar/{img_ids['bsn'][0]}") + + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] in img_ids['bsn'] + + +def test_search_advanced(test_client, img_ids): + resp = test_client.post("/search/advanced", + json={'criteria': ['white background', 'grayscale image'], + 'negative_criteria': ['cat', 'hatsune miku']}) + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] in img_ids['bsn'] + + +def test_search_combined(test_client, img_ids): + resp = test_client.post('/search/combined', json={'criteria': ['hatsune miku'], + 'negative_criteria': ['grayscale image', 'cat'], + 'extra_prompt': 'hatsunemiku'}) + + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] == img_ids['cg'][1] + + resp = test_client.post('/search/combined?basis=ocr', + json={'criteria': ['hatsunemiku'], 'extra_prompt': 'hatsune miku'}) + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] == img_ids['cg'][1] + + +def test_search_filters(test_client, img_ids): + resp = test_client.put(f"/admin/update_opt/{img_ids['bsn'][0]}", json={'categories': ['bsn'], 'starred': True}) + assert resp.status_code == 200 + + resp = test_client.get("/search/text/cat", params={'categories': 'bsn'}) + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] == img_ids['bsn'][0] + + resp = test_client.get("/search/text/cat", params={'starred': True}) + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] == img_ids['bsn'][0] + + +def test_images_query_by_id(test_client, img_ids): + resp = test_client.get(f"/images/id/{img_ids['bsn'][0]}") + assert resp.status_code == 200 + assert resp.json()['img']['id'] == img_ids['bsn'][0] + + +def test_images_query_not_exist(test_client, img_ids): + resp = test_client.get(f"/images/id/{uuid.uuid4()}") + assert resp.status_code == 404 + + +def test_images_query_scroll(test_client, img_ids): + resp = test_client.get("/images/", params={'count': 50}) + assert resp.status_code == 200 + resp_imgs = resp.json()['images'] + all_images_id = list(itertools.chain(*img_ids.values())) + for item in resp_imgs: + assert item['id'] in all_images_id + + paging_test = test_client.get(f'/images', + params={'prev_offset_id': resp_imgs[len(resp_imgs) // 2]['id']}) + assert paging_test.status_code == 200 + assert paging_test.json()['images'][0]['id'] == resp_imgs[len(resp_imgs) // 2]['id'] + + no_exist_test = test_client.get(f'/images', + params={'prev_offset_id': uuid.uuid4()}) + assert no_exist_test.status_code == 404 diff --git a/tests/api/test_upload.py b/tests/api/test_upload.py new file mode 100644 index 0000000000000000000000000000000000000000..03e6ffa0e6a6de8a1825dc5c42034cb1483a1895 --- /dev/null +++ b/tests/api/test_upload.py @@ -0,0 +1,223 @@ +import io +import random + +import pytest + +from ..assets import assets_path + +test_file_path = assets_path / 'test_images' / 'bsn_0.jpg' +test_file_2_path = assets_path / 'test_images' / 'bsn_1.jpg' + +test_file_hashes = ['648351F7CBD472D0CA23EADCCF3B9E619EC9ADDA', 'C5DE90DAC2F75FBDBE48023DF4DE7585A86B2392'] + + +def get_single_img_info(test_client, image_id): + query = test_client.get('/search/random') + assert query.status_code == 200 + assert query.json()['result'][0]['img']['id'] == image_id + + return query.json()['result'][0]['img'] + + +def test_upload_bad_img_file(test_client): + bad_img_file = io.BytesIO(bytearray(random.getrandbits(8) for _ in range(1024 * 1024))) + bad_img_file.name = 'bad_image.jpg' + + resp = test_client.post('/admin/upload', + files={'image_file': bad_img_file}, + params={'local': True}) + assert resp.status_code == 422 + + +def test_upload_unsupported_types(test_client): + bad_img_file = io.BytesIO(bytearray(random.getrandbits(8) for _ in range(1024 * 1024))) + bad_img_file.name = 'bad_image.tga' + + resp = test_client.post('/admin/upload', + files={'image_file': ('bad_img.tga', bad_img_file, 'image/tga')}, + params={'local': True}) + assert resp.status_code == 415 + + +@pytest.mark.asyncio +async def test_upload_duplicate(test_client, ensure_local_dir_empty, wait_for_background_task): + def upload(file): + return test_client.post('/admin/upload', + files={'image_file': file}, + params={'local': True}) + + def validate(hashes): + return test_client.post('/admin/duplication_validate', + json={'hashes': hashes}) + + with open(test_file_path, 'rb') as f: + # Validate 1# + val_resp = validate(test_file_hashes) + assert val_resp.status_code == 200 + assert val_resp.json()['exists'] == [False, False] + assert val_resp.json()['entity_ids'] == [None, None] + + # Upload + resp = upload(f) + assert resp.status_code == 200 + image_id = resp.json()['image_id'] + + for i in range(0, 2): + # Re-upload + resp = upload(f) + assert resp.status_code == 409, i + + # Query by ID + query = test_client.get(f'/images/id/{image_id}') + assert query.status_code == 200 + assert query.json()['img_status'] == 'mapped' if i == 1 else 'in_queue' + + # Validate + val_resp = validate(test_file_hashes) + assert val_resp.status_code == 200, i + assert val_resp.json()['exists'] == [True, False], i + assert val_resp.json()['entity_ids'] == [str(image_id), None], i + + # Wait for the image to be indexed + if i == 0: + await wait_for_background_task(1) + + # cleanup + resp = test_client.delete(f'/admin/delete/{image_id}') + assert resp.status_code == 200 + + +TEST_FAKE_URL = 'fake-url' +TEST_FAKE_THUMBNAIL_URL = 'fake-thumbnail-url' + +TEST_UPLOAD_THUMBNAILS_PARAMS = [ + (True, {'local': True}, True, 'local'), + (True, {'local': True, 'local_thumbnail': 'never'}, True, 'none'), + (False, {'local': True, 'local_thumbnail': 'always'}, True, 'local'), + (False, {'local': True}, True, 'none'), + (False, {'local': False, 'url': TEST_FAKE_URL, 'thumbnail_url': TEST_FAKE_THUMBNAIL_URL}, False, 'fake'), + (False, {'local': False, 'url': TEST_FAKE_URL, 'local_thumbnail': 'always'}, False, 'local'), + (False, {'local': False, 'url': TEST_FAKE_URL}, False, 'none'), +] + + +@pytest.mark.parametrize('add_trailing_bytes,params,expect_local_url,expect_thumbnail_mode', + TEST_UPLOAD_THUMBNAILS_PARAMS) +@pytest.mark.asyncio +async def test_upload_thumbnails(test_client, ensure_local_dir_empty, wait_for_background_task, # Fixtures + add_trailing_bytes, params, expect_local_url, expect_thumbnail_mode): # Parameters + with open(test_file_path, 'rb') as f: + # append 500KB to the image, to make it large enough to generate a thumbnail + if add_trailing_bytes: + img_bytes = f.read() + img_bytes += bytearray(random.getrandbits(8) for _ in range(1024 * 500)) + f_patched = io.BytesIO(img_bytes) + f_patched.name = 'bsn_0.jpg' + else: + f_patched = f + resp = test_client.post('/admin/upload', + files={'image_file': f_patched}, + params=params) + assert resp.status_code == 200 + image_id = resp.json()['image_id'] + await wait_for_background_task(1) + + query = get_single_img_info(test_client, image_id) + + if expect_local_url: + assert query['url'].startswith(f'/static/{image_id}.') + img_request = test_client.get(query['url']) + assert img_request.status_code == 200 + else: + assert query['url'] == TEST_FAKE_URL + + match expect_thumbnail_mode: + case 'local': + assert query['thumbnail_url'] == f'/static/thumbnails/{image_id}.webp' + + thumbnail_request = test_client.get(query['thumbnail_url']) + assert thumbnail_request.status_code == 200 + # IDK why starlette doesn't return the correct content type, but it works on the browser anyway + # assert thumbnail_request.headers['Content-Type'] == 'image/webp' + case 'fake': + assert query['thumbnail_url'] == TEST_FAKE_THUMBNAIL_URL + case 'none': + assert query['thumbnail_url'] is None + + # cleanup + resp = test_client.delete(f'/admin/delete/{image_id}') + assert resp.status_code == 200 + + +TEST_FAKE_URL_NEW = 'fake-url-new' +TEST_FAKE_THUMBNAIL_URL_NEW = 'fake-thumbnail-url-new' + +TEST_UPDATE_OPT_PARAMS = [ + ({'url': TEST_FAKE_URL}, {'url': TEST_FAKE_URL_NEW, 'thumbnail_url': TEST_FAKE_THUMBNAIL_URL_NEW}, + {'url': TEST_FAKE_URL_NEW, 'thumbnail_url': TEST_FAKE_THUMBNAIL_URL_NEW}, 200), + ({'local_thumbnail': 'always', 'url': TEST_FAKE_URL}, {'url': TEST_FAKE_URL_NEW}, {'url': TEST_FAKE_URL_NEW}, 200), + ({'local': True}, {'categories': ['1'], 'starred': True}, {'categories': ['1'], 'starred': True}, 200), + ({'local': True}, {'url': TEST_FAKE_URL_NEW}, {}, 422), + ({'local': True}, {'thumbnail_url': TEST_FAKE_THUMBNAIL_URL_NEW}, {}, 422), + ({'local_thumbnail': 'always', 'url': TEST_FAKE_URL}, {'thumbnail_url': TEST_FAKE_THUMBNAIL_URL_NEW}, {}, 422), + ({'local': True}, {}, {}, 422), +] + + +@pytest.mark.parametrize('initial_param,update_param,expected_param,resp_code', TEST_UPDATE_OPT_PARAMS) +@pytest.mark.asyncio +async def test_update_opt(test_client, ensure_local_dir_empty, wait_for_background_task, # Fixtures + initial_param, update_param, expected_param, resp_code): # Parameters + with open(test_file_path, 'rb') as f: + resp = test_client.post('/admin/upload', + files={'image_file': f}, + params=initial_param) + assert resp.status_code == 200 + image_id = resp.json()['image_id'] + await wait_for_background_task(1) + + old_info = get_single_img_info(test_client, image_id) + + resp = test_client.put(f'/admin/update_opt/{image_id}', json=update_param) + assert resp.status_code == resp_code + + new_info = get_single_img_info(test_client, image_id) + # Ensure expected keys are updated + for key, value in expected_param.items(): + assert new_info[key] == value + del new_info[key] + + # Ensure that the other keys are kept untouched + for key, value in new_info.items(): + assert old_info[key] == value + + # cleanup + resp = test_client.delete(f'/admin/delete/{image_id}') + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_delete(test_client, ensure_local_dir_empty, wait_for_background_task): + with open(test_file_path, 'rb') as f: + resp = test_client.post('/admin/upload', + files={'image_file': f}, + params={'local': True}) + assert resp.status_code == 200 + image_id = resp.json()['image_id'] + await wait_for_background_task(1) + + img_query = test_client.get(f'/static/{image_id}.jpeg') + assert img_query.status_code == 200 + + resp = test_client.delete(f'/admin/delete/{image_id}') + assert resp.status_code == 200 + + img_query = test_client.get(f'/static/{image_id}.jpeg') + assert img_query.status_code == 404 + + query = test_client.get('/search/random') + assert query.status_code == 200 + assert not query.json()['result'] + + resp = test_client.delete(f'/admin/delete/{image_id}') + assert resp.status_code == 404 diff --git a/tests/assets/__init__.py b/tests/assets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..838176457960b0653d443b064f274f3477973ac6 --- /dev/null +++ b/tests/assets/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +assets_path = Path(__file__).parent diff --git a/tests/assets/test_images/bsn_0.jpg b/tests/assets/test_images/bsn_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..966a7efcb0f1261c70ba52e5295894599bada081 Binary files /dev/null and b/tests/assets/test_images/bsn_0.jpg differ diff --git a/tests/assets/test_images/bsn_1.jpg b/tests/assets/test_images/bsn_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a1eb294cd041b8f99d06321eaf62c6e2146efc36 Binary files /dev/null and b/tests/assets/test_images/bsn_1.jpg differ diff --git a/tests/assets/test_images/bsn_2.jpg b/tests/assets/test_images/bsn_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..35f8cb5d34480125dbf533e6cb1409c0bc558c8d Binary files /dev/null and b/tests/assets/test_images/bsn_2.jpg differ diff --git a/tests/assets/test_images/cat_0.jpg b/tests/assets/test_images/cat_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7157dc20f721fdfc7af62d33039870db114ed8b2 Binary files /dev/null and b/tests/assets/test_images/cat_0.jpg differ diff --git a/tests/assets/test_images/cat_1.jpg b/tests/assets/test_images/cat_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a10363135b1341b34232b1a3bc76dd55aefe3c0b Binary files /dev/null and b/tests/assets/test_images/cat_1.jpg differ diff --git a/tests/assets/test_images/cg_0.jpg b/tests/assets/test_images/cg_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5580f09eaa11dc6dbe20edee8b55533d31c57112 Binary files /dev/null and b/tests/assets/test_images/cg_0.jpg differ diff --git a/tests/assets/test_images/cg_1.png b/tests/assets/test_images/cg_1.png new file mode 100644 index 0000000000000000000000000000000000000000..ef66247ba3212a2d6f06d32d85f653b79b0f3724 Binary files /dev/null and b/tests/assets/test_images/cg_1.png differ diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/unit/test_image_uuid.py b/tests/unit/test_image_uuid.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3a41a630d30139e4f496c3af7738a8257a4de9 --- /dev/null +++ b/tests/unit/test_image_uuid.py @@ -0,0 +1,19 @@ +import io +from uuid import UUID + +from app.util.generate_uuid import generate_uuid +from ..assets import assets_path + +BSN_UUID = UUID('b3aff1e9-8085-5300-8e06-37b522384659') # To test consistency of UUID across versions + + +def test_uuid_consistency(): + file_path = assets_path / 'test_images' / 'bsn_0.jpg' + with open(file_path, 'rb') as f: + file_content = f.read() + + uuid1 = generate_uuid(file_path) + uuid2 = generate_uuid(io.BytesIO(file_content)) + uuid3 = generate_uuid(file_content) + + assert uuid1 == uuid2 == uuid3 == BSN_UUID diff --git a/tests/unit/test_retry_deco.py b/tests/unit/test_retry_deco.py new file mode 100644 index 0000000000000000000000000000000000000000..4a37eef4f59654fb7c4e5135f60aaf37b849c475 --- /dev/null +++ b/tests/unit/test_retry_deco.py @@ -0,0 +1,47 @@ +import asyncio + +import pytest + +from app.util.retry_deco_async import retry_async, wrap_object + + +class TestRetryDeco: + class ExampleClass: + def __init__(self): + self.counter = 0 + self.counter2 = 0 + self.not_func = 'not a function' + + async def example_method(self): + await asyncio.sleep(0) + self.counter += 1 + if self.counter < 3: + raise ValueError("Counter is less than 3") + return self.counter + + async def example_method_must_raise(self): + await asyncio.sleep(0) + self.counter2 += 1 + raise NotImplementedError("This method must raise an exception.") + + @pytest.mark.asyncio + async def test_decorator(self): + obj = self.ExampleClass() + + @retry_async(tries=3) + def caller(): + return obj.example_method() + + assert await caller() == 3 + + @pytest.mark.asyncio + async def test_object_wrapper(self): + obj = self.ExampleClass() + wrap_object(obj, retry_async(ValueError, tries=2)) + assert isinstance(obj.not_func, str) + with pytest.raises(ValueError): + await obj.example_method() + assert await obj.example_method() == 3 + with pytest.raises(NotImplementedError): + await obj.example_method_must_raise() + assert obj.counter2 == 1 diff --git a/tests/unit/test_transformers_service.py b/tests/unit/test_transformers_service.py new file mode 100644 index 0000000000000000000000000000000000000000..fec9abec5798fcda55389946ae0b1febe4d1e34b --- /dev/null +++ b/tests/unit/test_transformers_service.py @@ -0,0 +1,38 @@ +from PIL import Image + +from app.Services.transformers_service import TransformersService +from app.util.calculate_vectors_cosine import calculate_vectors_cosine +from ..assets import assets_path + + +class TestTransformersService: + + def setup_class(self): + self.transformers_service = TransformersService() + + def test_get_image_vector(self): + vector1 = self.transformers_service.get_image_vector(Image.open(assets_path / 'test_images/cat_0.jpg')) + vector2 = self.transformers_service.get_image_vector(Image.open(assets_path / 'test_images/cat_1.jpg')) + assert vector1.shape == (768,) + assert vector2.shape == (768,) + assert calculate_vectors_cosine(vector1, vector2) > 0.8 + + def test_get_text_vector(self): + vector1 = self.transformers_service.get_text_vector('1girl') + vector2 = self.transformers_service.get_text_vector('girl, solo') + assert vector1.shape == (768,) + assert vector2.shape == (768,) + assert calculate_vectors_cosine(vector1, vector2) > 0.8 + + def test_get_bert_vector(self): + vector1 = self.transformers_service.get_bert_vector('hi') + vector2 = self.transformers_service.get_bert_vector('hello') + assert vector1.shape == (768,) + assert vector2.shape == (768,) + assert calculate_vectors_cosine(vector1, vector2) > 0.8 + + def test_get_bert_vector_long_text(self): + vector1 = self.transformers_service.get_bert_vector('The quick brown fox jumps over the lazy dog ' * 100) + vector2 = self.transformers_service.get_bert_vector('我可以吞下玻璃而不伤身体' * 100) + assert vector1.shape == (768,) + assert vector2.shape == (768,) diff --git a/web/screenshots/1.png b/web/screenshots/1.png new file mode 100644 index 0000000000000000000000000000000000000000..16981f6ffc6488d7d97a2acce2ac3e6218ecb1f1 --- /dev/null +++ b/web/screenshots/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3198603e3cbd1bf9c18e9c677fef6de1c7d33418c385c8aae7dd1bdb2e81cab9 +size 1027053 diff --git a/web/screenshots/2.png b/web/screenshots/2.png new file mode 100644 index 0000000000000000000000000000000000000000..a8bcae41c8da1a7f9b5ae0c0308b8002d66c6f50 --- /dev/null +++ b/web/screenshots/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:581f865b1a4d59c7e6e12652b2d9382e4d4d8dabb5e87ee703f8cedf4a056b71 +size 1076203 diff --git a/web/screenshots/3.png b/web/screenshots/3.png new file mode 100644 index 0000000000000000000000000000000000000000..aa970642af286898e34267575d45f73337b54ba9 Binary files /dev/null and b/web/screenshots/3.png differ diff --git a/web/screenshots/4.png b/web/screenshots/4.png new file mode 100644 index 0000000000000000000000000000000000000000..83abe4ee183576f41e9774ff5db411ff12d01364 Binary files /dev/null and b/web/screenshots/4.png differ diff --git a/web/screenshots/5.png b/web/screenshots/5.png new file mode 100644 index 0000000000000000000000000000000000000000..50179dd16dfff066b868de43a131000f91c1453f Binary files /dev/null and b/web/screenshots/5.png differ diff --git a/web/screenshots/6.png b/web/screenshots/6.png new file mode 100644 index 0000000000000000000000000000000000000000..02ed89c91c2ec60e0536658aab047486c98c9884 Binary files /dev/null and b/web/screenshots/6.png differ