diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..94cfb9d7a9324c457e117cf2b403fca79b88b302 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,10 @@ +# from .gitignore +venv/ +ui/__pycache__/ +outputs/ +modules/__pycache__/ +models/ +modules/yt_tmp.wav + +.git +.github diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000000000000000000000000000000000000..e1e01f8cf68969bacc8af57d793bea9a4b0a3c6f --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,13 @@ +# These are supported funding model platforms + +github: [] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: jhj0517 +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000000000000000000000000000000000000..99e25a330c8f674a17c5431a323b146c95d265f8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,11 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: jhj0517 + +--- + +**Which OS are you using?** + - OS: [e.g. iOS or Windows.. If you are using Google Colab, just Colab.] diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000000000000000000000000000000000000..74981022b47d3038fdb054c3cf338d93232012ba --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,10 @@ +--- +name: Feature request +about: Any feature you want +title: '' +labels: enhancement +assignees: jhj0517 + +--- + + diff --git a/.github/ISSUE_TEMPLATE/hallucination.md b/.github/ISSUE_TEMPLATE/hallucination.md new file mode 100644 index 0000000000000000000000000000000000000000..ba43584f7765d84d41a565d96d8b76c3f187414e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/hallucination.md @@ -0,0 +1,12 @@ +--- +name: Hallucination +about: Whisper hallucinations. ( Repeating certain words or subtitles starting too + early, etc. ) +title: '' +labels: hallucination +assignees: jhj0517 + +--- + +**Download URL for sample audio** +- Please upload download URL for sample audio file so I can test with some settings for better result. You can use https://easyupload.io/ or any other service to share. diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000000000000000000000000000000000000..d33d497792bd3ca3415376b6b8daf32835b04692 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,5 @@ +## Related issues +- #0 + +## Changed +1. Changes diff --git a/.github/workflows/ci-shell.yml b/.github/workflows/ci-shell.yml new file mode 100644 index 0000000000000000000000000000000000000000..7f8e77a9f3b5a66445cae629225576a012fa69b4 --- /dev/null +++ b/.github/workflows/ci-shell.yml @@ -0,0 +1,43 @@ +name: CI-Shell Script + +on: + workflow_dispatch: + + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + test-shell-script: + + runs-on: ubuntu-latest + strategy: + matrix: + python: [ "3.10" ] + + steps: + - name: Clean up space for action + run: rm -rf /opt/hostedtoolcache + + - uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: Install git and ffmpeg + run: sudo apt-get update && sudo apt-get install -y git ffmpeg + + - name: Execute Install.sh + run: | + chmod +x ./Install.sh + ./Install.sh + + - name: Execute start-webui.sh + run: | + chmod +x ./start-webui.sh + timeout 60s ./start-webui.sh || true + diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..33a084802a8302a86f32c31ad57b4b480d59ed24 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,41 @@ +name: CI + +on: + workflow_dispatch: + + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python: ["3.10"] + + env: + DEEPL_API_KEY: ${{ secrets.DEEPL_API_KEY }} + + steps: + - name: Clean up space for action + run: rm -rf /opt/hostedtoolcache + + - uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: Install git and ffmpeg + run: sudo apt-get update && sudo apt-get install -y git ffmpeg + + - name: Install dependencies + run: pip install -r requirements.txt pytest + + - name: Run test + run: python -m pytest -rs tests \ No newline at end of file diff --git a/.github/workflows/publish-docker.yml b/.github/workflows/publish-docker.yml new file mode 100644 index 0000000000000000000000000000000000000000..99da1b6da36e42b5eebc872bf2a05d6118fc7e50 --- /dev/null +++ b/.github/workflows/publish-docker.yml @@ -0,0 +1,37 @@ +name: Publish to Docker Hub + +on: + push: + branches: + - master + +jobs: + build-and-push: + runs-on: ubuntu-latest + + steps: + - name: Log in to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile + push: true + tags: ${{ secrets.DOCKER_USERNAME }}/whisper-webui:latest + + - name: Log out of Docker Hub + run: docker logout diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9cd9e3837fec930d3df2ff4430af26f40e672a82 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +*.wav +*.png +*.mp4 +*.mp3 +.idea/ +.pytest_cache/ +venv/ +modules/ui/__pycache__/ +outputs/ +modules/__pycache__/ +models/ +modules/yt_tmp.wav +configs/default_parameters.yaml \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..5604def4a9c7cf253d1d2fcbbcf005615691cf69 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +FROM debian:bookworm-slim AS builder + +RUN apt-get update && \ + apt-get install -y curl git python3 python3-pip python3-venv && \ + rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && \ + mkdir -p /Whisper-WebUI + +WORKDIR /Whisper-WebUI + +COPY requirements.txt . + +RUN python3 -m venv venv && \ + . venv/bin/activate && \ + pip install --no-cache-dir -r requirements.txt + + +FROM debian:bookworm-slim AS runtime + +RUN apt-get update && \ + apt-get install -y curl ffmpeg python3 && \ + rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +WORKDIR /Whisper-WebUI + +COPY . . +COPY --from=builder /Whisper-WebUI/venv /Whisper-WebUI/venv + +VOLUME [ "/Whisper-WebUI/models" ] +VOLUME [ "/Whisper-WebUI/outputs" ] + +ENV PATH="/Whisper-WebUI/venv/bin:$PATH" +ENV LD_LIBRARY_PATH=/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cublas/lib:/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cudnn/lib + +ENTRYPOINT [ "python", "app.py" ] diff --git a/Install.bat b/Install.bat new file mode 100644 index 0000000000000000000000000000000000000000..7c3f496a2091ba89b3e6f8582cbfaa35d50b7b19 --- /dev/null +++ b/Install.bat @@ -0,0 +1,20 @@ +@echo off + +if not exist "%~dp0\venv\Scripts" ( + echo Creating venv... + python -m venv venv +) +echo checked the venv folder. now installing requirements.. + +call "%~dp0\venv\scripts\activate" + +pip install -r requirements.txt + +if errorlevel 1 ( + echo. + echo Requirements installation failed. please remove venv folder and run install.bat again. +) else ( + echo. + echo Requirements installed successfully. +) +pause \ No newline at end of file diff --git a/Install.sh b/Install.sh new file mode 100644 index 0000000000000000000000000000000000000000..6ba3148ebd904101496e7198040d65569c1fb6b5 --- /dev/null +++ b/Install.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +if [ ! -d "venv" ]; then + echo "Creating virtual environment..." + python -m venv venv +fi + +source venv/bin/activate + +pip install -r requirements.txt && echo "Requirements installed successfully." || { + echo "" + echo "Requirements installation failed. Please remove the venv folder and run the script again." + deactivate + exit 1 +} + +deactivate diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..dd84d44b86260eb3206817f7c184f0534c1bd5a8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 jhj0517 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 0f8816a69a7e9d6cc1402378fce440682bbfec29..af4b32a98373691f297ee0c1f2805c4b6f7a6f03 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,117 @@ ---- -title: Whisper WebUI -emoji: πŸš€ -colorFrom: red -colorTo: pink -sdk: gradio -sdk_version: 5.5.0 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Whisper-WebUI +A Gradio-based browser interface for [Whisper](https://github.com/openai/whisper). You can use it as an Easy Subtitle Generator! + +![Whisper WebUI](https://github.com/jhj0517/Whsiper-WebUI/blob/master/screenshot.png) + +## Notebook +If you wish to try this on Colab, you can do it in [here](https://colab.research.google.com/github/jhj0517/Whisper-WebUI/blob/master/notebook/whisper-webui.ipynb)! + +# Feature +- Select the Whisper implementation you want to use between : + - [openai/whisper](https://github.com/openai/whisper) + - [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper) (used by default) + - [Vaibhavs10/insanely-fast-whisper](https://github.com/Vaibhavs10/insanely-fast-whisper) +- Generate subtitles from various sources, including : + - Files + - Youtube + - Microphone +- Currently supported subtitle formats : + - SRT + - WebVTT + - txt ( only text file without timeline ) +- Speech to Text Translation + - From other languages to English. ( This is Whisper's end-to-end speech-to-text translation feature ) +- Text to Text Translation + - Translate subtitle files using Facebook NLLB models + - Translate subtitle files using DeepL API +- Pre-processing audio input with [Silero VAD](https://github.com/snakers4/silero-vad). +- Pre-processing audio input to separate BGM with [UVR](https://github.com/Anjok07/ultimatevocalremovergui), [UVR-api](https://github.com/NextAudioGen/ultimatevocalremover_api). +- Post-processing with speaker diarization using the [pyannote](https://huggingface.co/pyannote/speaker-diarization-3.1) model. + - To download the pyannote model, you need to have a Huggingface token and manually accept their terms in the pages below. + 1. https://huggingface.co/pyannote/speaker-diarization-3.1 + 2. https://huggingface.co/pyannote/segmentation-3.0 + +# Installation and Running +### Prerequisite +To run this WebUI, you need to have `git`, `python` version 3.8 ~ 3.10, `FFmpeg`.
+And if you're not using an Nvida GPU, or using a different `CUDA` version than 12.4, edit the [`requirements.txt`](https://github.com/jhj0517/Whisper-WebUI/blob/master/requirements.txt) to match your environment. + +Please follow the links below to install the necessary software: +- git : [https://git-scm.com/downloads](https://git-scm.com/downloads) +- python : [https://www.python.org/downloads/](https://www.python.org/downloads/) **( If your python version is too new, torch will not install properly.)** +- FFmpeg : [https://ffmpeg.org/download.html](https://ffmpeg.org/download.html) +- CUDA : [https://developer.nvidia.com/cuda-downloads](https://developer.nvidia.com/cuda-downloads) + +After installing FFmpeg, **make sure to add the `FFmpeg/bin` folder to your system PATH!** + +### Automatic Installation + +1. Download `Whisper-WebUI.zip` with the file corresponding to your OS from [v1.0.0](https://github.com/jhj0517/Whisper-WebUI/releases/tag/v1.0.0) and extract its contents. +2. Run `install.bat` or `install.sh` to install dependencies. (This will create a `venv` directory and install dependencies there.) +3. Start WebUI with `start-webui.bat` or `start-webui.sh` +4. To update the WebUI, run `update.bat` or `update.sh` + +And you can also run the project with command line arguments if you like to, see [wiki](https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments) for a guide to arguments. + +- ## Running with Docker + +1. Install and launch [Docker-Desktop](https://www.docker.com/products/docker-desktop/). + +2. Git clone the repository + +```sh +git clone https://github.com/jhj0517/Whisper-WebUI.git +``` + +3. Build the image ( Image is about 7GB~ ) + +```sh +docker compose build +``` + +4. Run the container + +```sh +docker compose up +``` + +5. Connect to the WebUI with your browser at `http://localhost:7860` + +If needed, update the [`docker-compose.yaml`](https://github.com/jhj0517/Whisper-WebUI/blob/master/docker-compose.yaml) to match your environment. + +# VRAM Usages +This project is integrated with [faster-whisper](https://github.com/guillaumekln/faster-whisper) by default for better VRAM usage and transcription speed. + +According to faster-whisper, the efficiency of the optimized whisper model is as follows: +| Implementation | Precision | Beam size | Time | Max. GPU memory | Max. CPU memory | +|-------------------|-----------|-----------|-------|-----------------|-----------------| +| openai/whisper | fp16 | 5 | 4m30s | 11325MB | 9439MB | +| faster-whisper | fp16 | 5 | 54s | 4755MB | 3244MB | + +If you want to use an implementation other than faster-whisper, use `--whisper_type` arg and the repository name.
+Read [wiki](https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments) for more info about CLI args. + +## Available models +This is Whisper's original VRAM usage table for models. + +| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed | +|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:| +| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x | +| base | 74 M | `base.en` | `base` | ~1 GB | ~16x | +| small | 244 M | `small.en` | `small` | ~2 GB | ~6x | +| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x | +| large | 1550 M | N/A | `large` | ~10 GB | 1x | + + +`.en` models are for English only, and the cool thing is that you can use the `Translate to English` option from the "large" models! + +## TODOπŸ—“ + +- [x] Add DeepL API translation +- [x] Add NLLB Model translation +- [x] Integrate with faster-whisper +- [x] Integrate with insanely-fast-whisper +- [x] Integrate with whisperX ( Only speaker diarization part ) +- [x] Add background music separation pre-processing with [UVR](https://github.com/Anjok07/ultimatevocalremovergui) +- [ ] Add fast api script +- [ ] Support real-time transcription for microphone diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..46bb0e9febf8039f84f2d4483ac9b5fc15fb8c4a --- /dev/null +++ b/app.py @@ -0,0 +1,359 @@ +import os +import argparse +import gradio as gr +import yaml + +from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR, + INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, + UVR_MODELS_DIR) +from modules.utils.files_manager import load_yaml +from modules.whisper.whisper_factory import WhisperFactory +from modules.whisper.faster_whisper_inference import FasterWhisperInference +from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference +from modules.translation.nllb_inference import NLLBInference +from modules.ui.htmls import * +from modules.utils.cli_manager import str2bool +from modules.utils.youtube_manager import get_ytmetas +from modules.translation.deepl_api import DeepLAPI +from modules.whisper.whisper_parameter import * + +### Device info ### +import torch +import torchaudio +import torch.cuda as cuda +import platform +from transformers import __version__ as transformers_version + +device = "cuda" if torch.cuda.is_available() else "cpu" +num_gpus = cuda.device_count() if torch.cuda.is_available() else 0 +cuda_version = torch.version.cuda if torch.cuda.is_available() else "N/A" +cudnn_version = torch.backends.cudnn.version() if torch.cuda.is_available() else "N/A" +os_info = platform.system() + " " + platform.release() + " " + platform.machine() + +# Get the available VRAM for each GPU (if available) +vram_info = [] +if torch.cuda.is_available(): + for i in range(cuda.device_count()): + gpu_properties = cuda.get_device_properties(i) + vram_info.append(f"**GPU {i}: {gpu_properties.total_memory / 1024**3:.2f} GB**") + +pytorch_version = torch.__version__ +torchaudio_version = torchaudio.__version__ if 'torchaudio' in dir() else "N/A" + +device_info = f"""Running on: **{device}** + + Number of GPUs available: **{num_gpus}** + + CUDA version: **{cuda_version}** + + CuDNN version: **{cudnn_version}** + + PyTorch version: **{pytorch_version}** + + Torchaudio version: **{torchaudio_version}** + + Transformers version: **{transformers_version}** + + Operating system: **{os_info}** + + Available VRAM: + \t {', '.join(vram_info) if vram_info else '**N/A**'} +""" +### End Device info ### + +class App: + def __init__(self, args): + self.args = args + #self.app = gr.Blocks(css=CSS, theme=self.args.theme, delete_cache=(60, 3600)) + self.app = gr.Blocks(css=CSS, theme=gr.themes.Ocean(), delete_cache=(60, 3600)) + self.whisper_inf = WhisperFactory.create_whisper_inference( + whisper_type=self.args.whisper_type, + whisper_model_dir=self.args.whisper_model_dir, + faster_whisper_model_dir=self.args.faster_whisper_model_dir, + insanely_fast_whisper_model_dir=self.args.insanely_fast_whisper_model_dir, + uvr_model_dir=self.args.uvr_model_dir, + output_dir=self.args.output_dir, + ) + self.nllb_inf = NLLBInference( + model_dir=self.args.nllb_model_dir, + output_dir=os.path.join(self.args.output_dir, "translations") + ) + self.deepl_api = DeepLAPI( + output_dir=os.path.join(self.args.output_dir, "translations") + ) + self.default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) + print(f"Use \"{self.args.whisper_type}\" implementation") + print(f"Device \"{self.whisper_inf.device}\" is detected") + + def create_whisper_parameters(self): + + whisper_params = self.default_params["whisper"] + diarization_params = self.default_params["diarization"] + vad_params = self.default_params["vad"] + uvr_params = self.default_params["bgm_separation"] + + with gr.Row(): + dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value=whisper_params["model_size"],label="Model") + dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,value=whisper_params["lang"], label="Language") + #dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format") + dd_file_format = gr.Dropdown(choices=["SRT", "txt"], value="SRT", label="Output format") + + with gr.Row(): + cb_timestamp = gr.Checkbox(value=whisper_params["add_timestamp"], label="Add timestamp to output file",interactive=True) + cb_diarize = gr.Checkbox(label="Speaker diarization", value=diarization_params["is_diarize"]) + cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label="Translate to English",interactive=True) + + with gr.Accordion("Diarization options", open=False): + tb_hf_token = gr.Text(label="HuggingFace Token", value=diarization_params["hf_token"], + info="This is only needed the first time you download the model. If you already have" + " models, you don't need to enter. To download the model, you must manually go " + "to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to" + " their requirement.") + dd_diarization_device = gr.Dropdown(label="Device", + choices=self.whisper_inf.diarizer.get_available_device(), + value=self.whisper_inf.diarizer.get_device()) + + with gr.Accordion("Advanced options", open=False): + nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, interactive=True, + info="Beam size to use for decoding.") + nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=whisper_params["log_prob_threshold"], interactive=True, + info="If the average log probability over sampled tokens is below this value, treat as failed.") + nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"], interactive=True, + info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.") + dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, + value=self.whisper_inf.current_compute_type, interactive=True, + allow_custom_value=True, + info="Select the type of computation to perform.") + nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True, + info="Number of candidates when sampling with non-zero temperature.") + nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True, + info="Beam search patience factor.") + cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=whisper_params["condition_on_previous_text"], + interactive=True, + info="Condition on previous text during decoding.") + sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", value=whisper_params["prompt_reset_on_temperature"], + minimum=0, maximum=1, step=0.01, interactive=True, + info="Resets prompt if temperature is above this value." + " Arg has effect only if 'Condition On Previous Text' is True.") + tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True, + info="Initial prompt to use for decoding.") + sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0, + step=0.01, maximum=1.0, interactive=True, + info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.") + nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=whisper_params["compression_ratio_threshold"], + interactive=True, + info="If the gzip compression ratio is above this value, treat as failed.") + nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"], + precision=0, + info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.") + with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)): + nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"], + info="Exponential length penalty constant.") + nb_repetition_penalty = gr.Number(label="Repetition Penalty", value=whisper_params["repetition_penalty"], + info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).") + nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", value=whisper_params["no_repeat_ngram_size"], + precision=0, + info="Prevent repetitions of n-grams with this size (set 0 to disable).") + tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"], + info="Optional text to provide as a prefix for the first window.") + cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"], + info="Suppress blank outputs at the beginning of the sampling.") + tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"], + info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.") + nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", value=whisper_params["max_initial_timestamp"], + info="The initial timestamp cannot be later than this.") + cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"], + info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.") + tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", value=whisper_params["prepend_punctuations"], + info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.") + tb_append_punctuations = gr.Textbox(label="Append Punctuations", value=whisper_params["append_punctuations"], + info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.") + nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"], + precision=0, + info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.") + nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)", + value=lambda: whisper_params["hallucination_silence_threshold"], + info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.") + tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"], + info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.") + nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", value=lambda: whisper_params["language_detection_threshold"], + info="If the maximum probability of the language tokens is higher than this value, the language is detected.") + nb_language_detection_segments = gr.Number(label="Language Detection Segments", value=lambda: whisper_params["language_detection_segments"], + precision=0, + info="Number of segments to consider for the language detection.") + with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)): + nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0) + + with gr.Accordion("Background Music Remover Filter", open=False): + cb_bgm_separation = gr.Checkbox(label="Enable Background Music Remover Filter", value=uvr_params["is_separate_bgm"], + interactive=True, + info="Enabling this will remove background music by submodel before" + " transcribing ") + dd_uvr_device = gr.Dropdown(label="Device", value=self.whisper_inf.music_separator.device, + choices=self.whisper_inf.music_separator.available_devices) + dd_uvr_model_size = gr.Dropdown(label="Model", value=uvr_params["model_size"], + choices=self.whisper_inf.music_separator.available_models) + nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0) + cb_uvr_save_file = gr.Checkbox(label="Save separated files to output", value=uvr_params["save_file"]) + cb_uvr_enable_offload = gr.Checkbox(label="Offload sub model after removing background music", + value=uvr_params["enable_offload"]) + + with gr.Accordion("Voice Detection Filter", open=False): + cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=vad_params["vad_filter"], + interactive=True, + info="Enable this to transcribe only detected voice parts by submodel.") + sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", + value=vad_params["threshold"], + info="Lower it to be more sensitive to small sounds.") + nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, + value=vad_params["min_speech_duration_ms"], + info="Final speech chunks shorter than this time are thrown out") + nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", + value=vad_params["max_speech_duration_s"], + info="Maximum duration of speech chunks in \"seconds\".") + nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, + value=vad_params["min_silence_duration_ms"], + info="In the end of each speech chunk wait for this time" + " before separating it") + nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"], + info="Final speech chunks are padded by this time each side") + + dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate]) + + return ( + WhisperParameters( + model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size, + log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold, + compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience, + condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt, + temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold, + vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms, + max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms, + speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size, + is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device, + length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty, + no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank, + suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp, + word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations, + append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens, + hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords, + language_detection_threshold=nb_language_detection_threshold, + language_detection_segments=nb_language_detection_segments, + prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation, + uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size, + uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload + ), + dd_file_format, + cb_timestamp + ) + + def launch(self): + translation_params = self.default_params["translation"] + deepl_params = translation_params["deepl"] + nllb_params = translation_params["nllb"] + uvr_params = self.default_params["bgm_separation"] + + with self.app: + with gr.Row(): + with gr.Column(): + gr.Markdown(MARKDOWN, elem_id="md_project") + with gr.Tabs(): + with gr.TabItem("Audio"): # tab1 + with gr.Column(): + #input_file = gr.Files(type="filepath", label="Upload File here") + input_file = gr.Audio(type='filepath', elem_id="audio_input") + tb_input_folder = gr.Textbox(label="Input Folder Path (Optional)", + info="Optional: Specify the folder path where the input files are located, if you prefer to use local files instead of uploading them." + " Leave this field empty if you do not wish to use a local path.", + visible=self.args.colab, + value="") + + whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters() + + with gr.Row(): + btn_run = gr.Button("Transcribe", variant="primary") + btn_reset = gr.Button(value="Reset") + btn_reset.click(None,js="window.location.reload()") + with gr.Row(): + with gr.Column(scale=3): + tb_indicator = gr.Textbox(label="Output result") + with gr.Column(scale=1): + tb_info = gr.Textbox(label="Output info", interactive=False, scale=3) + files_subtitles = gr.Files(label="Output file", interactive=False, scale=2) + # btn_openfolder = gr.Button('πŸ“‚', scale=1) + + params = [input_file, tb_input_folder, dd_file_format, cb_timestamp] + btn_run.click(fn=self.whisper_inf.transcribe_file, + inputs=params + whisper_params.as_list(), + outputs=[tb_indicator, files_subtitles, tb_info]) + # btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None) + + with gr.TabItem("Device info"): # tab2 + with gr.Column(): + gr.Markdown(device_info, label="Hardware info & installed packages") + + # Launch the app with optional gradio settings + args = self.args + + self.app.queue( + api_open=args.api_open + ).launch( + share=args.share, + server_name=args.server_name, + server_port=args.server_port, + auth=(args.username, args.password) if args.username and args.password else None, + root_path=args.root_path, + inbrowser=args.inbrowser + ) + + @staticmethod + def open_folder(folder_path: str): + if os.path.exists(folder_path): + os.system(f"start {folder_path}") + else: + os.makedirs(folder_path, exist_ok=True) + print(f"The directory path {folder_path} has newly created.") + + @staticmethod + def on_change_models(model_size: str): + translatable_model = ["large", "large-v1", "large-v2", "large-v3"] + if model_size not in translatable_model: + return gr.Checkbox(visible=False, value=False, interactive=False) + #return gr.Checkbox(visible=True, value=False, label="Translate to English (large models only)", interactive=False) + else: + return gr.Checkbox(visible=True, value=False, label="Translate to English", interactive=True) + + +# Create the parser for command-line arguments +parser = argparse.ArgumentParser() +parser.add_argument('--whisper_type', type=str, default="faster-whisper", + help='A type of the whisper implementation between: ["whisper", "faster-whisper", "insanely-fast-whisper"]') +parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value') +parser.add_argument('--server_name', type=str, default=None, help='Gradio server host') +parser.add_argument('--server_port', type=int, default=None, help='Gradio server port') +parser.add_argument('--root_path', type=str, default=None, help='Gradio root path') +parser.add_argument('--username', type=str, default=None, help='Gradio authentication username') +parser.add_argument('--password', type=str, default=None, help='Gradio authentication password') +parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme') +parser.add_argument('--colab', type=str2bool, default=False, nargs='?', const=True, help='Is colab user or not') +parser.add_argument('--api_open', type=str2bool, default=False, nargs='?', const=True, help='Enable api or not in Gradio') +parser.add_argument('--inbrowser', type=str2bool, default=True, nargs='?', const=True, help='Whether to automatically start Gradio app or not') +parser.add_argument('--whisper_model_dir', type=str, default=WHISPER_MODELS_DIR, + help='Directory path of the whisper model') +parser.add_argument('--faster_whisper_model_dir', type=str, default=FASTER_WHISPER_MODELS_DIR, + help='Directory path of the faster-whisper model') +parser.add_argument('--insanely_fast_whisper_model_dir', type=str, + default=INSANELY_FAST_WHISPER_MODELS_DIR, + help='Directory path of the insanely-fast-whisper model') +parser.add_argument('--diarization_model_dir', type=str, default=DIARIZATION_MODELS_DIR, + help='Directory path of the diarization model') +parser.add_argument('--nllb_model_dir', type=str, default=NLLB_MODELS_DIR, + help='Directory path of the Facebook NLLB model') +parser.add_argument('--uvr_model_dir', type=str, default=UVR_MODELS_DIR, + help='Directory path of the UVR model') +parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='Directory path of the outputs') +_args = parser.parse_args() + +if __name__ == "__main__": + app = App(args=_args) + app.launch() diff --git a/configs/default_parameters.yaml b/configs/default_parameters.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8eace9295fd6f6ccece85ecfe53573e9b70367f9 --- /dev/null +++ b/configs/default_parameters.yaml @@ -0,0 +1,64 @@ +whisper: + model_size: "large-v3" + lang: "Automatic Detection" + is_translate: false + beam_size: 5 + log_prob_threshold: -1 + no_speech_threshold: 0.6 + best_of: 5 + patience: 1 + condition_on_previous_text: true + prompt_reset_on_temperature: 0.5 + initial_prompt: null + temperature: 0 + compression_ratio_threshold: 2.4 + chunk_length: 30 + batch_size: 24 + length_penalty: 1 + repetition_penalty: 1 + no_repeat_ngram_size: 0 + prefix: null + suppress_blank: true + suppress_tokens: "[-1]" + max_initial_timestamp: 1 + word_timestamps: false + prepend_punctuations: "\"'β€œΒΏ([{-" + append_punctuations: "\"'.。,,!!??:οΌšβ€)]}、" + max_new_tokens: null + hallucination_silence_threshold: null + hotwords: null + language_detection_threshold: null + language_detection_segments: 1 + add_timestamp: false + +vad: + vad_filter: false + threshold: 0.5 + min_speech_duration_ms: 250 + max_speech_duration_s: 9999 + min_silence_duration_ms: 1000 + speech_pad_ms: 2000 + +diarization: + is_diarize: false + hf_token: "" + +bgm_separation: + is_separate_bgm: false + model_size: "UVR-MDX-NET-Inst_HQ_4" + segment_size: 256 + save_file: false + enable_offload: true + +translation: + deepl: + api_key: "" + is_pro: false + source_lang: "Automatic Detection" + target_lang: "English" + nllb: + model_size: "facebook/nllb-200-1.3B" + source_lang: null + target_lang: null + max_length: 200 + add_timestamp: true diff --git a/demo/audio.wav b/demo/audio.wav new file mode 100644 index 0000000000000000000000000000000000000000..34a90504fb551b558e93ca3a899221ceda3b2686 Binary files /dev/null and b/demo/audio.wav differ diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0ccfbc3a817ee5ec2a9b747070f8a49f5aef055 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,29 @@ +services: + app: + build: . + image: whisper-webui:latest + + volumes: + # Update paths to mount models and output paths to your custom paths like this, e.g: + # - C:/whisper-models/custom-path:/Whisper-WebUI/models + # - C:/whisper-webui-outputs/custom-path:/Whisper-WebUI/outputs + - /Whisper-WebUI/models + - /Whisper-WebUI/outputs + + ports: + - "7860:7860" + + stdin_open: true + tty: true + + entrypoint: ["python", "app.py", "--server_port", "7860", "--server_name", "0.0.0.0",] + + # If you're not using nvidia GPU, Update device to match yours. + # See more info at : https://docs.docker.com/compose/compose-file/deploy/#driver + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [ gpu ] diff --git a/models/models will be saved here.txt b/models/models will be saved here.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/diarize/__init__.py b/modules/diarize/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/diarize/audio_loader.py b/modules/diarize/audio_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..d90e52c3eea45c1e737e789cde9bcc637c46da90 --- /dev/null +++ b/modules/diarize/audio_loader.py @@ -0,0 +1,179 @@ +# Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/audio.py + +import os +import subprocess +from functools import lru_cache +from typing import Optional, Union +from scipy.io.wavfile import write +import tempfile + +import numpy as np +import torch +import torch.nn.functional as F + +def exact_div(x, y): + assert x % y == 0 + return x // y + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input + +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token + + +def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray: + """ + Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary. + + Parameters + ---------- + file: Union[str, np.ndarray] + The audio file to open or a numpy array containing the audio data. + + sr: int + The sample rate to resample the audio if necessary. + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + if isinstance(file, np.ndarray): + if file.dtype != np.float32: + file = file.astype(np.float32) + if file.ndim > 1: + file = np.mean(file, axis=1) + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16)) + temp_file_path = temp_file.name + temp_file.close() + else: + temp_file_path = file + + try: + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", + "0", + "-i", + temp_file_path, + "-f", + "s16le", + "-ac", + "1", + "-acodec", + "pcm_s16le", + "-ar", + str(sr), + "-", + ] + out = subprocess.run(cmd, capture_output=True, check=True).stdout + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + finally: + if isinstance(file, np.ndarray): + os.remove(temp_file_path) + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}" + with np.load( + os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + ) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec \ No newline at end of file diff --git a/modules/diarize/diarize_pipeline.py b/modules/diarize/diarize_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..b4109e8474a37a35f423cd1f0daa0ddc22ab1bef --- /dev/null +++ b/modules/diarize/diarize_pipeline.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py + +import numpy as np +import pandas as pd +import os +from pyannote.audio import Pipeline +from typing import Optional, Union +import torch + +from modules.utils.paths import DIARIZATION_MODELS_DIR +from modules.diarize.audio_loader import load_audio, SAMPLE_RATE + + +class DiarizationPipeline: + def __init__( + self, + model_name="pyannote/speaker-diarization-3.1", + cache_dir: str = DIARIZATION_MODELS_DIR, + use_auth_token=None, + device: Optional[Union[str, torch.device]] = "cpu", + ): + if isinstance(device, str): + device = torch.device(device) + self.model = Pipeline.from_pretrained( + model_name, + use_auth_token=use_auth_token, + cache_dir=cache_dir + ).to(device) + + def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None): + if isinstance(audio, str): + audio = load_audio(audio) + audio_data = { + 'waveform': torch.from_numpy(audio[None, :]), + 'sample_rate': SAMPLE_RATE + } + segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers) + diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker']) + diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start) + diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end) + return diarize_df + + +def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): + transcript_segments = transcript_result["segments"] + for seg in transcript_segments: + # assign speaker to segment (if any) + diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], + seg['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start']) + + intersected = diarize_df[diarize_df["intersection"] > 0] + + speaker = None + if len(intersected) > 0: + # Choosing most strong intersection + speaker = intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] + elif fill_nearest: + # Otherwise choosing closest + speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0] + + if speaker is not None: + seg["speaker"] = speaker + + # assign speaker to words + if 'words' in seg: + for word in seg['words']: + if 'start' in word: + diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum( + diarize_df['start'], word['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], + word['start']) + + intersected = diarize_df[diarize_df["intersection"] > 0] + + word_speaker = None + if len(intersected) > 0: + # Choosing most strong intersection + word_speaker = \ + intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] + elif fill_nearest: + # Otherwise choosing closest + word_speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0] + + if word_speaker is not None: + word["speaker"] = word_speaker + + return transcript_result + + +class Segment: + def __init__(self, start, end, speaker=None): + self.start = start + self.end = end + self.speaker = speaker diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e24adc75f2b65ae99976423424466af194f55552 --- /dev/null +++ b/modules/diarize/diarizer.py @@ -0,0 +1,133 @@ +import os +import torch +from typing import List, Union, BinaryIO, Optional +import numpy as np +import time +import logging + +from modules.utils.paths import DIARIZATION_MODELS_DIR +from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers +from modules.diarize.audio_loader import load_audio + + +class Diarizer: + def __init__(self, + model_dir: str = DIARIZATION_MODELS_DIR + ): + self.device = self.get_device() + self.available_device = self.get_available_device() + self.compute_type = "float16" + self.model_dir = model_dir + os.makedirs(self.model_dir, exist_ok=True) + self.pipe = None + + def run(self, + audio: Union[str, BinaryIO, np.ndarray], + transcribed_result: List[dict], + use_auth_token: str, + device: Optional[str] = None + ): + """ + Diarize transcribed result as a post-processing + + Parameters + ---------- + audio: Union[str, BinaryIO, np.ndarray] + Audio input. This can be file path or binary type. + transcribed_result: List[dict] + transcribed result through whisper. + use_auth_token: str + Huggingface token with READ permission. This is only needed the first time you download the model. + You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model. + device: Optional[str] + Device for diarization. + + Returns + ---------- + segments_result: List[dict] + list of dicts that includes start, end timestamps and transcribed text + elapsed_time: float + elapsed time for running + """ + start_time = time.time() + + if device is None: + device = self.device + + if device != self.device or self.pipe is None: + self.update_pipe( + device=device, + use_auth_token=use_auth_token + ) + + audio = load_audio(audio) + + diarization_segments = self.pipe(audio) + diarized_result = assign_word_speakers( + diarization_segments, + {"segments": transcribed_result} + ) + + for segment in diarized_result["segments"]: + speaker = "None" + if "speaker" in segment: + speaker = segment["speaker"] + segment["text"] = speaker + ": " + segment["text"].strip() + + elapsed_time = time.time() - start_time + return diarized_result["segments"], elapsed_time + + def update_pipe(self, + use_auth_token: str, + device: str + ): + """ + Set pipeline for diarization + + Parameters + ---------- + use_auth_token: str + Huggingface token with READ permission. This is only needed the first time you download the model. + You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model. + device: str + Device for diarization. + """ + self.device = device + + os.makedirs(self.model_dir, exist_ok=True) + + if (not os.listdir(self.model_dir) and + not use_auth_token): + print( + "\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n" + "Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n" + ) + return + + logger = logging.getLogger("speechbrain.utils.train_logger") + # Disable redundant torchvision warning message + logger.disabled = True + self.pipe = DiarizationPipeline( + use_auth_token=use_auth_token, + device=device, + cache_dir=self.model_dir + ) + logger.disabled = False + + @staticmethod + def get_device(): + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" + + @staticmethod + def get_available_device(): + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda") + elif torch.backends.mps.is_available(): + devices.append("mps") + return devices \ No newline at end of file diff --git a/modules/translation/__init__.py b/modules/translation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/translation/deepl_api.py b/modules/translation/deepl_api.py new file mode 100644 index 0000000000000000000000000000000000000000..385b3a14bfa201021d45d6a0ecf1dc83c1c955f6 --- /dev/null +++ b/modules/translation/deepl_api.py @@ -0,0 +1,226 @@ +import requests +import time +import os +from datetime import datetime +import gradio as gr + +from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH +from modules.utils.subtitle_manager import * +from modules.utils.files_manager import load_yaml, save_yaml + +""" +This is written with reference to the DeepL API documentation. +If you want to know the information of the DeepL API, see here: https://www.deepl.com/docs-api/documents +""" + +DEEPL_AVAILABLE_TARGET_LANGS = { + 'Bulgarian': 'BG', + 'Czech': 'CS', + 'Danish': 'DA', + 'German': 'DE', + 'Greek': 'EL', + 'English': 'EN', + 'English (British)': 'EN-GB', + 'English (American)': 'EN-US', + 'Spanish': 'ES', + 'Estonian': 'ET', + 'Finnish': 'FI', + 'French': 'FR', + 'Hungarian': 'HU', + 'Indonesian': 'ID', + 'Italian': 'IT', + 'Japanese': 'JA', + 'Korean': 'KO', + 'Lithuanian': 'LT', + 'Latvian': 'LV', + 'Norwegian (BokmΓ₯l)': 'NB', + 'Dutch': 'NL', + 'Polish': 'PL', + 'Portuguese': 'PT', + 'Portuguese (Brazilian)': 'PT-BR', + 'Portuguese (all Portuguese varieties excluding Brazilian Portuguese)': 'PT-PT', + 'Romanian': 'RO', + 'Russian': 'RU', + 'Slovak': 'SK', + 'Slovenian': 'SL', + 'Swedish': 'SV', + 'Turkish': 'TR', + 'Ukrainian': 'UK', + 'Chinese (simplified)': 'ZH' +} + +DEEPL_AVAILABLE_SOURCE_LANGS = { + 'Automatic Detection': None, + 'Bulgarian': 'BG', + 'Czech': 'CS', + 'Danish': 'DA', + 'German': 'DE', + 'Greek': 'EL', + 'English': 'EN', + 'Spanish': 'ES', + 'Estonian': 'ET', + 'Finnish': 'FI', + 'French': 'FR', + 'Hungarian': 'HU', + 'Indonesian': 'ID', + 'Italian': 'IT', + 'Japanese': 'JA', + 'Korean': 'KO', + 'Lithuanian': 'LT', + 'Latvian': 'LV', + 'Norwegian (BokmΓ₯l)': 'NB', + 'Dutch': 'NL', + 'Polish': 'PL', + 'Portuguese (all Portuguese varieties mixed)': 'PT', + 'Romanian': 'RO', + 'Russian': 'RU', + 'Slovak': 'SK', + 'Slovenian': 'SL', + 'Swedish': 'SV', + 'Turkish': 'TR', + 'Ukrainian': 'UK', + 'Chinese': 'ZH' +} + + +class DeepLAPI: + def __init__(self, + output_dir: str = TRANSLATION_OUTPUT_DIR + ): + self.api_interval = 1 + self.max_text_batch_size = 50 + self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS + self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS + self.output_dir = output_dir + + def translate_deepl(self, + auth_key: str, + fileobjs: list, + source_lang: str, + target_lang: str, + is_pro: bool = False, + add_timestamp: bool = True, + progress=gr.Progress()) -> list: + """ + Translate subtitle files using DeepL API + Parameters + ---------- + auth_key: str + API Key for DeepL from gr.Textbox() + fileobjs: list + List of files to transcribe from gr.Files() + source_lang: str + Source language of the file to transcribe from gr.Dropdown() + target_lang: str + Target language of the file to transcribe from gr.Dropdown() + is_pro: str + Boolean value that is about pro user or not from gr.Checkbox(). + add_timestamp: bool + Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. + progress: gr.Progress + Indicator to show progress directly in gradio. + + Returns + ---------- + A List of + String to return to gr.Textbox() + Files to return to gr.Files() + """ + if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString): + fileobjs = [fileobj.name for fileobj in fileobjs] + + self.cache_parameters( + api_key=auth_key, + is_pro=is_pro, + source_lang=source_lang, + target_lang=target_lang, + add_timestamp=add_timestamp + ) + + files_info = {} + for fileobj in fileobjs: + file_path = fileobj + file_name, file_ext = os.path.splitext(os.path.basename(fileobj)) + + if file_ext == ".srt": + parsed_dicts = parse_srt(file_path=file_path) + + elif file_ext == ".vtt": + parsed_dicts = parse_vtt(file_path=file_path) + + batch_size = self.max_text_batch_size + for batch_start in range(0, len(parsed_dicts), batch_size): + batch_end = min(batch_start + batch_size, len(parsed_dicts)) + sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]] + translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang, + target_lang, is_pro) + for i, translated_text in enumerate(translated_texts): + parsed_dicts[batch_start + i]["sentence"] = translated_text["text"] + progress(batch_end / len(parsed_dicts), desc="Translating..") + + if file_ext == ".srt": + subtitle = get_serialized_srt(parsed_dicts) + elif file_ext == ".vtt": + subtitle = get_serialized_vtt(parsed_dicts) + + if add_timestamp: + timestamp = datetime.now().strftime("%m%d%H%M%S") + file_name += f"-{timestamp}" + + output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}") + write_file(subtitle, output_path) + + files_info[file_name] = {"subtitle": subtitle, "path": output_path} + + total_result = '' + for file_name, info in files_info.items(): + total_result += '------------------------------------\n' + total_result += f'{file_name}\n\n' + total_result += f'{info["subtitle"]}' + gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}" + + output_file_paths = [item["path"] for key, item in files_info.items()] + return [gr_str, output_file_paths] + + def request_deepl_translate(self, + auth_key: str, + text: list, + source_lang: str, + target_lang: str, + is_pro: bool = False): + """Request API response to DeepL server""" + if source_lang not in list(DEEPL_AVAILABLE_SOURCE_LANGS.keys()): + raise ValueError(f"Source language {source_lang} is not supported." + f"Use one of {list(DEEPL_AVAILABLE_SOURCE_LANGS.keys())}") + if target_lang not in list(DEEPL_AVAILABLE_TARGET_LANGS.keys()): + raise ValueError(f"Target language {target_lang} is not supported." + f"Use one of {list(DEEPL_AVAILABLE_TARGET_LANGS.keys())}") + + url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate' + headers = { + 'Authorization': f'DeepL-Auth-Key {auth_key}' + } + data = { + 'text': text, + 'source_lang': DEEPL_AVAILABLE_SOURCE_LANGS[source_lang], + 'target_lang': DEEPL_AVAILABLE_TARGET_LANGS[target_lang] + } + response = requests.post(url, headers=headers, data=data).json() + time.sleep(self.api_interval) + return response["translations"] + + @staticmethod + def cache_parameters(api_key: str, + is_pro: bool, + source_lang: str, + target_lang: str, + add_timestamp: bool): + cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) + cached_params["translation"]["deepl"] = { + "api_key": api_key, + "is_pro": is_pro, + "source_lang": source_lang, + "target_lang": target_lang + } + cached_params["translation"]["add_timestamp"] = add_timestamp + save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH) diff --git a/modules/translation/nllb_inference.py b/modules/translation/nllb_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..7987bb61e5702e3e66ec9f9c7095d6980e2fbcd3 --- /dev/null +++ b/modules/translation/nllb_inference.py @@ -0,0 +1,287 @@ +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline +import gradio as gr +import os + +from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR +from modules.translation.translation_base import TranslationBase + + +class NLLBInference(TranslationBase): + def __init__(self, + model_dir: str = NLLB_MODELS_DIR, + output_dir: str = TRANSLATION_OUTPUT_DIR + ): + super().__init__( + model_dir=model_dir, + output_dir=output_dir + ) + self.tokenizer = None + self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"] + self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys()) + self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys()) + self.pipeline = None + + def translate(self, + text: str, + max_length: int + ): + result = self.pipeline( + text, + max_length=max_length + ) + return result[0]['translation_text'] + + def update_model(self, + model_size: str, + src_lang: str, + tgt_lang: str, + progress: gr.Progress = gr.Progress() + ): + def validate_language(lang: str) -> str: + if lang in NLLB_AVAILABLE_LANGS: + return NLLB_AVAILABLE_LANGS[lang] + elif lang not in NLLB_AVAILABLE_LANGS.values(): + raise ValueError( + f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}") + return lang + + src_lang = validate_language(src_lang) + tgt_lang = validate_language(tgt_lang) + + if model_size != self.current_model_size or self.model is None: + print("\nInitializing NLLB Model..\n") + progress(0, desc="Initializing NLLB Model..") + self.current_model_size = model_size + local_files_only = self.is_model_exists(self.current_model_size) + self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size, + cache_dir=self.model_dir, + local_files_only=local_files_only) + self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size, + cache_dir=os.path.join(self.model_dir, "tokenizers"), + local_files_only=local_files_only) + + self.pipeline = pipeline("translation", + model=self.model, + tokenizer=self.tokenizer, + src_lang=src_lang, + tgt_lang=tgt_lang, + device=self.device) + + def is_model_exists(self, + model_size: str): + """Check if model exists or not (Only facebook model)""" + prefix = "models--facebook--" + _id, model_size_name = model_size.split("/") + model_dir_name = prefix + model_size_name + model_dir_path = os.path.join(self.model_dir, model_dir_name) + if os.path.exists(model_dir_path) and os.listdir(model_dir_path): + return True + return False + + +NLLB_AVAILABLE_LANGS = { + "Acehnese (Arabic script)": "ace_Arab", + "Acehnese (Latin script)": "ace_Latn", + "Mesopotamian Arabic": "acm_Arab", + "Ta’izzi-Adeni Arabic": "acq_Arab", + "Tunisian Arabic": "aeb_Arab", + "Afrikaans": "afr_Latn", + "South Levantine Arabic": "ajp_Arab", + "Akan": "aka_Latn", + "Amharic": "amh_Ethi", + "North Levantine Arabic": "apc_Arab", + "Modern Standard Arabic": "arb_Arab", + "Modern Standard Arabic (Romanized)": "arb_Latn", + "Najdi Arabic": "ars_Arab", + "Moroccan Arabic": "ary_Arab", + "Egyptian Arabic": "arz_Arab", + "Assamese": "asm_Beng", + "Asturian": "ast_Latn", + "Awadhi": "awa_Deva", + "Central Aymara": "ayr_Latn", + "South Azerbaijani": "azb_Arab", + "North Azerbaijani": "azj_Latn", + "Bashkir": "bak_Cyrl", + "Bambara": "bam_Latn", + "Balinese": "ban_Latn", + "Belarusian": "bel_Cyrl", + "Bemba": "bem_Latn", + "Bengali": "ben_Beng", + "Bhojpuri": "bho_Deva", + "Banjar (Arabic script)": "bjn_Arab", + "Banjar (Latin script)": "bjn_Latn", + "Standard Tibetan": "bod_Tibt", + "Bosnian": "bos_Latn", + "Buginese": "bug_Latn", + "Bulgarian": "bul_Cyrl", + "Catalan": "cat_Latn", + "Cebuano": "ceb_Latn", + "Czech": "ces_Latn", + "Chokwe": "cjk_Latn", + "Central Kurdish": "ckb_Arab", + "Crimean Tatar": "crh_Latn", + "Welsh": "cym_Latn", + "Danish": "dan_Latn", + "German": "deu_Latn", + "Southwestern Dinka": "dik_Latn", + "Dyula": "dyu_Latn", + "Dzongkha": "dzo_Tibt", + "Greek": "ell_Grek", + "English": "eng_Latn", + "Esperanto": "epo_Latn", + "Estonian": "est_Latn", + "Basque": "eus_Latn", + "Ewe": "ewe_Latn", + "Faroese": "fao_Latn", + "Fijian": "fij_Latn", + "Finnish": "fin_Latn", + "Fon": "fon_Latn", + "French": "fra_Latn", + "Friulian": "fur_Latn", + "Nigerian Fulfulde": "fuv_Latn", + "Scottish Gaelic": "gla_Latn", + "Irish": "gle_Latn", + "Galician": "glg_Latn", + "Guarani": "grn_Latn", + "Gujarati": "guj_Gujr", + "Haitian Creole": "hat_Latn", + "Hausa": "hau_Latn", + "Hebrew": "heb_Hebr", + "Hindi": "hin_Deva", + "Chhattisgarhi": "hne_Deva", + "Croatian": "hrv_Latn", + "Hungarian": "hun_Latn", + "Armenian": "hye_Armn", + "Igbo": "ibo_Latn", + "Ilocano": "ilo_Latn", + "Indonesian": "ind_Latn", + "Icelandic": "isl_Latn", + "Italian": "ita_Latn", + "Javanese": "jav_Latn", + "Japanese": "jpn_Jpan", + "Kabyle": "kab_Latn", + "Jingpho": "kac_Latn", + "Kamba": "kam_Latn", + "Kannada": "kan_Knda", + "Kashmiri (Arabic script)": "kas_Arab", + "Kashmiri (Devanagari script)": "kas_Deva", + "Georgian": "kat_Geor", + "Central Kanuri (Arabic script)": "knc_Arab", + "Central Kanuri (Latin script)": "knc_Latn", + "Kazakh": "kaz_Cyrl", + "KabiyΓ¨": "kbp_Latn", + "Kabuverdianu": "kea_Latn", + "Khmer": "khm_Khmr", + "Kikuyu": "kik_Latn", + "Kinyarwanda": "kin_Latn", + "Kyrgyz": "kir_Cyrl", + "Kimbundu": "kmb_Latn", + "Northern Kurdish": "kmr_Latn", + "Kikongo": "kon_Latn", + "Korean": "kor_Hang", + "Lao": "lao_Laoo", + "Ligurian": "lij_Latn", + "Limburgish": "lim_Latn", + "Lingala": "lin_Latn", + "Lithuanian": "lit_Latn", + "Lombard": "lmo_Latn", + "Latgalian": "ltg_Latn", + "Luxembourgish": "ltz_Latn", + "Luba-Kasai": "lua_Latn", + "Ganda": "lug_Latn", + "Luo": "luo_Latn", + "Mizo": "lus_Latn", + "Standard Latvian": "lvs_Latn", + "Magahi": "mag_Deva", + "Maithili": "mai_Deva", + "Malayalam": "mal_Mlym", + "Marathi": "mar_Deva", + "Minangkabau (Arabic script)": "min_Arab", + "Minangkabau (Latin script)": "min_Latn", + "Macedonian": "mkd_Cyrl", + "Plateau Malagasy": "plt_Latn", + "Maltese": "mlt_Latn", + "Meitei (Bengali script)": "mni_Beng", + "Halh Mongolian": "khk_Cyrl", + "Mossi": "mos_Latn", + "Maori": "mri_Latn", + "Burmese": "mya_Mymr", + "Dutch": "nld_Latn", + "Norwegian Nynorsk": "nno_Latn", + "Norwegian BokmΓ₯l": "nob_Latn", + "Nepali": "npi_Deva", + "Northern Sotho": "nso_Latn", + "Nuer": "nus_Latn", + "Nyanja": "nya_Latn", + "Occitan": "oci_Latn", + "West Central Oromo": "gaz_Latn", + "Odia": "ory_Orya", + "Pangasinan": "pag_Latn", + "Eastern Panjabi": "pan_Guru", + "Papiamento": "pap_Latn", + "Western Persian": "pes_Arab", + "Polish": "pol_Latn", + "Portuguese": "por_Latn", + "Dari": "prs_Arab", + "Southern Pashto": "pbt_Arab", + "Ayacucho Quechua": "quy_Latn", + "Romanian": "ron_Latn", + "Rundi": "run_Latn", + "Russian": "rus_Cyrl", + "Sango": "sag_Latn", + "Sanskrit": "san_Deva", + "Santali": "sat_Olck", + "Sicilian": "scn_Latn", + "Shan": "shn_Mymr", + "Sinhala": "sin_Sinh", + "Slovak": "slk_Latn", + "Slovenian": "slv_Latn", + "Samoan": "smo_Latn", + "Shona": "sna_Latn", + "Sindhi": "snd_Arab", + "Somali": "som_Latn", + "Southern Sotho": "sot_Latn", + "Spanish": "spa_Latn", + "Tosk Albanian": "als_Latn", + "Sardinian": "srd_Latn", + "Serbian": "srp_Cyrl", + "Swati": "ssw_Latn", + "Sundanese": "sun_Latn", + "Swedish": "swe_Latn", + "Swahili": "swh_Latn", + "Silesian": "szl_Latn", + "Tamil": "tam_Taml", + "Tatar": "tat_Cyrl", + "Telugu": "tel_Telu", + "Tajik": "tgk_Cyrl", + "Tagalog": "tgl_Latn", + "Thai": "tha_Thai", + "Tigrinya": "tir_Ethi", + "Tamasheq (Latin script)": "taq_Latn", + "Tamasheq (Tifinagh script)": "taq_Tfng", + "Tok Pisin": "tpi_Latn", + "Tswana": "tsn_Latn", + "Tsonga": "tso_Latn", + "Turkmen": "tuk_Latn", + "Tumbuka": "tum_Latn", + "Turkish": "tur_Latn", + "Twi": "twi_Latn", + "Central Atlas Tamazight": "tzm_Tfng", + "Uyghur": "uig_Arab", + "Ukrainian": "ukr_Cyrl", + "Umbundu": "umb_Latn", + "Urdu": "urd_Arab", + "Northern Uzbek": "uzn_Latn", + "Venetian": "vec_Latn", + "Vietnamese": "vie_Latn", + "Waray": "war_Latn", + "Wolof": "wol_Latn", + "Xhosa": "xho_Latn", + "Eastern Yiddish": "ydd_Hebr", + "Yoruba": "yor_Latn", + "Yue Chinese": "yue_Hant", + "Chinese (Simplified)": "zho_Hans", + "Chinese (Traditional)": "zho_Hant", + "Standard Malay": "zsm_Latn", + "Zulu": "zul_Latn", +} diff --git a/modules/translation/translation_base.py b/modules/translation/translation_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2551f0e9e1c82d4f71fd5d3a848b56c63eea4baf --- /dev/null +++ b/modules/translation/translation_base.py @@ -0,0 +1,177 @@ +import os +import torch +import gradio as gr +from abc import ABC, abstractmethod +from typing import List +from datetime import datetime + +from modules.whisper.whisper_parameter import * +from modules.utils.subtitle_manager import * +from modules.utils.files_manager import load_yaml, save_yaml +from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR + + +class TranslationBase(ABC): + def __init__(self, + model_dir: str = NLLB_MODELS_DIR, + output_dir: str = TRANSLATION_OUTPUT_DIR + ): + super().__init__() + self.model = None + self.model_dir = model_dir + self.output_dir = output_dir + os.makedirs(self.model_dir, exist_ok=True) + os.makedirs(self.output_dir, exist_ok=True) + self.current_model_size = None + self.device = self.get_device() + + @abstractmethod + def translate(self, + text: str, + max_length: int + ): + pass + + @abstractmethod + def update_model(self, + model_size: str, + src_lang: str, + tgt_lang: str, + progress: gr.Progress = gr.Progress() + ): + pass + + def translate_file(self, + fileobjs: list, + model_size: str, + src_lang: str, + tgt_lang: str, + max_length: int = 200, + add_timestamp: bool = True, + progress=gr.Progress()) -> list: + """ + Translate subtitle file from source language to target language + + Parameters + ---------- + fileobjs: list + List of files to transcribe from gr.Files() + model_size: str + Whisper model size from gr.Dropdown() + src_lang: str + Source language of the file to translate from gr.Dropdown() + tgt_lang: str + Target language of the file to translate from gr.Dropdown() + max_length: int + Max length per line to translate + add_timestamp: bool + Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. + progress: gr.Progress + Indicator to show progress directly in gradio. + I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback + + Returns + ---------- + A List of + String to return to gr.Textbox() + Files to return to gr.Files() + """ + try: + if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString): + fileobjs = [file.name for file in fileobjs] + + self.cache_parameters(model_size=model_size, + src_lang=src_lang, + tgt_lang=tgt_lang, + max_length=max_length, + add_timestamp=add_timestamp) + + self.update_model(model_size=model_size, + src_lang=src_lang, + tgt_lang=tgt_lang, + progress=progress) + + files_info = {} + for fileobj in fileobjs: + file_name, file_ext = os.path.splitext(os.path.basename(fileobj)) + if file_ext == ".srt": + parsed_dicts = parse_srt(file_path=fileobj) + total_progress = len(parsed_dicts) + for index, dic in enumerate(parsed_dicts): + progress(index / total_progress, desc="Translating..") + translated_text = self.translate(dic["sentence"], max_length=max_length) + dic["sentence"] = translated_text + subtitle = get_serialized_srt(parsed_dicts) + + elif file_ext == ".vtt": + parsed_dicts = parse_vtt(file_path=fileobj) + total_progress = len(parsed_dicts) + for index, dic in enumerate(parsed_dicts): + progress(index / total_progress, desc="Translating..") + translated_text = self.translate(dic["sentence"], max_length=max_length) + dic["sentence"] = translated_text + subtitle = get_serialized_vtt(parsed_dicts) + + if add_timestamp: + timestamp = datetime.now().strftime("%m%d%H%M%S") + file_name += f"-{timestamp}" + + output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}") + write_file(subtitle, output_path) + + files_info[file_name] = {"subtitle": subtitle, "path": output_path} + + total_result = '' + for file_name, info in files_info.items(): + total_result += '------------------------------------\n' + total_result += f'{file_name}\n\n' + total_result += f'{info["subtitle"]}' + gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}" + + output_file_paths = [item["path"] for key, item in files_info.items()] + return [gr_str, output_file_paths] + + except Exception as e: + print(f"Error: {str(e)}") + finally: + self.release_cuda_memory() + + @staticmethod + def get_device(): + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" + + @staticmethod + def release_cuda_memory(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + @staticmethod + def remove_input_files(file_paths: List[str]): + if not file_paths: + return + + for file_path in file_paths: + if file_path and os.path.exists(file_path): + os.remove(file_path) + + @staticmethod + def cache_parameters(model_size: str, + src_lang: str, + tgt_lang: str, + max_length: int, + add_timestamp: bool): + cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) + cached_params["translation"]["nllb"] = { + "model_size": model_size, + "source_lang": src_lang, + "target_lang": tgt_lang, + "max_length": max_length, + } + cached_params["translation"]["add_timestamp"] = add_timestamp + save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH) diff --git a/modules/ui/__init__.py b/modules/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/ui/htmls.py b/modules/ui/htmls.py new file mode 100644 index 0000000000000000000000000000000000000000..241705344a869259e3873e7ecbaef9a1ab883442 --- /dev/null +++ b/modules/ui/htmls.py @@ -0,0 +1,97 @@ +CSS = """ +.bmc-button { + padding: 2px 5px; + border-radius: 5px; + background-color: #FF813F; + color: white; + box-shadow: 0px 1px 2px rgba(0, 0, 0, 0.3); + text-decoration: none; + display: inline-block; + font-size: 20px; + margin: 2px; + cursor: pointer; + -webkit-transition: background-color 0.3s ease; + -ms-transition: background-color 0.3s ease; + transition: background-color 0.3s ease; +} +.bmc-button:hover, +.bmc-button:active, +.bmc-button:focus { + background-color: #FF5633; +} +.markdown { + margin-bottom: 0; + padding-bottom: 0; +} +.tabs { + margin-top: 0; + padding-top: 0; +} + +#md_project a { + color: black; + text-decoration: none; +} +#md_project a:hover { + text-decoration: underline; +} +""" + +MARKDOWN = """ +# Automatic speech recognition +""" + + +NLLB_VRAM_TABLE = """ + + + + + + + + + +
+ VRAM usage for each model + + + + + + + + + + + + + + + + + + + + + +
Model nameRequired VRAM
nllb-200-3.3B~16GB
nllb-200-1.3B~8GB
nllb-200-distilled-600M~4GB
+

