mkozak commited on
Commit
629adbe
·
unverified ·
1 Parent(s): df76b39

do not use replicate, run model locally

Browse files
Files changed (2) hide show
  1. main.py +33 -6
  2. requirements.txt +68 -1
main.py CHANGED
@@ -1,4 +1,7 @@
1
- import replicate
 
 
 
2
  from pydantic import BaseModel
3
  from fastapi import FastAPI
4
 
@@ -9,12 +12,36 @@ class URLPayload(BaseModel):
9
  app = FastAPI()
10
 
11
  def process_audio(url: str):
12
- deployment = replicate.deployments.get("meal/incredibly-fast-whisper")
13
- prediction = deployment.predictions.create(
14
- input={ "audio": url }
 
 
 
 
 
 
 
15
  )
16
- prediction.wait()
17
- return prediction.output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  @app.post("/process/")
20
  async def process_audio_endpoint(payload: URLPayload):
 
1
+ import torch
2
+ import requests
3
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
4
+ from datasets import load_dataset
5
  from pydantic import BaseModel
6
  from fastapi import FastAPI
7
 
 
12
  app = FastAPI()
13
 
14
  def process_audio(url: str):
15
+ response = requests.get(url)
16
+ with open("audio.mp3", mode="wb") as file:
17
+ file.write(response.content)
18
+
19
+
20
+ device = "cpu"
21
+
22
+ model_id = "openai/whisper-large-v3"
23
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
24
+ model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True, use_safetensors=True
25
  )
26
+ model.to(device)
27
+
28
+ processor = AutoProcessor.from_pretrained(model_id)
29
+ pipe = pipeline(
30
+ "automatic-speech-recognition",
31
+ model=model,
32
+ tokenizer=processor.tokenizer,
33
+ feature_extractor=processor.feature_extractor,
34
+ max_new_tokens=8192,
35
+ chunk_length_s=30,
36
+ batch_size=16,
37
+ return_timestamps=True,
38
+ torch_dtype=torch.float32,
39
+ device=device
40
+ )
41
+ dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
42
+ whisper_result = pipe("audio.mp3")
43
+ return whisper_result
44
+
45
 
46
  @app.post("/process/")
47
  async def process_audio_endpoint(payload: URLPayload):
requirements.txt CHANGED
@@ -1,21 +1,88 @@
 
 
 
1
  annotated-types==0.6.0
2
  anyio==4.2.0
 
 
 
3
  certifi==2023.11.17
 
 
4
  click==8.1.7
 
 
 
5
  distro==1.9.0
6
  exceptiongroup==1.2.0
7
  fastapi==0.108.0
 
 
 
8
  h11==0.14.0
9
  httpcore==1.0.2
10
  httpx==0.26.0
 
11
  idna==3.6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  openai==1.7.1
13
  packaging==23.2
 
 
 
 
 
 
 
14
  pydantic==2.5.3
15
  pydantic_core==2.14.6
16
- replicate==0.22.0
 
 
 
 
 
 
 
 
17
  sniffio==1.3.0
 
 
18
  starlette==0.32.0.post1
 
 
 
 
19
  tqdm==4.66.1
 
 
20
  typing_extensions==4.9.0
 
 
21
  uvicorn==0.25.0
 
 
 
1
+ accelerate==0.26.1
2
+ aiohttp==3.9.1
3
+ aiosignal==1.3.1
4
  annotated-types==0.6.0
5
  anyio==4.2.0
6
+ async-timeout==4.0.3
7
+ attrs==23.2.0
8
+ audioread==3.0.1
9
  certifi==2023.11.17
10
+ cffi==1.16.0
11
+ charset-normalizer==3.3.2
12
  click==8.1.7
13
+ datasets==2.16.1
14
+ decorator==5.1.1
15
+ dill==0.3.7
16
  distro==1.9.0
17
  exceptiongroup==1.2.0
18
  fastapi==0.108.0
19
+ filelock==3.13.1
20
+ frozenlist==1.4.1
21
+ fsspec==2023.10.0
22
  h11==0.14.0
23
  httpcore==1.0.2
24
  httpx==0.26.0
25
+ huggingface-hub==0.20.2
26
  idna==3.6
27
+ Jinja2==3.1.3
28
+ joblib==1.3.2
29
+ lazy_loader==0.3
30
+ librosa==0.10.1
31
+ llvmlite==0.41.1
32
+ MarkupSafe==2.1.3
33
+ mpmath==1.3.0
34
+ msgpack==1.0.7
35
+ multidict==6.0.4
36
+ multiprocess==0.70.15
37
+ networkx==3.2.1
38
+ numba==0.58.1
39
+ numpy==1.26.3
40
+ nvidia-cublas-cu12==12.1.3.1
41
+ nvidia-cuda-cupti-cu12==12.1.105
42
+ nvidia-cuda-nvrtc-cu12==12.1.105
43
+ nvidia-cuda-runtime-cu12==12.1.105
44
+ nvidia-cudnn-cu12==8.9.2.26
45
+ nvidia-cufft-cu12==11.0.2.54
46
+ nvidia-curand-cu12==10.3.2.106
47
+ nvidia-cusolver-cu12==11.4.5.107
48
+ nvidia-cusparse-cu12==12.1.0.106
49
+ nvidia-nccl-cu12==2.18.1
50
+ nvidia-nvjitlink-cu12==12.3.101
51
+ nvidia-nvtx-cu12==12.1.105
52
  openai==1.7.1
53
  packaging==23.2
54
+ pandas==2.1.4
55
+ platformdirs==4.1.0
56
+ pooch==1.8.0
57
+ psutil==5.9.7
58
+ pyarrow==14.0.2
59
+ pyarrow-hotfix==0.6
60
+ pycparser==2.21
61
  pydantic==2.5.3
62
  pydantic_core==2.14.6
63
+ python-dateutil==2.8.2
64
+ pytz==2023.3.post1
65
+ PyYAML==6.0.1
66
+ regex==2023.12.25
67
+ requests==2.31.0
68
+ safetensors==0.4.1
69
+ scikit-learn==1.3.2
70
+ scipy==1.11.4
71
+ six==1.16.0
72
  sniffio==1.3.0
73
+ soundfile==0.12.1
74
+ soxr==0.3.7
75
  starlette==0.32.0.post1
76
+ sympy==1.12
77
+ threadpoolctl==3.2.0
78
+ tokenizers==0.15.0
79
+ torch==2.1.2
80
  tqdm==4.66.1
81
+ transformers @ git+https://github.com/huggingface/transformers.git@64bdbd888c78dcef5aeaeabc842e12981c8aae7a
82
+ triton==2.1.0
83
  typing_extensions==4.9.0
84
+ tzdata==2023.4
85
+ urllib3==2.1.0
86
  uvicorn==0.25.0
87
+ xxhash==3.4.1
88
+ yarl==1.9.4