Note: Be mindful of your VRAM! The table above provides an approximate VRAM usage for each model.

+
+ + + +""" \ No newline at end of file diff --git a/modules/utils/__init__.py b/modules/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/utils/cli_manager.py b/modules/utils/cli_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..67f91b61d3c75ea96c2a8462059c880970cfb85d --- /dev/null +++ b/modules/utils/cli_manager.py @@ -0,0 +1,12 @@ +import argparse + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') \ No newline at end of file diff --git a/modules/utils/files_manager.py b/modules/utils/files_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac0c6344ff97e31adb43c6cc3aad16384e08893 --- /dev/null +++ b/modules/utils/files_manager.py @@ -0,0 +1,69 @@ +import os +import fnmatch +from ruamel.yaml import YAML +from gradio.utils import NamedString + +from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH + + +def load_yaml(path: str = DEFAULT_PARAMETERS_CONFIG_PATH): + yaml = YAML(typ="safe") + yaml.preserve_quotes = True + with open(path, 'r', encoding='utf-8') as file: + config = yaml.load(file) + return config + + +def save_yaml(data: dict, path: str = DEFAULT_PARAMETERS_CONFIG_PATH): + yaml = YAML(typ="safe") + yaml.map_indent = 2 + yaml.sequence_indent = 4 + yaml.sequence_dash_offset = 2 + yaml.preserve_quotes = True + yaml.default_flow_style = False + yaml.sort_base_mapping_type_on_output = False + + with open(path, 'w', encoding='utf-8') as file: + yaml.dump(data, file) + return path + + +def get_media_files(folder_path, include_sub_directory=False): + video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv', '*.webm', '*.m4v', '*.mpeg', '*.mpg', + '*.3gp', '*.f4v', '*.ogv', '*.vob', '*.mts', '*.m2ts', '*.divx', '*.mxf', '*.rm', '*.rmvb'] + audio_extensions = ['*.mp3', '*.wav', '*.aac', '*.flac', '*.ogg', '*.m4a'] + media_extensions = video_extensions + audio_extensions + + media_files = [] + + if include_sub_directory: + for root, _, files in os.walk(folder_path): + for extension in media_extensions: + media_files.extend( + os.path.join(root, file) for file in fnmatch.filter(files, extension) + if os.path.exists(os.path.join(root, file)) + ) + else: + for extension in media_extensions: + media_files.extend( + os.path.join(folder_path, file) for file in fnmatch.filter(os.listdir(folder_path), extension) + if os.path.isfile(os.path.join(folder_path, file)) and os.path.exists(os.path.join(folder_path, file)) + ) + + return media_files + + +def format_gradio_files(files: list): + if not files: + return files + + gradio_files = [] + for file in files: + gradio_files.append(NamedString(file)) + return gradio_files + + +def is_video(file_path): + video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp'] + extension = os.path.splitext(file_path)[1].lower() + return extension in video_extensions diff --git a/modules/utils/paths.py b/modules/utils/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..630ab40bcbd037c0f1fdfa35601ed13c50745446 --- /dev/null +++ b/modules/utils/paths.py @@ -0,0 +1,31 @@ +import os + +WEBUI_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +MODELS_DIR = os.path.join(WEBUI_DIR, "models") +WHISPER_MODELS_DIR = os.path.join(MODELS_DIR, "Whisper") +FASTER_WHISPER_MODELS_DIR = os.path.join(WHISPER_MODELS_DIR, "faster-whisper") +INSANELY_FAST_WHISPER_MODELS_DIR = os.path.join(WHISPER_MODELS_DIR, "insanely-fast-whisper") +NLLB_MODELS_DIR = os.path.join(MODELS_DIR, "NLLB") +DIARIZATION_MODELS_DIR = os.path.join(MODELS_DIR, "Diarization") +UVR_MODELS_DIR = os.path.join(MODELS_DIR, "UVR", "MDX_Net_Models") +CONFIGS_DIR = os.path.join(WEBUI_DIR, "configs") +DEFAULT_PARAMETERS_CONFIG_PATH = os.path.join(CONFIGS_DIR, "default_parameters.yaml") +OUTPUT_DIR = os.path.join(WEBUI_DIR, "outputs") +TRANSLATION_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "translations") +UVR_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "UVR") +UVR_INSTRUMENTAL_OUTPUT_DIR = os.path.join(UVR_OUTPUT_DIR, "instrumental") +UVR_VOCALS_OUTPUT_DIR = os.path.join(UVR_OUTPUT_DIR, "vocals") + +for dir_path in [MODELS_DIR, + WHISPER_MODELS_DIR, + FASTER_WHISPER_MODELS_DIR, + INSANELY_FAST_WHISPER_MODELS_DIR, + NLLB_MODELS_DIR, + DIARIZATION_MODELS_DIR, + UVR_MODELS_DIR, + CONFIGS_DIR, + OUTPUT_DIR, + TRANSLATION_OUTPUT_DIR, + UVR_INSTRUMENTAL_OUTPUT_DIR, + UVR_VOCALS_OUTPUT_DIR]: + os.makedirs(dir_path, exist_ok=True) diff --git a/modules/utils/subtitle_manager.py b/modules/utils/subtitle_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..4b484254064517836bc36790bf7a4bc3508cbb82 --- /dev/null +++ b/modules/utils/subtitle_manager.py @@ -0,0 +1,132 @@ +import re + + +def timeformat_srt(time): + hours = time // 3600 + minutes = (time - hours * 3600) // 60 + seconds = time - hours * 3600 - minutes * 60 + milliseconds = (time - int(time)) * 1000 + return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}" + + +def timeformat_vtt(time): + hours = time // 3600 + minutes = (time - hours * 3600) // 60 + seconds = time - hours * 3600 - minutes * 60 + milliseconds = (time - int(time)) * 1000 + return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}" + + +def write_file(subtitle, output_file): + with open(output_file, 'w', encoding='utf-8') as f: + f.write(subtitle) + + +def get_srt(segments): + output = "" + for i, segment in enumerate(segments): + output += f"{i + 1}\n" + output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n" + if segment['text'].startswith(' '): + segment['text'] = segment['text'][1:] + output += f"{segment['text']}\n\n" + return output + + +def get_vtt(segments): + output = "WebVTT\n\n" + for i, segment in enumerate(segments): + output += f"{i + 1}\n" + output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n" + if segment['text'].startswith(' '): + segment['text'] = segment['text'][1:] + output += f"{segment['text']}\n\n" + return output + + +def get_txt(segments): + output = "" + for i, segment in enumerate(segments): + if segment['text'].startswith(' '): + segment['text'] = segment['text'][1:] + output += f"{segment['text']}\n" + return output + + +def parse_srt(file_path): + """Reads SRT file and returns as dict""" + with open(file_path, 'r', encoding='utf-8') as file: + srt_data = file.read() + + data = [] + blocks = srt_data.split('\n\n') + + for block in blocks: + if block.strip() != '': + lines = block.strip().split('\n') + index = lines[0] + timestamp = lines[1] + sentence = ' '.join(lines[2:]) + + data.append({ + "index": index, + "timestamp": timestamp, + "sentence": sentence + }) + return data + + +def parse_vtt(file_path): + """Reads WebVTT file and returns as dict""" + with open(file_path, 'r', encoding='utf-8') as file: + webvtt_data = file.read() + + data = [] + blocks = webvtt_data.split('\n\n') + + for block in blocks: + if block.strip() != '' and not block.strip().startswith("WebVTT"): + lines = block.strip().split('\n') + index = lines[0] + timestamp = lines[1] + sentence = ' '.join(lines[2:]) + + data.append({ + "index": index, + "timestamp": timestamp, + "sentence": sentence + }) + + return data + + +def get_serialized_srt(dicts): + output = "" + for dic in dicts: + output += f'{dic["index"]}\n' + output += f'{dic["timestamp"]}\n' + output += f'{dic["sentence"]}\n\n' + return output + + +def get_serialized_vtt(dicts): + output = "WebVTT\n\n" + for dic in dicts: + output += f'{dic["index"]}\n' + output += f'{dic["timestamp"]}\n' + output += f'{dic["sentence"]}\n\n' + return output + + +def safe_filename(name): + INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]' + safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name) + # Truncate the filename if it exceeds the max_length (20) + if len(safe_name) > 20: + file_extension = safe_name.split('.')[-1] + if len(file_extension) + 1 < 20: + truncated_name = safe_name[:20 - len(file_extension) - 1] + safe_name = truncated_name + '.' + file_extension + else: + safe_name = safe_name[:20] + return safe_name diff --git a/modules/utils/youtube_manager.py b/modules/utils/youtube_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..26929364564215f9e650ac1927c43df5c2465ccb --- /dev/null +++ b/modules/utils/youtube_manager.py @@ -0,0 +1,33 @@ +from pytubefix import YouTube +import subprocess +import os + + +def get_ytdata(link): + return YouTube(link) + + +def get_ytmetas(link): + yt = YouTube(link) + return yt.thumbnail_url, yt.title, yt.description + + +def get_ytaudio(ytdata: YouTube): + # Somehow the audio is corrupted so need to convert to valid audio file. + # Fix for : https://github.com/jhj0517/Whisper-WebUI/issues/304 + + audio_path = ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav")) + temp_audio_path = os.path.join("modules", "yt_tmp_fixed.wav") + + try: + subprocess.run([ + 'ffmpeg', '-y', + '-i', audio_path, + temp_audio_path + ], check=True) + + os.replace(temp_audio_path, audio_path) + return audio_path + except subprocess.CalledProcessError as e: + print(f"Error during ffmpeg conversion: {e}") + return None diff --git a/modules/uvr/music_separator.py b/modules/uvr/music_separator.py new file mode 100644 index 0000000000000000000000000000000000000000..42294d17f83075ffe497b6ae1a33b562feb8ac70 --- /dev/null +++ b/modules/uvr/music_separator.py @@ -0,0 +1,183 @@ +from typing import Optional, Union, List, Dict +import numpy as np +import torchaudio +import soundfile as sf +import os +import torch +import gc +import gradio as gr +from datetime import datetime + +from uvr.models import MDX, Demucs, VrNetwork, MDXC +from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH +from modules.utils.files_manager import load_yaml, save_yaml, is_video +from modules.diarize.audio_loader import load_audio + +class MusicSeparator: + def __init__(self, + model_dir: Optional[str] = None, + output_dir: Optional[str] = None): + self.model = None + self.device = self.get_device() + self.available_devices = ["cpu", "cuda"] + self.model_dir = model_dir + self.output_dir = output_dir + instrumental_output_dir = os.path.join(self.output_dir, "instrumental") + vocals_output_dir = os.path.join(self.output_dir, "vocals") + os.makedirs(instrumental_output_dir, exist_ok=True) + os.makedirs(vocals_output_dir, exist_ok=True) + self.audio_info = None + self.available_models = ["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"] + self.default_model = self.available_models[0] + self.current_model_size = self.default_model + self.model_config = { + "segment": 256, + "split": True + } + + def update_model(self, + model_name: str = "UVR-MDX-NET-Inst_1", + device: Optional[str] = None, + segment_size: int = 256): + """ + Update model with the given model name + + Args: + model_name (str): Model name. + device (str): Device to use for the model. + segment_size (int): Segment size for the prediction. + """ + if device is None: + device = self.device + + self.device = device + self.model_config = { + "segment": segment_size, + "split": True + } + self.model = MDX(name=model_name, + other_metadata=self.model_config, + device=self.device, + logger=None, + model_dir=self.model_dir) + + def separate(self, + audio: Union[str, np.ndarray], + model_name: str, + device: Optional[str] = None, + segment_size: int = 256, + save_file: bool = False, + progress: gr.Progress = gr.Progress()) -> tuple[np.ndarray, np.ndarray, List]: + """ + Separate the background music from the audio. + + Args: + audio (Union[str, np.ndarray]): Audio path or numpy array. + model_name (str): Model name. + device (str): Device to use for the model. + segment_size (int): Segment size for the prediction. + save_file (bool): Whether to save the separated audio to output path or not. + progress (gr.Progress): Gradio progress indicator. + + Returns: + A Tuple of + np.ndarray: Instrumental numpy arrays. + np.ndarray: Vocals numpy arrays. + file_paths: List of file paths where the separated audio is saved. Return empty when save_file is False. + """ + if isinstance(audio, str): + output_filename, ext = os.path.basename(audio), ".wav" + output_filename, orig_ext = os.path.splitext(output_filename) + + if is_video(audio): + audio = load_audio(audio) + sample_rate = 16000 + else: + self.audio_info = torchaudio.info(audio) + sample_rate = self.audio_info.sample_rate + else: + timestamp = datetime.now().strftime("%m%d%H%M%S") + output_filename, ext = f"UVR-{timestamp}", ".wav" + sample_rate = 16000 + + model_config = { + "segment": segment_size, + "split": True + } + + if (self.model is None or + self.current_model_size != model_name or + self.model_config != model_config or + self.model.sample_rate != sample_rate or + self.device != device): + progress(0, desc="Initializing UVR Model..") + self.update_model( + model_name=model_name, + device=device, + segment_size=segment_size + ) + self.model.sample_rate = sample_rate + + progress(0, desc="Separating background music from the audio..") + result = self.model(audio) + instrumental, vocals = result["instrumental"].T, result["vocals"].T + + file_paths = [] + if save_file: + instrumental_output_path = os.path.join(self.output_dir, "instrumental", f"{output_filename}-instrumental{ext}") + vocals_output_path = os.path.join(self.output_dir, "vocals", f"{output_filename}-vocals{ext}") + sf.write(instrumental_output_path, instrumental, sample_rate, format="WAV") + sf.write(vocals_output_path, vocals, sample_rate, format="WAV") + file_paths += [instrumental_output_path, vocals_output_path] + + return instrumental, vocals, file_paths + + def separate_files(self, + files: List, + model_name: str, + device: Optional[str] = None, + segment_size: int = 256, + save_file: bool = True, + progress: gr.Progress = gr.Progress()) -> List[str]: + """Separate the background music from the audio files. Returns only last Instrumental and vocals file paths + to display into gr.Audio()""" + self.cache_parameters(model_size=model_name, segment_size=segment_size) + + for file_path in files: + instrumental, vocals, file_paths = self.separate( + audio=file_path, + model_name=model_name, + device=device, + segment_size=segment_size, + save_file=save_file, + progress=progress + ) + return file_paths + + @staticmethod + def get_device(): + """Get device for the model""" + return "cuda" if torch.cuda.is_available() else "cpu" + + def offload(self): + """Offload the model and free up the memory""" + if self.model is not None: + del self.model + self.model = None + if self.device == "cuda": + torch.cuda.empty_cache() + gc.collect() + self.audio_info = None + + @staticmethod + def cache_parameters(model_size: str, + segment_size: int): + cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) + cached_uvr_params = cached_params["bgm_separation"] + uvr_params_to_cache = { + "model_size": model_size, + "segment_size": segment_size + } + cached_uvr_params = {**cached_uvr_params, **uvr_params_to_cache} + cached_params["bgm_separation"] = cached_uvr_params + save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH) diff --git a/modules/vad/__init__.py b/modules/vad/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/vad/silero_vad.py b/modules/vad/silero_vad.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5c91921e8d4cf9fb564394fc40715a886fbee4 --- /dev/null +++ b/modules/vad/silero_vad.py @@ -0,0 +1,264 @@ +# Adapted from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py + +from faster_whisper.vad import VadOptions, get_vad_model +import numpy as np +from typing import BinaryIO, Union, List, Optional, Tuple +import warnings +import faster_whisper +from faster_whisper.transcribe import SpeechTimestampsMap, Segment +import gradio as gr + + +class SileroVAD: + def __init__(self): + self.sampling_rate = 16000 + self.window_size_samples = 512 + self.model = None + + def run(self, + audio: Union[str, BinaryIO, np.ndarray], + vad_parameters: VadOptions, + progress: gr.Progress = gr.Progress() + ) -> Tuple[np.ndarray, List[dict]]: + """ + Run VAD + + Parameters + ---------- + audio: Union[str, BinaryIO, np.ndarray] + Audio path or file binary or Audio numpy array + vad_parameters: + Options for VAD processing. + progress: gr.Progress + Indicator to show progress directly in gradio. + + Returns + ---------- + np.ndarray + Pre-processed audio with VAD + List[dict] + Chunks of speeches to be used to restore the timestamps later + """ + + sampling_rate = self.sampling_rate + + if not isinstance(audio, np.ndarray): + audio = faster_whisper.decode_audio(audio, sampling_rate=sampling_rate) + + duration = audio.shape[0] / sampling_rate + duration_after_vad = duration + + if vad_parameters is None: + vad_parameters = VadOptions() + elif isinstance(vad_parameters, dict): + vad_parameters = VadOptions(**vad_parameters) + speech_chunks = self.get_speech_timestamps( + audio=audio, + vad_options=vad_parameters, + progress=progress + ) + audio = self.collect_chunks(audio, speech_chunks) + duration_after_vad = audio.shape[0] / sampling_rate + + return audio, speech_chunks + + def get_speech_timestamps( + self, + audio: np.ndarray, + vad_options: Optional[VadOptions] = None, + progress: gr.Progress = gr.Progress(), + **kwargs, + ) -> List[dict]: + """This method is used for splitting long audios into speech chunks using silero VAD. + + Args: + audio: One dimensional float array. + vad_options: Options for VAD processing. + kwargs: VAD options passed as keyword arguments for backward compatibility. + progress: Gradio progress to indicate progress. + + Returns: + List of dicts containing begin and end samples of each speech chunk. + """ + + if self.model is None: + self.update_model() + + if vad_options is None: + vad_options = VadOptions(**kwargs) + + threshold = vad_options.threshold + min_speech_duration_ms = vad_options.min_speech_duration_ms + max_speech_duration_s = vad_options.max_speech_duration_s + min_silence_duration_ms = vad_options.min_silence_duration_ms + window_size_samples = self.window_size_samples + speech_pad_ms = vad_options.speech_pad_ms + sampling_rate = 16000 + min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 + speech_pad_samples = sampling_rate * speech_pad_ms / 1000 + max_speech_samples = ( + sampling_rate * max_speech_duration_s + - window_size_samples + - 2 * speech_pad_samples + ) + min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 + min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 + + audio_length_samples = len(audio) + + state, context = self.model.get_initial_states(batch_size=1) + + speech_probs = [] + for current_start_sample in range(0, audio_length_samples, window_size_samples): + progress(current_start_sample/audio_length_samples, desc="Detecting speeches only using VAD...") + + chunk = audio[current_start_sample: current_start_sample + window_size_samples] + if len(chunk) < window_size_samples: + chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) + speech_prob, state, context = self.model(chunk, state, context, sampling_rate) + speech_probs.append(speech_prob) + + triggered = False + speeches = [] + current_speech = {} + neg_threshold = threshold - 0.15 + + # to save potential segment end (and tolerate some silence) + temp_end = 0 + # to save potential segment limits in case of maximum segment size reached + prev_end = next_start = 0 + + for i, speech_prob in enumerate(speech_probs): + if (speech_prob >= threshold) and temp_end: + temp_end = 0 + if next_start < prev_end: + next_start = window_size_samples * i + + if (speech_prob >= threshold) and not triggered: + triggered = True + current_speech["start"] = window_size_samples * i + continue + + if ( + triggered + and (window_size_samples * i) - current_speech["start"] > max_speech_samples + ): + if prev_end: + current_speech["end"] = prev_end + speeches.append(current_speech) + current_speech = {} + # previously reached silence (< neg_thres) and is still not speech (< thres) + if next_start < prev_end: + triggered = False + else: + current_speech["start"] = next_start + prev_end = next_start = temp_end = 0 + else: + current_speech["end"] = window_size_samples * i + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if (speech_prob < neg_threshold) and triggered: + if not temp_end: + temp_end = window_size_samples * i + # condition to avoid cutting in very short silence + if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech: + prev_end = temp_end + if (window_size_samples * i) - temp_end < min_silence_samples: + continue + else: + current_speech["end"] = temp_end + if ( + current_speech["end"] - current_speech["start"] + ) > min_speech_samples: + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if ( + current_speech + and (audio_length_samples - current_speech["start"]) > min_speech_samples + ): + current_speech["end"] = audio_length_samples + speeches.append(current_speech) + + for i, speech in enumerate(speeches): + if i == 0: + speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) + if i != len(speeches) - 1: + silence_duration = speeches[i + 1]["start"] - speech["end"] + if silence_duration < 2 * speech_pad_samples: + speech["end"] += int(silence_duration // 2) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - silence_duration // 2) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - speech_pad_samples) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + + return speeches + + def update_model(self): + self.model = get_vad_model() + + @staticmethod + def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: + """Collects and concatenates audio chunks.""" + if not chunks: + return np.array([], dtype=np.float32) + + return np.concatenate([audio[chunk["start"]: chunk["end"]] for chunk in chunks]) + + @staticmethod + def format_timestamp( + seconds: float, + always_include_hours: bool = False, + decimal_marker: str = ".", + ) -> str: + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) + + def restore_speech_timestamps( + self, + segments: List[dict], + speech_chunks: List[dict], + sampling_rate: Optional[int] = None, + ) -> List[dict]: + if sampling_rate is None: + sampling_rate = self.sampling_rate + + ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) + + for segment in segments: + segment["start"] = ts_map.get_original_time(segment["start"]) + segment["end"] = ts_map.get_original_time(segment["end"]) + + return segments + diff --git a/modules/whisper/__init__.py b/modules/whisper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..30a4cc386e1ae5bb9f1c33c2703ddeaba1d6eadd --- /dev/null +++ b/modules/whisper/faster_whisper_inference.py @@ -0,0 +1,192 @@ +import os +import time +import numpy as np +import torch +from typing import BinaryIO, Union, Tuple, List +import faster_whisper +from faster_whisper.vad import VadOptions +import ast +import ctranslate2 +import whisper +import gradio as gr +from argparse import Namespace + +from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) +from modules.whisper.whisper_parameter import * +from modules.whisper.whisper_base import WhisperBase + + +class FasterWhisperInference(WhisperBase): + def __init__(self, + model_dir: str = FASTER_WHISPER_MODELS_DIR, + diarization_model_dir: str = DIARIZATION_MODELS_DIR, + uvr_model_dir: str = UVR_MODELS_DIR, + output_dir: str = OUTPUT_DIR, + ): + super().__init__( + model_dir=model_dir, + diarization_model_dir=diarization_model_dir, + uvr_model_dir=uvr_model_dir, + output_dir=output_dir + ) + self.model_dir = model_dir + os.makedirs(self.model_dir, exist_ok=True) + + self.model_paths = self.get_model_paths() + self.device = self.get_device() + self.available_models = self.model_paths.keys() + self.available_compute_types = ctranslate2.get_supported_compute_types( + "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu") + + def transcribe(self, + audio: Union[str, BinaryIO, np.ndarray], + progress: gr.Progress = gr.Progress(), + *whisper_params, + ) -> Tuple[List[dict], float]: + """ + transcribe method for faster-whisper. + + Parameters + ---------- + audio: Union[str, BinaryIO, np.ndarray] + Audio path or file binary or Audio numpy array + progress: gr.Progress + Indicator to show progress directly in gradio. + *whisper_params: tuple + Parameters related with whisper. This will be dealt with "WhisperParameters" data class + + Returns + ---------- + segments_result: List[dict] + list of dicts that includes start, end timestamps and transcribed text + elapsed_time: float + elapsed time for transcription + """ + start_time = time.time() + + params = WhisperParameters.as_value(*whisper_params) + + if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: + self.update_model(params.model_size, params.compute_type, progress) + + # None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723 + if not params.initial_prompt: + params.initial_prompt = None + if not params.prefix: + params.prefix = None + if not params.hotwords: + params.hotwords = None + + params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens) + + segments, info = self.model.transcribe( + audio=audio, + language=params.lang, + task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", + beam_size=params.beam_size, + log_prob_threshold=params.log_prob_threshold, + no_speech_threshold=params.no_speech_threshold, + best_of=params.best_of, + patience=params.patience, + temperature=params.temperature, + initial_prompt=params.initial_prompt, + compression_ratio_threshold=params.compression_ratio_threshold, + length_penalty=params.length_penalty, + repetition_penalty=params.repetition_penalty, + no_repeat_ngram_size=params.no_repeat_ngram_size, + prefix=params.prefix, + suppress_blank=params.suppress_blank, + suppress_tokens=params.suppress_tokens, + max_initial_timestamp=params.max_initial_timestamp, + word_timestamps=params.word_timestamps, + prepend_punctuations=params.prepend_punctuations, + append_punctuations=params.append_punctuations, + max_new_tokens=params.max_new_tokens, + chunk_length=params.chunk_length, + hallucination_silence_threshold=params.hallucination_silence_threshold, + hotwords=params.hotwords, + language_detection_threshold=params.language_detection_threshold, + language_detection_segments=params.language_detection_segments, + prompt_reset_on_temperature=params.prompt_reset_on_temperature, + ) + progress(0, desc="Loading audio..") + + segments_result = [] + for segment in segments: + progress(segment.start / info.duration, desc="Transcribing..") + segments_result.append({ + "start": segment.start, + "end": segment.end, + "text": segment.text + }) + + elapsed_time = time.time() - start_time + return segments_result, elapsed_time + + def update_model(self, + model_size: str, + compute_type: str, + progress: gr.Progress = gr.Progress() + ): + """ + Update current model setting + + Parameters + ---------- + model_size: str + Size of whisper model + compute_type: str + Compute type for transcription. + see more info : https://opennmt.net/CTranslate2/quantization.html + progress: gr.Progress + Indicator to show progress directly in gradio. + """ + progress(0, desc="Initializing Model..") + self.current_model_size = self.model_paths[model_size] + self.current_compute_type = compute_type + self.model = faster_whisper.WhisperModel( + device=self.device, + model_size_or_path=self.current_model_size, + download_root=self.model_dir, + compute_type=self.current_compute_type + ) + + def get_model_paths(self): + """ + Get available models from models path including fine-tuned model. + + Returns + ---------- + Name list of models + """ + model_paths = {model:model for model in faster_whisper.available_models()} + faster_whisper_prefix = "models--Systran--faster-whisper-" + + existing_models = os.listdir(self.model_dir) + wrong_dirs = [".locks"] + existing_models = list(set(existing_models) - set(wrong_dirs)) + + for model_name in existing_models: + if faster_whisper_prefix in model_name: + model_name = model_name[len(faster_whisper_prefix):] + + if model_name not in whisper.available_models(): + model_paths[model_name] = os.path.join(self.model_dir, model_name) + return model_paths + + @staticmethod + def get_device(): + if torch.cuda.is_available(): + return "cuda" + else: + return "auto" + + @staticmethod + def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]: + try: + suppress_tokens = ast.literal_eval(suppress_tokens_str) + if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens): + raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") + return suppress_tokens + except Exception as e: + raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..21eb930d510cccf5083d87a1da2321d12a55754a --- /dev/null +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -0,0 +1,195 @@ +import os +import time +import numpy as np +from typing import BinaryIO, Union, Tuple, List +import torch +from transformers import pipeline +from transformers.utils import is_flash_attn_2_available +import gradio as gr +from huggingface_hub import hf_hub_download +import whisper +from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn +from argparse import Namespace + +from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) +from modules.whisper.whisper_parameter import * +from modules.whisper.whisper_base import WhisperBase + + +class InsanelyFastWhisperInference(WhisperBase): + def __init__(self, + model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR, + diarization_model_dir: str = DIARIZATION_MODELS_DIR, + uvr_model_dir: str = UVR_MODELS_DIR, + output_dir: str = OUTPUT_DIR, + ): + super().__init__( + model_dir=model_dir, + output_dir=output_dir, + diarization_model_dir=diarization_model_dir, + uvr_model_dir=uvr_model_dir + ) + self.model_dir = model_dir + os.makedirs(self.model_dir, exist_ok=True) + + openai_models = whisper.available_models() + distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"] + self.available_models = openai_models + distil_models + self.available_compute_types = ["float16"] + + def transcribe(self, + audio: Union[str, np.ndarray, torch.Tensor], + progress: gr.Progress = gr.Progress(), + *whisper_params, + ) -> Tuple[List[dict], float]: + """ + transcribe method for faster-whisper. + + Parameters + ---------- + audio: Union[str, BinaryIO, np.ndarray] + Audio path or file binary or Audio numpy array + progress: gr.Progress + Indicator to show progress directly in gradio. + *whisper_params: tuple + Parameters related with whisper. This will be dealt with "WhisperParameters" data class + + Returns + ---------- + segments_result: List[dict] + list of dicts that includes start, end timestamps and transcribed text + elapsed_time: float + elapsed time for transcription + """ + start_time = time.time() + params = WhisperParameters.as_value(*whisper_params) + + if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: + self.update_model(params.model_size, params.compute_type, progress) + + progress(0, desc="Transcribing...Progress is not shown in insanely-fast-whisper.") + with Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(style="yellow1", pulse_style="white"), + TimeElapsedColumn(), + ) as progress: + progress.add_task("[yellow]Transcribing...", total=None) + + kwargs = { + "no_speech_threshold": params.no_speech_threshold, + "temperature": params.temperature, + "compression_ratio_threshold": params.compression_ratio_threshold, + "logprob_threshold": params.log_prob_threshold, + } + + if self.current_model_size.endswith(".en"): + pass + else: + kwargs["language"] = params.lang + kwargs["task"] = "translate" if params.is_translate else "transcribe" + + segments = self.model( + inputs=audio, + return_timestamps=True, + chunk_length_s=params.chunk_length, + batch_size=params.batch_size, + generate_kwargs=kwargs + ) + + segments_result = self.format_result( + transcribed_result=segments, + ) + elapsed_time = time.time() - start_time + return segments_result, elapsed_time + + def update_model(self, + model_size: str, + compute_type: str, + progress: gr.Progress = gr.Progress(), + ): + """ + Update current model setting + + Parameters + ---------- + model_size: str + Size of whisper model + compute_type: str + Compute type for transcription. + see more info : https://opennmt.net/CTranslate2/quantization.html + progress: gr.Progress + Indicator to show progress directly in gradio. + """ + progress(0, desc="Initializing Model..") + model_path = os.path.join(self.model_dir, model_size) + if not os.path.isdir(model_path) or not os.listdir(model_path): + self.download_model( + model_size=model_size, + download_root=model_path, + progress=progress + ) + + self.current_compute_type = compute_type + self.current_model_size = model_size + self.model = pipeline( + "automatic-speech-recognition", + model=os.path.join(self.model_dir, model_size), + torch_dtype=self.current_compute_type, + device=self.device, + model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}, + ) + + @staticmethod + def format_result( + transcribed_result: dict + ) -> List[dict]: + """ + Format the transcription result of insanely_fast_whisper as the same with other implementation. + + Parameters + ---------- + transcribed_result: dict + Transcription result of the insanely_fast_whisper + + Returns + ---------- + result: List[dict] + Formatted result as the same with other implementation + """ + result = transcribed_result["chunks"] + for item in result: + start, end = item["timestamp"][0], item["timestamp"][1] + if end is None: + end = start + item["start"] = start + item["end"] = end + return result + + @staticmethod + def download_model( + model_size: str, + download_root: str, + progress: gr.Progress + ): + progress(0, 'Initializing model..') + print(f'Downloading {model_size} to "{download_root}"....') + + os.makedirs(download_root, exist_ok=True) + download_list = [ + "model.safetensors", + "config.json", + "generation_config.json", + "preprocessor_config.json", + "tokenizer.json", + "tokenizer_config.json", + "added_tokens.json", + "special_tokens_map.json", + "vocab.json", + ] + + if model_size.startswith("distil"): + repo_id = f"distil-whisper/{model_size}" + else: + repo_id = f"openai/whisper-{model_size}" + for item in download_list: + hf_hub_download(repo_id=repo_id, filename=item, local_dir=download_root) diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f87fbe5d34f0e9adff809e793ec3126db084fd69 --- /dev/null +++ b/modules/whisper/whisper_Inference.py @@ -0,0 +1,104 @@ +import whisper +import gradio as gr +import time +from typing import BinaryIO, Union, Tuple, List +import numpy as np +import torch +import os +from argparse import Namespace + +from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR) +from modules.whisper.whisper_base import WhisperBase +from modules.whisper.whisper_parameter import * + + +class WhisperInference(WhisperBase): + def __init__(self, + model_dir: str = WHISPER_MODELS_DIR, + diarization_model_dir: str = DIARIZATION_MODELS_DIR, + uvr_model_dir: str = UVR_MODELS_DIR, + output_dir: str = OUTPUT_DIR, + ): + super().__init__( + model_dir=model_dir, + output_dir=output_dir, + diarization_model_dir=diarization_model_dir, + uvr_model_dir=uvr_model_dir + ) + + def transcribe(self, + audio: Union[str, np.ndarray, torch.Tensor], + progress: gr.Progress = gr.Progress(), + *whisper_params, + ) -> Tuple[List[dict], float]: + """ + transcribe method for faster-whisper. + + Parameters + ---------- + audio: Union[str, BinaryIO, np.ndarray] + Audio path or file binary or Audio numpy array + progress: gr.Progress + Indicator to show progress directly in gradio. + *whisper_params: tuple + Parameters related with whisper. This will be dealt with "WhisperParameters" data class + + Returns + ---------- + segments_result: List[dict] + list of dicts that includes start, end timestamps and transcribed text + elapsed_time: float + elapsed time for transcription + """ + start_time = time.time() + params = WhisperParameters.as_value(*whisper_params) + + if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: + self.update_model(params.model_size, params.compute_type, progress) + + def progress_callback(progress_value): + progress(progress_value, desc="Transcribing..") + + segments_result = self.model.transcribe(audio=audio, + language=params.lang, + verbose=False, + beam_size=params.beam_size, + logprob_threshold=params.log_prob_threshold, + no_speech_threshold=params.no_speech_threshold, + task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", + fp16=True if params.compute_type == "float16" else False, + best_of=params.best_of, + patience=params.patience, + temperature=params.temperature, + compression_ratio_threshold=params.compression_ratio_threshold, + progress_callback=progress_callback,)["segments"] + elapsed_time = time.time() - start_time + + return segments_result, elapsed_time + + def update_model(self, + model_size: str, + compute_type: str, + progress: gr.Progress = gr.Progress(), + ): + """ + Update current model setting + + Parameters + ---------- + model_size: str + Size of whisper model + compute_type: str + Compute type for transcription. + see more info : https://opennmt.net/CTranslate2/quantization.html + progress: gr.Progress + Indicator to show progress directly in gradio. + """ + progress(0, desc="Initializing Model..") + self.current_compute_type = compute_type + self.current_model_size = model_size + self.model = whisper.load_model( + name=model_size, + device=self.device, + download_root=self.model_dir + ) \ No newline at end of file diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py new file mode 100644 index 0000000000000000000000000000000000000000..22efea6aeb85b141eca4c933f3a263dc9d08ecac --- /dev/null +++ b/modules/whisper/whisper_base.py @@ -0,0 +1,542 @@ +import os +import torch +import whisper +import gradio as gr +import torchaudio +from abc import ABC, abstractmethod +from typing import BinaryIO, Union, Tuple, List +import numpy as np +from datetime import datetime +from faster_whisper.vad import VadOptions +from dataclasses import astuple + +from modules.uvr.music_separator import MusicSeparator +from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, + UVR_MODELS_DIR) +from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename +from modules.utils.youtube_manager import get_ytdata, get_ytaudio +from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml +from modules.whisper.whisper_parameter import * +from modules.diarize.diarizer import Diarizer +from modules.vad.silero_vad import SileroVAD + + +class WhisperBase(ABC): + def __init__(self, + model_dir: str = WHISPER_MODELS_DIR, + diarization_model_dir: str = DIARIZATION_MODELS_DIR, + uvr_model_dir: str = UVR_MODELS_DIR, + output_dir: str = OUTPUT_DIR, + ): + self.model_dir = model_dir + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + os.makedirs(self.model_dir, exist_ok=True) + self.diarizer = Diarizer( + model_dir=diarization_model_dir + ) + self.vad = SileroVAD() + self.music_separator = MusicSeparator( + model_dir=uvr_model_dir, + output_dir=os.path.join(output_dir, "UVR") + ) + + self.model = None + self.current_model_size = None + self.available_models = whisper.available_models() + self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values())) + self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"] + self.device = self.get_device() + self.available_compute_types = ["float16", "float32"] + self.current_compute_type = "float16" if self.device == "cuda" else "float32" + + @abstractmethod + def transcribe(self, + audio: Union[str, BinaryIO, np.ndarray], + progress: gr.Progress = gr.Progress(), + *whisper_params, + ): + """Inference whisper model to transcribe""" + pass + + @abstractmethod + def update_model(self, + model_size: str, + compute_type: str, + progress: gr.Progress = gr.Progress() + ): + """Initialize whisper model""" + pass + + def run(self, + audio: Union[str, BinaryIO, np.ndarray], + progress: gr.Progress = gr.Progress(), + add_timestamp: bool = True, + *whisper_params, + ) -> Tuple[List[dict], float]: + """ + Run transcription with conditional pre-processing and post-processing. + The VAD will be performed to remove noise from the audio input in pre-processing, if enabled. + The diarization will be performed in post-processing, if enabled. + + Parameters + ---------- + audio: Union[str, BinaryIO, np.ndarray] + Audio input. This can be file path or binary type. + progress: gr.Progress + Indicator to show progress directly in gradio. + add_timestamp: bool + Whether to add a timestamp at the end of the filename. + *whisper_params: tuple + Parameters related with whisper. This will be dealt with "WhisperParameters" data class + + Returns + ---------- + segments_result: List[dict] + list of dicts that includes start, end timestamps and transcribed text + elapsed_time: float + elapsed time for running + """ + params = WhisperParameters.as_value(*whisper_params) + + self.cache_parameters( + whisper_params=params, + add_timestamp=add_timestamp + ) + + if params.lang is None: + pass + elif params.lang == "Automatic Detection": + params.lang = None + else: + language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} + params.lang = language_code_dict[params.lang] + + if params.is_bgm_separate: + music, audio, _ = self.music_separator.separate( + audio=audio, + model_name=params.uvr_model_size, + device=params.uvr_device, + segment_size=params.uvr_segment_size, + save_file=params.uvr_save_file, + progress=progress + ) + + if audio.ndim >= 2: + audio = audio.mean(axis=1) + if self.music_separator.audio_info is None: + origin_sample_rate = 16000 + else: + origin_sample_rate = self.music_separator.audio_info.sample_rate + audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate) + + if params.uvr_enable_offload: + self.music_separator.offload() + + if params.vad_filter: + # Explicit value set for float('inf') from gr.Number() + if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999: + params.max_speech_duration_s = float('inf') + + vad_options = VadOptions( + threshold=params.threshold, + min_speech_duration_ms=params.min_speech_duration_ms, + max_speech_duration_s=params.max_speech_duration_s, + min_silence_duration_ms=params.min_silence_duration_ms, + speech_pad_ms=params.speech_pad_ms + ) + + audio, speech_chunks = self.vad.run( + audio=audio, + vad_parameters=vad_options, + progress=progress + ) + + result, elapsed_time = self.transcribe( + audio, + progress, + *astuple(params) + ) + + if params.vad_filter: + result = self.vad.restore_speech_timestamps( + segments=result, + speech_chunks=speech_chunks, + ) + + if params.is_diarize: + result, elapsed_time_diarization = self.diarizer.run( + audio=audio, + use_auth_token=params.hf_token, + transcribed_result=result, + ) + elapsed_time += elapsed_time_diarization + return result, elapsed_time + + def transcribe_file(self, + files: Optional[List] = None, + input_folder_path: Optional[str] = None, + file_format: str = "SRT", + add_timestamp: bool = True, + progress=gr.Progress(), + *whisper_params, + ) -> list: + """ + Write subtitle file from Files + + Parameters + ---------- + files: list + List of files to transcribe from gr.Files() + input_folder_path: str + Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and + this will be used instead. + file_format: str + Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt] + add_timestamp: bool + Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename. + progress: gr.Progress + Indicator to show progress directly in gradio. + *whisper_params: tuple + Parameters related with whisper. This will be dealt with "WhisperParameters" data class + + Returns + ---------- + result_str: + Result of transcription to return to gr.Textbox() + result_file_path: + Output file path to return to gr.Files() + """ + try: + if input_folder_path: + files = get_media_files(input_folder_path) + if isinstance(files, str): + files = [files] + if files and isinstance(files[0], gr.utils.NamedString): + files = [file.name for file in files] + + files_info = {} + for file in files: + + ## Detect language + #model = whisper.load_model("base") + params = WhisperParameters.as_value(*whisper_params) + model = whisper.load_model(params.model_size) + mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device) + _, probs = model.detect_language(mel) + file_language = "not" + for key,value in whisper.tokenizer.LANGUAGES.items(): + if key == str(max(probs, key=probs.get)): + file_language = value.capitalize() + break + + transcribed_segments, time_for_task = self.run( + file, + progress, + add_timestamp, + *whisper_params, + ) + + file_name, file_ext = os.path.splitext(os.path.basename(file)) + subtitle, file_path = self.generate_and_write_file( + file_name=file_name, + transcribed_segments=transcribed_segments, + add_timestamp=add_timestamp, + file_format=file_format, + output_dir=self.output_dir + ) + + files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path, "lang": file_language} + + total_result = '' + total_info = '' + total_time = 0 + for file_name, info in files_info.items(): + total_result += f'{info["subtitle"]}' + total_time += info["time_for_task"] + #total_info += f'{info["lang"]}' + total_info += f"Language {info['lang']} detected for file '{file_name}{file_ext}'" + + #result_str = f"Processing of file '{file_name}{file_ext}' done in {self.format_time(total_time)}:\n\n{total_result}" + total_info += f"\nTranscription process done in {self.format_time(total_time)}" + result_str = total_result + result_file_path = [info['path'] for info in files_info.values()] + + return [result_str, result_file_path, total_info] + + except Exception as e: + print(f"Error transcribing file: {e}") + finally: + self.release_cuda_memory() + + def transcribe_mic(self, + mic_audio: str, + file_format: str = "SRT", + add_timestamp: bool = True, + progress=gr.Progress(), + *whisper_params, + ) -> list: + """ + Write subtitle file from microphone + + Parameters + ---------- + mic_audio: str + Audio file path from gr.Microphone() + file_format: str + Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt] + add_timestamp: bool + Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. + progress: gr.Progress + Indicator to show progress directly in gradio. + *whisper_params: tuple + Parameters related with whisper. This will be dealt with "WhisperParameters" data class + + Returns + ---------- + result_str: + Result of transcription to return to gr.Textbox() + result_file_path: + Output file path to return to gr.Files() + """ + try: + progress(0, desc="Loading Audio..") + transcribed_segments, time_for_task = self.run( + mic_audio, + progress, + add_timestamp, + *whisper_params, + ) + progress(1, desc="Completed!") + + subtitle, result_file_path = self.generate_and_write_file( + file_name="Mic", + transcribed_segments=transcribed_segments, + add_timestamp=add_timestamp, + file_format=file_format, + output_dir=self.output_dir + ) + + result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" + return [result_str, result_file_path] + except Exception as e: + print(f"Error transcribing file: {e}") + finally: + self.release_cuda_memory() + + def transcribe_youtube(self, + youtube_link: str, + file_format: str = "SRT", + add_timestamp: bool = True, + progress=gr.Progress(), + *whisper_params, + ) -> list: + """ + Write subtitle file from Youtube + + Parameters + ---------- + youtube_link: str + URL of the Youtube video to transcribe from gr.Textbox() + file_format: str + Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt] + add_timestamp: bool + Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. + progress: gr.Progress + Indicator to show progress directly in gradio. + *whisper_params: tuple + Parameters related with whisper. This will be dealt with "WhisperParameters" data class + + Returns + ---------- + result_str: + Result of transcription to return to gr.Textbox() + result_file_path: + Output file path to return to gr.Files() + """ + try: + progress(0, desc="Loading Audio from Youtube..") + yt = get_ytdata(youtube_link) + audio = get_ytaudio(yt) + + transcribed_segments, time_for_task = self.run( + audio, + progress, + add_timestamp, + *whisper_params, + ) + + progress(1, desc="Completed!") + + file_name = safe_filename(yt.title) + subtitle, result_file_path = self.generate_and_write_file( + file_name=file_name, + transcribed_segments=transcribed_segments, + add_timestamp=add_timestamp, + file_format=file_format, + output_dir=self.output_dir + ) + result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" + + if os.path.exists(audio): + os.remove(audio) + + return [result_str, result_file_path] + + except Exception as e: + print(f"Error transcribing file: {e}") + finally: + self.release_cuda_memory() + + @staticmethod + def generate_and_write_file(file_name: str, + transcribed_segments: list, + add_timestamp: bool, + file_format: str, + output_dir: str + ) -> str: + """ + Writes subtitle file + + Parameters + ---------- + file_name: str + Output file name + transcribed_segments: list + Text segments transcribed from audio + add_timestamp: bool + Determines whether to add a timestamp to the end of the filename. + file_format: str + File format to write. Supported formats: [SRT, WebVTT, txt] + output_dir: str + Directory path of the output + + Returns + ---------- + content: str + Result of the transcription + output_path: str + output file path + """ + if add_timestamp: + timestamp = datetime.now().strftime("%m%d%H%M%S") + output_path = os.path.join(output_dir, f"{file_name} - {timestamp}") + else: + output_path = os.path.join(output_dir, f"{file_name}") + + file_format = file_format.strip().lower() + if file_format == "srt": + content = get_srt(transcribed_segments) + output_path += '.srt' + + elif file_format == "webvtt": + content = get_vtt(transcribed_segments) + output_path += '.vtt' + + elif file_format == "txt": + content = get_txt(transcribed_segments) + output_path += '.txt' + + write_file(content, output_path) + return content, output_path + + @staticmethod + def format_time(elapsed_time: float) -> str: + """ + Get {hours} {minutes} {seconds} time format string + + Parameters + ---------- + elapsed_time: str + Elapsed time for transcription + + Returns + ---------- + Time format string + """ + hours, rem = divmod(elapsed_time, 3600) + minutes, seconds = divmod(rem, 60) + + time_str = "" + if hours: + time_str += f"{hours} hours " + if minutes: + time_str += f"{minutes} minutes " + seconds = round(seconds) + time_str += f"{seconds} seconds" + + return time_str.strip() + + @staticmethod + def get_device(): + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + if not WhisperBase.is_sparse_api_supported(): + # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886 + return "cpu" + return "mps" + else: + return "cpu" + + @staticmethod + def is_sparse_api_supported(): + if not torch.backends.mps.is_available(): + return False + + try: + device = torch.device("mps") + sparse_tensor = torch.sparse_coo_tensor( + indices=torch.tensor([[0, 1], [2, 3]]), + values=torch.tensor([1, 2]), + size=(4, 4), + device=device + ) + return True + except RuntimeError: + return False + + @staticmethod + def release_cuda_memory(): + """Release memory""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + @staticmethod + def remove_input_files(file_paths: List[str]): + """Remove gradio cached files""" + if not file_paths: + return + + for file_path in file_paths: + if file_path and os.path.exists(file_path): + os.remove(file_path) + + @staticmethod + def cache_parameters( + whisper_params: WhisperValues, + add_timestamp: bool + ): + """cache parameters to the yaml file""" + cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) + cached_whisper_param = whisper_params.to_yaml() + cached_yaml = {**cached_params, **cached_whisper_param} + cached_yaml["whisper"]["add_timestamp"] = add_timestamp + + save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) + + @staticmethod + def resample_audio(audio: Union[str, np.ndarray], + new_sample_rate: int = 16000, + original_sample_rate: Optional[int] = None,) -> np.ndarray: + """Resamples audio to 16k sample rate, standard on Whisper model""" + if isinstance(audio, str): + audio, original_sample_rate = torchaudio.load(audio) + else: + if original_sample_rate is None: + raise ValueError("original_sample_rate must be provided when audio is numpy array.") + audio = torch.from_numpy(audio) + resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate) + resampled_audio = resampler(audio).numpy() + return resampled_audio diff --git a/modules/whisper/whisper_factory.py b/modules/whisper/whisper_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..6bda8c58272bafa7e7e1fa175fa914ad304dee6f --- /dev/null +++ b/modules/whisper/whisper_factory.py @@ -0,0 +1,90 @@ +from typing import Optional +import os + +from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, + INSANELY_FAST_WHISPER_MODELS_DIR, WHISPER_MODELS_DIR, UVR_MODELS_DIR) +from modules.whisper.faster_whisper_inference import FasterWhisperInference +from modules.whisper.whisper_Inference import WhisperInference +from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference +from modules.whisper.whisper_base import WhisperBase + + +class WhisperFactory: + @staticmethod + def create_whisper_inference( + whisper_type: str, + whisper_model_dir: str = WHISPER_MODELS_DIR, + faster_whisper_model_dir: str = FASTER_WHISPER_MODELS_DIR, + insanely_fast_whisper_model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR, + diarization_model_dir: str = DIARIZATION_MODELS_DIR, + uvr_model_dir: str = UVR_MODELS_DIR, + output_dir: str = OUTPUT_DIR, + ) -> "WhisperBase": + """ + Create a whisper inference class based on the provided whisper_type. + + Parameters + ---------- + whisper_type : str + The type of Whisper implementation to use. Supported values (case-insensitive): + - "faster-whisper": https://github.com/openai/whisper + - "whisper": https://github.com/openai/whisper + - "insanely-fast-whisper": https://github.com/Vaibhavs10/insanely-fast-whisper + whisper_model_dir : str + Directory path for the Whisper model. + faster_whisper_model_dir : str + Directory path for the Faster Whisper model. + insanely_fast_whisper_model_dir : str + Directory path for the Insanely Fast Whisper model. + diarization_model_dir : str + Directory path for the diarization model. + uvr_model_dir : str + Directory path for the UVR model. + output_dir : str + Directory path where output files will be saved. + + Returns + ------- + WhisperBase + An instance of the appropriate whisper inference class based on the whisper_type. + """ + # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144 + os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' + + whisper_type = whisper_type.lower().strip() + + faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"] + whisper_typos = ["whisper"] + insanely_fast_whisper_typos = [ + "insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper", + "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper" + ] + + if whisper_type in faster_whisper_typos: + return FasterWhisperInference( + model_dir=faster_whisper_model_dir, + output_dir=output_dir, + diarization_model_dir=diarization_model_dir, + uvr_model_dir=uvr_model_dir + ) + elif whisper_type in whisper_typos: + return WhisperInference( + model_dir=whisper_model_dir, + output_dir=output_dir, + diarization_model_dir=diarization_model_dir, + uvr_model_dir=uvr_model_dir + ) + elif whisper_type in insanely_fast_whisper_typos: + return InsanelyFastWhisperInference( + model_dir=insanely_fast_whisper_model_dir, + output_dir=output_dir, + diarization_model_dir=diarization_model_dir, + uvr_model_dir=uvr_model_dir + ) + else: + return FasterWhisperInference( + model_dir=faster_whisper_model_dir, + output_dir=output_dir, + diarization_model_dir=diarization_model_dir, + uvr_model_dir=uvr_model_dir + ) diff --git a/modules/whisper/whisper_parameter.py b/modules/whisper/whisper_parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..0aec677b2826636e98542e41ed587e349e663d2a --- /dev/null +++ b/modules/whisper/whisper_parameter.py @@ -0,0 +1,369 @@ +from dataclasses import dataclass, fields +import gradio as gr +from typing import Optional, Dict +import yaml + + +@dataclass +class WhisperParameters: + model_size: gr.Dropdown + lang: gr.Dropdown + is_translate: gr.Checkbox + beam_size: gr.Number + log_prob_threshold: gr.Number + no_speech_threshold: gr.Number + compute_type: gr.Dropdown + best_of: gr.Number + patience: gr.Number + condition_on_previous_text: gr.Checkbox + prompt_reset_on_temperature: gr.Slider + initial_prompt: gr.Textbox + temperature: gr.Slider + compression_ratio_threshold: gr.Number + vad_filter: gr.Checkbox + threshold: gr.Slider + min_speech_duration_ms: gr.Number + max_speech_duration_s: gr.Number + min_silence_duration_ms: gr.Number + speech_pad_ms: gr.Number + batch_size: gr.Number + is_diarize: gr.Checkbox + hf_token: gr.Textbox + diarization_device: gr.Dropdown + length_penalty: gr.Number + repetition_penalty: gr.Number + no_repeat_ngram_size: gr.Number + prefix: gr.Textbox + suppress_blank: gr.Checkbox + suppress_tokens: gr.Textbox + max_initial_timestamp: gr.Number + word_timestamps: gr.Checkbox + prepend_punctuations: gr.Textbox + append_punctuations: gr.Textbox + max_new_tokens: gr.Number + chunk_length: gr.Number + hallucination_silence_threshold: gr.Number + hotwords: gr.Textbox + language_detection_threshold: gr.Number + language_detection_segments: gr.Number + is_bgm_separate: gr.Checkbox + uvr_model_size: gr.Dropdown + uvr_device: gr.Dropdown + uvr_segment_size: gr.Number + uvr_save_file: gr.Checkbox + uvr_enable_offload: gr.Checkbox + """ + A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing. + This data class is used to mitigate the key-value problem between Gradio components and function parameters. + Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471 + See more about Gradio pre-processing: https://www.gradio.app/docs/components + + Attributes + ---------- + model_size: gr.Dropdown + Whisper model size. + + lang: gr.Dropdown + Source language of the file to transcribe. + + is_translate: gr.Checkbox + Boolean value that determines whether to translate to English. + It's Whisper's feature to translate speech from another language directly into English end-to-end. + + beam_size: gr.Number + Int value that is used for decoding option. + + log_prob_threshold: gr.Number + If the average log probability over sampled tokens is below this value, treat as failed. + + no_speech_threshold: gr.Number + If the no_speech probability is higher than this value AND + the average log probability over sampled tokens is below `log_prob_threshold`, + consider the segment as silent. + + compute_type: gr.Dropdown + compute type for transcription. + see more info : https://opennmt.net/CTranslate2/quantization.html + + best_of: gr.Number + Number of candidates when sampling with non-zero temperature. + + patience: gr.Number + Beam search patience factor. + + condition_on_previous_text: gr.Checkbox + if True, the previous output of the model is provided as a prompt for the next window; + disabling may make the text inconsistent across windows, but the model becomes less prone to + getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. + + initial_prompt: gr.Textbox + Optional text to provide as a prompt for the first window. This can be used to provide, or + "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns + to make it more likely to predict those word correctly. + + temperature: gr.Slider + Temperature for sampling. It can be a tuple of temperatures, + which will be successively used upon failures according to either + `compression_ratio_threshold` or `log_prob_threshold`. + + compression_ratio_threshold: gr.Number + If the gzip compression ratio is above this value, treat as failed + + vad_filter: gr.Checkbox + Enable the voice activity detection (VAD) to filter out parts of the audio + without speech. This step is using the Silero VAD model + https://github.com/snakers4/silero-vad. + + threshold: gr.Slider + This parameter is related with Silero VAD. Speech threshold. + Silero VAD outputs speech probabilities for each audio chunk, + probabilities ABOVE this value are considered as SPEECH. It is better to tune this + parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. + + min_speech_duration_ms: gr.Number + This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out. + + max_speech_duration_s: gr.Number + This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer + than max_speech_duration_s will be split at the timestamp of the last silence that + lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be + split aggressively just before max_speech_duration_s. + + min_silence_duration_ms: gr.Number + This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms + before separating it + + speech_pad_ms: gr.Number + This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side + + batch_size: gr.Number + This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe + + is_diarize: gr.Checkbox + This parameter is related with whisperx. Boolean value that determines whether to diarize or not. + + hf_token: gr.Textbox + This parameter is related with whisperx. Huggingface token is needed to download diarization models. + Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements + + diarization_device: gr.Dropdown + This parameter is related with whisperx. Device to run diarization model + + length_penalty: gr.Number + This parameter is related to faster-whisper. Exponential length penalty constant. + + repetition_penalty: gr.Number + This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + + no_repeat_ngram_size: gr.Number + This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable). + + prefix: gr.Textbox + This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window. + + suppress_blank: gr.Checkbox + This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling. + + suppress_tokens: gr.Textbox + This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set + of symbols as defined in the model config.json file. + + max_initial_timestamp: gr.Number + This parameter is related to faster-whisper. The initial timestamp cannot be later than this. + + word_timestamps: gr.Checkbox + This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern + and dynamic time warping, and include the timestamps for each word in each segment. + + prepend_punctuations: gr.Textbox + This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols + with the next word. + + append_punctuations: gr.Textbox + This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols + with the previous word. + + max_new_tokens: gr.Number + This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set, + the maximum will be set by the default max_length. + + chunk_length: gr.Number + This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds. + If it is not None, it will overwrite the default chunk_length of the FeatureExtractor. + + hallucination_silence_threshold: gr.Number + This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold + (in seconds) when a possible hallucination is detected. + + hotwords: gr.Textbox + This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None. + + language_detection_threshold: gr.Number + This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected. + + language_detection_segments: gr.Number + This parameter is related to faster-whisper. Number of segments to consider for the language detection. + + is_separate_bgm: gr.Checkbox + This parameter is related to UVR. Boolean value that determines whether to separate bgm or not. + + uvr_model_size: gr.Dropdown + This parameter is related to UVR. UVR model size. + + uvr_device: gr.Dropdown + This parameter is related to UVR. Device to run UVR model. + + uvr_segment_size: gr.Number + This parameter is related to UVR. Segment size for UVR model. + + uvr_save_file: gr.Checkbox + This parameter is related to UVR. Boolean value that determines whether to save the file or not. + + uvr_enable_offload: gr.Checkbox + This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not + after each transcription. + """ + + def as_list(self) -> list: + """ + Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing. + See more about Gradio pre-processing: : https://www.gradio.app/docs/components + + Returns + ---------- + A list of Gradio components + """ + return [getattr(self, f.name) for f in fields(self)] + + @staticmethod + def as_value(*args) -> 'WhisperValues': + """ + To use Whisper parameters in function after Gradio post-processing. + See more about Gradio post-processing: : https://www.gradio.app/docs/components + + Returns + ---------- + WhisperValues + Data class that has values of parameters + """ + return WhisperValues(*args) + + +@dataclass +class WhisperValues: + model_size: str = "large-v2" + lang: Optional[str] = None + is_translate: bool = False + beam_size: int = 5 + log_prob_threshold: float = -1.0 + no_speech_threshold: float = 0.6 + compute_type: str = "float16" + best_of: int = 5 + patience: float = 1.0 + condition_on_previous_text: bool = True + prompt_reset_on_temperature: float = 0.5 + initial_prompt: Optional[str] = None + temperature: float = 0.0 + compression_ratio_threshold: float = 2.4 + vad_filter: bool = False + threshold: float = 0.5 + min_speech_duration_ms: int = 250 + max_speech_duration_s: float = float("inf") + min_silence_duration_ms: int = 2000 + speech_pad_ms: int = 400 + batch_size: int = 24 + is_diarize: bool = False + hf_token: str = "" + diarization_device: str = "cuda" + length_penalty: float = 1.0 + repetition_penalty: float = 1.0 + no_repeat_ngram_size: int = 0 + prefix: Optional[str] = None + suppress_blank: bool = True + suppress_tokens: Optional[str] = "[-1]" + max_initial_timestamp: float = 0.0 + word_timestamps: bool = False + prepend_punctuations: Optional[str] = "\"'β€œΒΏ([{-" + append_punctuations: Optional[str] = "\"'.。,,!!??:οΌšβ€)]}、" + max_new_tokens: Optional[int] = None + chunk_length: Optional[int] = 30 + hallucination_silence_threshold: Optional[float] = None + hotwords: Optional[str] = None + language_detection_threshold: Optional[float] = None + language_detection_segments: int = 1 + is_bgm_separate: bool = False + uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4" + uvr_device: str = "cuda" + uvr_segment_size: int = 256 + uvr_save_file: bool = False + uvr_enable_offload: bool = True + """ + A data class to use Whisper parameters. + """ + + def to_yaml(self) -> Dict: + data = { + "whisper": { + "model_size": self.model_size, + "lang": "Automatic Detection" if self.lang is None else self.lang, + "is_translate": self.is_translate, + "beam_size": self.beam_size, + "log_prob_threshold": self.log_prob_threshold, + "no_speech_threshold": self.no_speech_threshold, + "best_of": self.best_of, + "patience": self.patience, + "condition_on_previous_text": self.condition_on_previous_text, + "prompt_reset_on_temperature": self.prompt_reset_on_temperature, + "initial_prompt": None if not self.initial_prompt else self.initial_prompt, + "temperature": self.temperature, + "compression_ratio_threshold": self.compression_ratio_threshold, + "batch_size": self.batch_size, + "length_penalty": self.length_penalty, + "repetition_penalty": self.repetition_penalty, + "no_repeat_ngram_size": self.no_repeat_ngram_size, + "prefix": None if not self.prefix else self.prefix, + "suppress_blank": self.suppress_blank, + "suppress_tokens": self.suppress_tokens, + "max_initial_timestamp": self.max_initial_timestamp, + "word_timestamps": self.word_timestamps, + "prepend_punctuations": self.prepend_punctuations, + "append_punctuations": self.append_punctuations, + "max_new_tokens": self.max_new_tokens, + "chunk_length": self.chunk_length, + "hallucination_silence_threshold": self.hallucination_silence_threshold, + "hotwords": None if not self.hotwords else self.hotwords, + "language_detection_threshold": self.language_detection_threshold, + "language_detection_segments": self.language_detection_segments, + }, + "vad": { + "vad_filter": self.vad_filter, + "threshold": self.threshold, + "min_speech_duration_ms": self.min_speech_duration_ms, + "max_speech_duration_s": self.max_speech_duration_s, + "min_silence_duration_ms": self.min_silence_duration_ms, + "speech_pad_ms": self.speech_pad_ms, + }, + "diarization": { + "is_diarize": self.is_diarize, + "hf_token": self.hf_token + }, + "bgm_separation": { + "is_separate_bgm": self.is_bgm_separate, + "model_size": self.uvr_model_size, + "segment_size": self.uvr_segment_size, + "save_file": self.uvr_save_file, + "enable_offload": self.uvr_enable_offload + }, + } + return data + + def as_list(self) -> list: + """ + Converts the data class attributes into a list + + Returns + ---------- + A list of Whisper parameters + """ + return [getattr(self, f.name) for f in fields(self)] diff --git a/notebook/whisper-webui.ipynb b/notebook/whisper-webui.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..bebf6e45b6e1015ebb586332ab892049d18fb4f6 --- /dev/null +++ b/notebook/whisper-webui.ipynb @@ -0,0 +1,132 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "---\n", + "\n", + "πŸ“Œ **This notebook has been updated [here](https://github.com/jhj0517/Whisper-WebUI.git)!**\n", + "\n", + "πŸ–‹ **Author**: [jhj0517](https://github.com/jhj0517/Whisper-WebUI/blob/master/notebook/whisper-webui.ipynb)\n", + "\n", + "😎 **Support the Project**:\n", + "\n", + "If you find this project useful, please consider supporting it:\n", + "\n", + "\n", + " \"Buy\n", + "\n", + "\n", + "---" + ], + "metadata": { + "id": "doKhBBXIfS21" + } + }, + { + "cell_type": "code", + "source": [ + "#@title #(Optional) Check GPU\n", + "#@markdown Some models may not function correctly on a CPU runtime.\n", + "\n", + "#@markdown so you should check your GPU setup before run.\n", + "!nvidia-smi" + ], + "metadata": { + "id": "23yZvUlagEsx", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kNbSbsctxahq", + "cellView": "form" + }, + "outputs": [], + "source": [ + "#@title #Installation\n", + "#@markdown This cell will install dependencies for Whisper-WebUI!\n", + "!git clone https://github.com/jhj0517/Whisper-WebUI.git\n", + "%cd Whisper-WebUI\n", + "!pip install git+https://github.com/jhj0517/jhj0517-whisper.git\n", + "!pip install faster-whisper==1.0.3\n", + "!pip install gradio==4.43.0\n", + "# Temporal bug fix from https://github.com/jhj0517/Whisper-WebUI/issues/256\n", + "!pip install git+https://github.com/JuanBindez/pytubefix.git\n", + "!pip install tokenizers==0.19.1\n", + "!pip install pyannote.audio==3.3.1\n", + "!pip install git+https://github.com/jhj0517/ultimatevocalremover_api.git" + ] + }, + { + "cell_type": "code", + "source": [ + "#@title # (Optional) Configure arguments\n", + "#@markdown This section is used to configure some command line arguments.\n", + "\n", + "#@markdown You can simply ignore this section and the default values will be used.\n", + "\n", + "USERNAME = '' #@param {type: \"string\"}\n", + "PASSWORD = '' #@param {type: \"string\"}\n", + "WHISPER_TYPE = 'faster-whisper' # @param [\"whisper\", \"faster-whisper\", \"insanely-fast-whisper\"]\n", + "THEME = '' #@param {type: \"string\"}\n", + "\n", + "arguments = \"\"\n", + "if USERNAME:\n", + " arguments += f\" --username {USERNAME}\"\n", + "if PASSWORD:\n", + " arguments += f\" --password {PASSWORD}\"\n", + "if THEME:\n", + " arguments += f\" --theme {THEME}\"\n", + "if WHISPER_TYPE:\n", + " arguments += f\" --whisper_type {WHISPER_TYPE}\"\n", + "\n", + "\n", + "#@markdown If you wonder how these arguments are used, you can see the [Wiki](https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments)." + ], + "metadata": { + "id": "Qosz9BFlGui3", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "PQroYRRZzQiN", + "cellView": "form" + }, + "outputs": [], + "source": [ + "#@title #Run\n", + "#@markdown Once the installation is complete, you can use public URL that is displayed.\n", + "if 'arguments' in locals():\n", + " !python app.py --share --colab{arguments}\n", + "else:\n", + " !python app.py --share --colab" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/outputs/outputs are saved here.txt b/outputs/outputs are saved here.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/outputs/translations/outputs for translation are saved here.txt b/outputs/translations/outputs for translation are saved here.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..afe8b6c7d16d8875d5eb1d55e9d031433b83c639 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +# Remove the --extra-index-url line below if you're not using Nvidia GPU. +# If you're using it, update url to your CUDA version (CUDA 12.1 is minimum requirement): +# For CUDA 12.1, use : https://download.pytorch.org/whl/cu121 +# For CUDA 12.4, use : https://download.pytorch.org/whl/cu124 +--extra-index-url https://download.pytorch.org/whl/cu121 + + +torch==2.3.1 +torchaudio==2.3.1 +git+https://github.com/jhj0517/jhj0517-whisper.git +faster-whisper==1.0.3 +transformers +gradio +pytubefix +ruamel.yaml==0.18.6 +pyannote.audio==3.3.1 +git+https://github.com/jhj0517/ultimatevocalremover_api.git +git+https://github.com/jhj0517/pyrubberband.git \ No newline at end of file diff --git a/screenshot.png b/screenshot.png new file mode 100644 index 0000000000000000000000000000000000000000..c12b18ded02179356efa498dde1a1150ed64680b Binary files /dev/null and b/screenshot.png differ diff --git a/start-webui.bat b/start-webui.bat new file mode 100644 index 0000000000000000000000000000000000000000..58422504d0b96c6131abf6860714dfb8d8977351 --- /dev/null +++ b/start-webui.bat @@ -0,0 +1,7 @@ +@echo off + +call venv\scripts\activate +python app.py %* + +echo "launching the app" +pause diff --git a/start-webui.sh b/start-webui.sh new file mode 100644 index 0000000000000000000000000000000000000000..901d412849d93ee2f6ef165c9e539b6f1b25b3f0 --- /dev/null +++ b/start-webui.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +source venv/bin/activate +python app.py "$@" + +echo "launching the app" diff --git a/tests/test_bgm_separation.py b/tests/test_bgm_separation.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4a6f80531a29fb3a6e764bf425dd15923e1ca5 --- /dev/null +++ b/tests/test_bgm_separation.py @@ -0,0 +1,53 @@ +from modules.utils.paths import * +from modules.whisper.whisper_factory import WhisperFactory +from modules.whisper.whisper_parameter import WhisperValues +from test_config import * +from test_transcription import download_file, test_transcribe + +import gradio as gr +import pytest +import torch +import os + + +@pytest.mark.skipif( + not is_cuda_available(), + reason="Skipping because the test only works on GPU" +) +@pytest.mark.parametrize( + "whisper_type,vad_filter,bgm_separation,diarization", + [ + ("whisper", False, True, False), + ("faster-whisper", False, True, False), + ("insanely_fast_whisper", False, True, False) + ] +) +def test_bgm_separation_pipeline( + whisper_type: str, + vad_filter: bool, + bgm_separation: bool, + diarization: bool, +): + test_transcribe(whisper_type, vad_filter, bgm_separation, diarization) + + +@pytest.mark.skipif( + not is_cuda_available(), + reason="Skipping because the test only works on GPU" +) +@pytest.mark.parametrize( + "whisper_type,vad_filter,bgm_separation,diarization", + [ + ("whisper", True, True, False), + ("faster-whisper", True, True, False), + ("insanely_fast_whisper", True, True, False) + ] +) +def test_bgm_separation_with_vad_pipeline( + whisper_type: str, + vad_filter: bool, + bgm_separation: bool, + diarization: bool, +): + test_transcribe(whisper_type, vad_filter, bgm_separation, diarization) + diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0f60aa582d440a2765bdc0b524f3d283de18cc9f --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,17 @@ +from modules.utils.paths import * + +import os +import torch + +TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav" +TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav") +TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer" +TEST_WHISPER_MODEL = "tiny" +TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4" +TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M" +TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt") +TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt") + + +def is_cuda_available(): + return torch.cuda.is_available() diff --git a/tests/test_diarization.py b/tests/test_diarization.py new file mode 100644 index 0000000000000000000000000000000000000000..54e7244ef97fb23fa9fffae9d8a2adf1876f9ad2 --- /dev/null +++ b/tests/test_diarization.py @@ -0,0 +1,31 @@ +from modules.utils.paths import * +from modules.whisper.whisper_factory import WhisperFactory +from modules.whisper.whisper_parameter import WhisperValues +from test_config import * +from test_transcription import download_file, test_transcribe + +import gradio as gr +import pytest +import os + + +@pytest.mark.skipif( + not is_cuda_available(), + reason="Skipping because the test only works on GPU" +) +@pytest.mark.parametrize( + "whisper_type,vad_filter,bgm_separation,diarization", + [ + ("whisper", False, False, True), + ("faster-whisper", False, False, True), + ("insanely_fast_whisper", False, False, True) + ] +) +def test_diarization_pipeline( + whisper_type: str, + vad_filter: bool, + bgm_separation: bool, + diarization: bool, +): + test_transcribe(whisper_type, vad_filter, bgm_separation, diarization) + diff --git a/tests/test_srt.srt b/tests/test_srt.srt new file mode 100644 index 0000000000000000000000000000000000000000..4874c7f4ae0e3ecfd5f3d0b89a981633cd108c20 --- /dev/null +++ b/tests/test_srt.srt @@ -0,0 +1,7 @@ +1 +00:00:00,000 --> 00:00:02,240 +You've got + +2 +00:00:02,240 --> 00:00:04,160 +a friend in me. diff --git a/tests/test_transcription.py b/tests/test_transcription.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5ab98f2aa59d6c497c8ea7529b79d19a232b74 --- /dev/null +++ b/tests/test_transcription.py @@ -0,0 +1,97 @@ +from modules.whisper.whisper_factory import WhisperFactory +from modules.whisper.whisper_parameter import WhisperValues +from modules.utils.paths import WEBUI_DIR +from test_config import * + +import requests +import pytest +import gradio as gr +import os + + +@pytest.mark.parametrize( + "whisper_type,vad_filter,bgm_separation,diarization", + [ + ("whisper", False, False, False), + ("faster-whisper", False, False, False), + ("insanely_fast_whisper", False, False, False) + ] +) +def test_transcribe( + whisper_type: str, + vad_filter: bool, + bgm_separation: bool, + diarization: bool, +): + audio_path_dir = os.path.join(WEBUI_DIR, "tests") + audio_path = os.path.join(audio_path_dir, "jfk.wav") + if not os.path.exists(audio_path): + download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir) + + whisper_inferencer = WhisperFactory.create_whisper_inference( + whisper_type=whisper_type, + ) + print( + f"""Whisper Device : {whisper_inferencer.device}\n""" + f"""BGM Separation Device: {whisper_inferencer.music_separator.device}\n""" + f"""Diarization Device: {whisper_inferencer.diarizer.device}""" + ) + + hparams = WhisperValues( + model_size=TEST_WHISPER_MODEL, + vad_filter=vad_filter, + is_bgm_separate=bgm_separation, + compute_type=whisper_inferencer.current_compute_type, + uvr_enable_offload=True, + is_diarize=diarization, + ).as_list() + + subtitle_str, file_path = whisper_inferencer.transcribe_file( + [audio_path], + None, + "SRT", + False, + gr.Progress(), + *hparams, + ) + + assert isinstance(subtitle_str, str) and subtitle_str + assert isinstance(file_path[0], str) and file_path + + whisper_inferencer.transcribe_youtube( + TEST_YOUTUBE_URL, + "SRT", + False, + gr.Progress(), + *hparams, + ) + assert isinstance(subtitle_str, str) and subtitle_str + assert isinstance(file_path[0], str) and file_path + + whisper_inferencer.transcribe_mic( + audio_path, + "SRT", + False, + gr.Progress(), + *hparams, + ) + assert isinstance(subtitle_str, str) and subtitle_str + assert isinstance(file_path[0], str) and file_path + + +def download_file(url, save_dir): + if os.path.exists(TEST_FILE_PATH): + return + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + file_name = url.split("/")[-1] + file_path = os.path.join(save_dir, file_name) + + response = requests.get(url) + + with open(file_path, "wb") as file: + file.write(response.content) + + print(f"File downloaded to: {file_path}") diff --git a/tests/test_translation.py b/tests/test_translation.py new file mode 100644 index 0000000000000000000000000000000000000000..7e63e1a6a77a10011b997a5f9daa2929691403c3 --- /dev/null +++ b/tests/test_translation.py @@ -0,0 +1,52 @@ +from modules.translation.deepl_api import DeepLAPI +from modules.translation.nllb_inference import NLLBInference +from test_config import * + +import os +import pytest + + +@pytest.mark.parametrize("model_size, file_path", [ + (TEST_NLLB_MODEL, TEST_SUBTITLE_SRT_PATH), + (TEST_NLLB_MODEL, TEST_SUBTITLE_VTT_PATH), +]) +def test_nllb_inference( + model_size: str, + file_path: str +): + nllb_inferencer = NLLBInference() + print(f"NLLB Device : {nllb_inferencer.device}") + + result_str, file_paths = nllb_inferencer.translate_file( + fileobjs=[file_path], + model_size=model_size, + src_lang="eng_Latn", + tgt_lang="kor_Hang", + ) + + assert isinstance(result_str, str) + assert isinstance(file_paths[0], str) + + +@pytest.mark.parametrize("file_path", [ + TEST_SUBTITLE_SRT_PATH, + TEST_SUBTITLE_VTT_PATH, +]) +def test_deepl_api( + file_path: str +): + deepl_api = DeepLAPI() + + api_key = os.getenv("DEEPL_API_KEY") + + result_str, file_paths = deepl_api.translate_deepl( + auth_key=api_key, + fileobjs=[file_path], + source_lang="English", + target_lang="Korean", + is_pro=False, + add_timestamp=True, + ) + + assert isinstance(result_str, str) + assert isinstance(file_paths[0], str) diff --git a/tests/test_vad.py b/tests/test_vad.py new file mode 100644 index 0000000000000000000000000000000000000000..124a043dda69d86801d5141cdf2b68fcca3e7852 --- /dev/null +++ b/tests/test_vad.py @@ -0,0 +1,26 @@ +from modules.utils.paths import * +from modules.whisper.whisper_factory import WhisperFactory +from modules.whisper.whisper_parameter import WhisperValues +from test_config import * +from test_transcription import download_file, test_transcribe + +import gradio as gr +import pytest +import os + + +@pytest.mark.parametrize( + "whisper_type,vad_filter,bgm_separation,diarization", + [ + ("whisper", True, False, False), + ("faster-whisper", True, False, False), + ("insanely_fast_whisper", True, False, False) + ] +) +def test_vad_pipeline( + whisper_type: str, + vad_filter: bool, + bgm_separation: bool, + diarization: bool, +): + test_transcribe(whisper_type, vad_filter, bgm_separation, diarization) diff --git a/tests/test_vtt.vtt b/tests/test_vtt.vtt new file mode 100644 index 0000000000000000000000000000000000000000..2157e25e64951fded8103a47dda2c5f5f6a61671 --- /dev/null +++ b/tests/test_vtt.vtt @@ -0,0 +1,6 @@ +WEBVTT +00:00:00.500 --> 00:00:02.000 +You've got + +00:00:02.500 --> 00:00:04.300 +a friend in me. \ No newline at end of file diff --git a/user-start-webui.bat b/user-start-webui.bat new file mode 100644 index 0000000000000000000000000000000000000000..a1455d05a2ae1f44700b2b069129324f93dd8dfc --- /dev/null +++ b/user-start-webui.bat @@ -0,0 +1,61 @@ +@echo off +:: This batch file is for launching with command line args +:: See the wiki for a guide to command line arguments: https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments +:: Set the values here to whatever you want. See the wiki above for how to set this. +set SERVER_NAME= +set SERVER_PORT= +set USERNAME= +set PASSWORD= +set SHARE= +set THEME= +set API_OPEN= +set WHISPER_TYPE= +set WHISPER_MODEL_DIR= +set FASTER_WHISPER_MODEL_DIR= +set INSANELY_FAST_WHISPER_MODEL_DIR= +set DIARIZATION_MODEL_DIR= + + +if not "%SERVER_NAME%"=="" ( + set SERVER_NAME_ARG=--server_name %SERVER_NAME% +) +if not "%SERVER_PORT%"=="" ( + set SERVER_PORT_ARG=--server_port %SERVER_PORT% +) +if not "%USERNAME%"=="" ( + set USERNAME_ARG=--username %USERNAME% +) +if not "%PASSWORD%"=="" ( + set PASSWORD_ARG=--password %PASSWORD% +) +if /I "%SHARE%"=="true" ( + set SHARE_ARG=--share +) +if not "%THEME%"=="" ( + set THEME_ARG=--theme %THEME% +) +if /I "%DISABLE_FASTER_WHISPER%"=="true" ( + set DISABLE_FASTER_WHISPER_ARG=--disable_faster_whisper +) +if /I "%API_OPEN%"=="true" ( + set API_OPEN=--api_open +) +if not "%WHISPER_TYPE%"=="" ( + set WHISPER_TYPE_ARG=--whisper_type %WHISPER_TYPE% +) +if not "%WHISPER_MODEL_DIR%"=="" ( + set WHISPER_MODEL_DIR_ARG=--whisper_model_dir "%WHISPER_MODEL_DIR%" +) +if not "%FASTER_WHISPER_MODEL_DIR%"=="" ( + set FASTER_WHISPER_MODEL_DIR_ARG=--faster_whisper_model_dir "%FASTER_WHISPER_MODEL_DIR%" +) +if not "%INSANELY_FAST_WHISPER_MODEL_DIR%"=="" ( + set INSANELY_FAST_WHISPER_MODEL_DIR_ARG=--insanely_fast_whisper_model_dir "%INSANELY_FAST_WHISPER_MODEL_DIR%" +) +if not "%DIARIZATION_MODEL_DIR%"=="" ( + set DIARIZATION_MODEL_DIR_ARG=--diarization_model_dir "%DIARIZATION_MODEL_DIR%" +) + +:: Call the original .bat script with cli arguments +start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %API_OPEN% %WHISPER_TYPE_ARG% %WHISPER_MODEL_DIR_ARG% %FASTER_WHISPER_MODEL_DIR_ARG% %INSANELY_FAST_WHISPER_MODEL_DIR_ARG% %DIARIZATION_MODEL_DIR_ARG% +pause \ No newline at end of file