Spaces:
Sleeping
Sleeping
Sofia Casadei
commited on
Commit
·
97f18ea
1
Parent(s):
382a8a5
fix: turn server config
Browse files- main.py +21 -14
- utils/turn_server.py +86 -81
main.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import logging
|
3 |
import json
|
4 |
import torch
|
|
|
5 |
|
6 |
import gradio as gr
|
7 |
import numpy as np
|
@@ -16,8 +17,6 @@ from fastrtc import (
|
|
16 |
AlgoOptions,
|
17 |
SileroVadOptions,
|
18 |
audio_to_bytes,
|
19 |
-
get_cloudflare_turn_credentials_async,
|
20 |
-
get_cloudflare_turn_credentials,
|
21 |
)
|
22 |
from transformers import (
|
23 |
AutoModelForSpeechSeq2Seq,
|
@@ -28,7 +27,7 @@ from transformers.utils import is_flash_attn_2_available
|
|
28 |
|
29 |
from utils.logger_config import setup_logging
|
30 |
from utils.device import get_device, get_torch_and_np_dtypes
|
31 |
-
from utils.turn_server import get_rtc_credentials
|
32 |
|
33 |
|
34 |
load_dotenv()
|
@@ -39,10 +38,10 @@ logger = logging.getLogger(__name__)
|
|
39 |
UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
|
40 |
UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
|
41 |
APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
|
|
|
42 |
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
|
43 |
LANGUAGE = os.getenv("LANGUAGE", "english").lower()
|
44 |
|
45 |
-
|
46 |
device = get_device(force_cpu=False)
|
47 |
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
|
48 |
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
|
@@ -93,8 +92,8 @@ async def transcribe(audio: tuple[int, np.ndarray]):
|
|
93 |
|
94 |
outputs = transcribe_pipeline(
|
95 |
audio_to_bytes(audio),
|
96 |
-
chunk_length_s=
|
97 |
-
batch_size=
|
98 |
generate_kwargs={
|
99 |
'task': 'transcribe',
|
100 |
'language': LANGUAGE,
|
@@ -103,8 +102,8 @@ async def transcribe(audio: tuple[int, np.ndarray]):
|
|
103 |
)
|
104 |
yield AdditionalOutputs(outputs["text"].strip())
|
105 |
|
106 |
-
|
107 |
-
|
108 |
|
109 |
logger.info("Initializing FastRTC stream")
|
110 |
stream = Stream(
|
@@ -123,12 +122,13 @@ stream = Stream(
|
|
123 |
threshold=0.5,
|
124 |
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
|
125 |
min_speech_duration_ms=250,
|
126 |
-
# Max duration of speech chunks, longer will be split
|
127 |
-
max_speech_duration_s
|
|
|
128 |
# Wait for ms at the end of each speech chunk before separating it (default 2000)
|
129 |
min_silence_duration_ms=100,
|
130 |
# Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
|
131 |
-
window_size_samples=
|
132 |
# Final speech chunks are padded by speech_pad_ms each side (default 400)
|
133 |
speech_pad_ms=200,
|
134 |
),
|
@@ -142,8 +142,8 @@ stream = Stream(
|
|
142 |
gr.Textbox(label="Transcript"),
|
143 |
],
|
144 |
additional_outputs_handler=lambda current, new: current + " " + new,
|
145 |
-
rtc_configuration=get_credentials
|
146 |
-
server_rtc_configuration=
|
147 |
concurrency_limit=6
|
148 |
)
|
149 |
|
@@ -158,7 +158,14 @@ async def index():
|
|
158 |
elif UI_TYPE == "screen":
|
159 |
html_content = open("static/index-screen.html").read()
|
160 |
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)))
|
163 |
|
164 |
@app.get("/transcript")
|
|
|
2 |
import logging
|
3 |
import json
|
4 |
import torch
|
5 |
+
import asyncio
|
6 |
|
7 |
import gradio as gr
|
8 |
import numpy as np
|
|
|
17 |
AlgoOptions,
|
18 |
SileroVadOptions,
|
19 |
audio_to_bytes,
|
|
|
|
|
20 |
)
|
21 |
from transformers import (
|
22 |
AutoModelForSpeechSeq2Seq,
|
|
|
27 |
|
28 |
from utils.logger_config import setup_logging
|
29 |
from utils.device import get_device, get_torch_and_np_dtypes
|
30 |
+
from utils.turn_server import get_credential_function, get_rtc_credentials
|
31 |
|
32 |
|
33 |
load_dotenv()
|
|
|
38 |
UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
|
39 |
UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
|
40 |
APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
|
41 |
+
TURN_SERVER_PROVIDER = os.getenv("TURN_SERVER_PROVIDER", "hf-cloudflare").lower() # hf-cloudflare | cloudflare | hf | twilio
|
42 |
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
|
43 |
LANGUAGE = os.getenv("LANGUAGE", "english").lower()
|
44 |
|
|
|
45 |
device = get_device(force_cpu=False)
|
46 |
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
|
47 |
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
|
|
|
92 |
|
93 |
outputs = transcribe_pipeline(
|
94 |
audio_to_bytes(audio),
|
95 |
+
chunk_length_s=3,
|
96 |
+
batch_size=2,
|
97 |
generate_kwargs={
|
98 |
'task': 'transcribe',
|
99 |
'language': LANGUAGE,
|
|
|
102 |
)
|
103 |
yield AdditionalOutputs(outputs["text"].strip())
|
104 |
|
105 |
+
get_credentials = get_credential_function(TURN_SERVER_PROVIDER, is_async=True) if APP_MODE == "deployed" else None
|
106 |
+
server_rtc_configuration = get_rtc_credentials(provider=TURN_SERVER_PROVIDER, ttl=360_000) if APP_MODE == "deployed" else None
|
107 |
|
108 |
logger.info("Initializing FastRTC stream")
|
109 |
stream = Stream(
|
|
|
122 |
threshold=0.5,
|
123 |
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
|
124 |
min_speech_duration_ms=250,
|
125 |
+
# Max duration of speech chunks, longer will be split at the timestamp of the last silence
|
126 |
+
# that lasts more than 100ms (if any) or just before max_speech_duration_s (default float('inf'))
|
127 |
+
max_speech_duration_s=3,
|
128 |
# Wait for ms at the end of each speech chunk before separating it (default 2000)
|
129 |
min_silence_duration_ms=100,
|
130 |
# Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
|
131 |
+
window_size_samples=512,
|
132 |
# Final speech chunks are padded by speech_pad_ms each side (default 400)
|
133 |
speech_pad_ms=200,
|
134 |
),
|
|
|
142 |
gr.Textbox(label="Transcript"),
|
143 |
],
|
144 |
additional_outputs_handler=lambda current, new: current + " " + new,
|
145 |
+
rtc_configuration=get_credentials,
|
146 |
+
server_rtc_configuration=server_rtc_configuration,
|
147 |
concurrency_limit=6
|
148 |
)
|
149 |
|
|
|
158 |
elif UI_TYPE == "screen":
|
159 |
html_content = open("static/index-screen.html").read()
|
160 |
|
161 |
+
# Return the actual credentials for client-side, not the function
|
162 |
+
rtc_config = None
|
163 |
+
if APP_MODE == "deployed":
|
164 |
+
if asyncio.iscoroutinefunction(get_credentials):
|
165 |
+
rtc_config = await get_credentials()
|
166 |
+
else:
|
167 |
+
rtc_config = get_rtc_credentials(provider=TURN_SERVER_PROVIDER)
|
168 |
+
|
169 |
return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)))
|
170 |
|
171 |
@app.get("/transcript")
|
utils/turn_server.py
CHANGED
@@ -1,19 +1,24 @@
|
|
1 |
import os
|
2 |
-
from typing import Literal, Optional, Dict, Any
|
3 |
import requests
|
4 |
|
5 |
-
from fastrtc import
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
def get_rtc_credentials(
|
9 |
-
provider: Literal["hf", "twilio", "cloudflare"] = "hf",
|
10 |
**kwargs
|
11 |
) -> Dict[str, Any]:
|
12 |
"""
|
13 |
Get RTC configuration for different TURN server providers.
|
14 |
|
15 |
Args:
|
16 |
-
provider: The TURN server provider to use ('hf', 'twilio', or 'cloudflare')
|
17 |
**kwargs: Additional arguments passed to the specific provider's function
|
18 |
|
19 |
Returns:
|
@@ -21,99 +26,99 @@ def get_rtc_credentials(
|
|
21 |
"""
|
22 |
try:
|
23 |
if provider == "hf":
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
elif provider == "twilio":
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
elif provider == "cloudflare":
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
except Exception as e:
|
30 |
raise Exception(f"Failed to get RTC credentials ({provider}): {str(e)}")
|
31 |
|
32 |
|
33 |
-
def
|
34 |
-
"""
|
35 |
-
|
36 |
-
|
37 |
-
Required setup:
|
38 |
-
1. Create a Hugging Face account at huggingface.co
|
39 |
-
2. Visit: https://huggingface.co/spaces/fastrtc/turn-server-login
|
40 |
-
3. Set HF_TOKEN environment variable or pass token directly
|
41 |
-
"""
|
42 |
-
token = token or os.environ.get("HF_TOKEN")
|
43 |
-
if not token:
|
44 |
-
raise ValueError("HF_TOKEN environment variable not set")
|
45 |
-
|
46 |
-
try:
|
47 |
-
return get_hf_turn_credentials(token=token)
|
48 |
-
except Exception as e:
|
49 |
-
raise Exception(f"Failed to get HF TURN credentials: {str(e)}")
|
50 |
-
|
51 |
-
|
52 |
-
def get_twilio_credentials(
|
53 |
-
account_sid: Optional[str] = None,
|
54 |
-
auth_token: Optional[str] = None
|
55 |
) -> Dict[str, Any]:
|
56 |
"""
|
57 |
-
Get
|
58 |
-
|
59 |
-
Required setup:
|
60 |
-
1. Create a free Twilio account at: https://login.twilio.com/u/signup
|
61 |
-
2. Get your Account SID and Auth Token from the Twilio Console
|
62 |
-
3. Set environment variables:
|
63 |
-
- TWILIO_ACCOUNT_SID (or pass directly)
|
64 |
-
- TWILIO_AUTH_TOKEN (or pass directly)
|
65 |
-
"""
|
66 |
-
account_sid = account_sid or os.environ.get("TWILIO_ACCOUNT_SID")
|
67 |
-
auth_token = auth_token or os.environ.get("TWILIO_AUTH_TOKEN")
|
68 |
|
69 |
-
|
70 |
-
|
|
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
try:
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
74 |
except Exception as e:
|
75 |
-
raise Exception(f"Failed to get
|
76 |
|
77 |
|
78 |
-
def
|
79 |
-
key_id: Optional[str] = None,
|
80 |
-
api_token: Optional[str] = None,
|
81 |
-
ttl: int = 86400
|
82 |
-
) -> Dict[str, Any]:
|
83 |
"""
|
84 |
-
Get
|
85 |
-
|
86 |
-
Required setup:
|
87 |
-
1. Create a free Cloudflare account
|
88 |
-
2. Go to Cloudflare dashboard -> Calls section
|
89 |
-
3. Create a TURN App and get the Turn Token ID and API Token
|
90 |
-
4. Set environment variables:
|
91 |
-
- TURN_KEY_ID
|
92 |
-
- TURN_KEY_API_TOKEN
|
93 |
|
94 |
Args:
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
98 |
"""
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
raise ValueError("Cloudflare credentials not found. Set TURN_KEY_ID and TURN_KEY_API_TOKEN env vars")
|
104 |
-
|
105 |
-
response = requests.post(
|
106 |
-
f"https://rtc.live.cloudflare.com/v1/turn/keys/{key_id}/credentials/generate",
|
107 |
-
headers={
|
108 |
-
"Authorization": f"Bearer {api_token}",
|
109 |
-
"Content-Type": "application/json",
|
110 |
-
},
|
111 |
-
json={"ttl": ttl},
|
112 |
-
)
|
113 |
-
|
114 |
-
if response.ok:
|
115 |
-
return {"iceServers": [response.json()["iceServers"]]}
|
116 |
else:
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
1 |
import os
|
2 |
+
from typing import Literal, Optional, Dict, Any, Callable, Awaitable
|
3 |
import requests
|
4 |
|
5 |
+
from fastrtc import (
|
6 |
+
get_hf_turn_credentials,
|
7 |
+
get_twilio_turn_credentials,
|
8 |
+
get_cloudflare_turn_credentials,
|
9 |
+
get_cloudflare_turn_credentials_async
|
10 |
+
)
|
11 |
|
12 |
|
13 |
def get_rtc_credentials(
|
14 |
+
provider: Literal["hf", "twilio", "cloudflare", "hf-cloudflare"] = "hf-cloudflare",
|
15 |
**kwargs
|
16 |
) -> Dict[str, Any]:
|
17 |
"""
|
18 |
Get RTC configuration for different TURN server providers.
|
19 |
|
20 |
Args:
|
21 |
+
provider: The TURN server provider to use ('hf', 'twilio', 'cloudflare', or 'hf-cloudflare')
|
22 |
**kwargs: Additional arguments passed to the specific provider's function
|
23 |
|
24 |
Returns:
|
|
|
26 |
"""
|
27 |
try:
|
28 |
if provider == "hf":
|
29 |
+
# HF Community Server (Deprecated)
|
30 |
+
# 1. Create a Hugging Face account at huggingface.co
|
31 |
+
# 2. Visit: https://huggingface.co/settings/tokens to create a token
|
32 |
+
# 3. Set HF_TOKEN environment variable or pass token directly
|
33 |
+
token = kwargs.pop("token", os.environ.get("HF_TOKEN"))
|
34 |
+
if not token:
|
35 |
+
raise ValueError("HF_TOKEN environment variable not set")
|
36 |
+
return get_hf_turn_credentials(token=token)
|
37 |
+
|
38 |
elif provider == "twilio":
|
39 |
+
# Twilio TURN Server
|
40 |
+
# 1. Create a free Twilio account at: https://login.twilio.com/u/signup
|
41 |
+
# 2. Get your Account SID and Auth Token from the Twilio Console
|
42 |
+
# 3. Set environment variables: TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN
|
43 |
+
account_sid = kwargs.pop("account_sid", os.environ.get("TWILIO_ACCOUNT_SID"))
|
44 |
+
auth_token = kwargs.pop("auth_token", os.environ.get("TWILIO_AUTH_TOKEN"))
|
45 |
+
if not account_sid or not auth_token:
|
46 |
+
raise ValueError("Twilio credentials not found. Set TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN env vars")
|
47 |
+
return get_twilio_turn_credentials(account_sid=account_sid, auth_token=auth_token)
|
48 |
+
|
49 |
elif provider == "cloudflare":
|
50 |
+
# Cloudflare TURN Server
|
51 |
+
# 1. Create a free Cloudflare account
|
52 |
+
# 2. Go to Cloudflare dashboard -> Calls section
|
53 |
+
# 3. Create a TURN App and get the Turn Token ID and API Token
|
54 |
+
# 4. Set environment variables: TURN_KEY_ID and TURN_KEY_API_TOKEN
|
55 |
+
key_id = kwargs.pop("key_id", os.environ.get("TURN_KEY_ID"))
|
56 |
+
api_token = kwargs.pop("api_token", os.environ.get("TURN_KEY_API_TOKEN"))
|
57 |
+
ttl = kwargs.pop("ttl", 86400)
|
58 |
+
if not key_id or not api_token:
|
59 |
+
raise ValueError("Cloudflare credentials not found. Set TURN_KEY_ID and TURN_KEY_API_TOKEN env vars")
|
60 |
+
return get_cloudflare_turn_credentials(key_id=key_id, api_token=api_token, ttl=ttl)
|
61 |
+
|
62 |
+
elif provider == "hf-cloudflare":
|
63 |
+
# Cloudflare with Hugging Face Token (10GB free traffic per month)
|
64 |
+
# 1. Create a Hugging Face account at huggingface.co
|
65 |
+
# 2. Visit: https://huggingface.co/settings/tokens to create a token
|
66 |
+
# 3. Set HF_TOKEN environment variable or pass token directly
|
67 |
+
hf_token = kwargs.pop("hf_token", os.environ.get("HF_TOKEN"))
|
68 |
+
ttl = kwargs.pop("ttl", 86400)
|
69 |
+
if not hf_token:
|
70 |
+
raise ValueError("HF_TOKEN environment variable not set")
|
71 |
+
return get_cloudflare_turn_credentials(hf_token=hf_token, ttl=ttl)
|
72 |
+
else:
|
73 |
+
raise ValueError(f"Unknown provider: {provider}")
|
74 |
except Exception as e:
|
75 |
raise Exception(f"Failed to get RTC credentials ({provider}): {str(e)}")
|
76 |
|
77 |
|
78 |
+
async def get_rtc_credentials_async(
|
79 |
+
provider: Literal["hf-cloudflare"] = "hf-cloudflare",
|
80 |
+
**kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
) -> Dict[str, Any]:
|
82 |
"""
|
83 |
+
Get RTC configuration asynchronously for different TURN server providers.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
Args:
|
86 |
+
provider: Currently only supports 'hf-cloudflare'
|
87 |
+
**kwargs: Additional arguments passed to the specific provider's function
|
88 |
|
89 |
+
Returns:
|
90 |
+
Dictionary containing the RTC configuration
|
91 |
+
"""
|
92 |
+
if provider != "hf-cloudflare":
|
93 |
+
raise NotImplementedError(f"Async credentials for {provider} not implemented")
|
94 |
+
|
95 |
try:
|
96 |
+
# Cloudflare with Hugging Face Token (10GB free traffic per month)
|
97 |
+
hf_token = kwargs.pop("hf_token", os.environ.get("HF_TOKEN"))
|
98 |
+
ttl = kwargs.pop("ttl", 600) # Default 10 minutes for client-side
|
99 |
+
if not hf_token:
|
100 |
+
raise ValueError("HF_TOKEN environment variable not set")
|
101 |
+
return await get_cloudflare_turn_credentials_async(hf_token=hf_token, ttl=ttl)
|
102 |
except Exception as e:
|
103 |
+
raise Exception(f"Failed to get async RTC credentials: {str(e)}")
|
104 |
|
105 |
|
106 |
+
def get_credential_function(provider: str, is_async: bool = False) -> Callable:
|
|
|
|
|
|
|
|
|
107 |
"""
|
108 |
+
Get the appropriate credential function based on provider and whether async is needed.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
Args:
|
111 |
+
provider: The TURN server provider
|
112 |
+
is_async: Whether to return an async function
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
Function that returns credentials (async or sync)
|
116 |
"""
|
117 |
+
if is_async and provider == "hf-cloudflare":
|
118 |
+
async def get_creds():
|
119 |
+
return await get_rtc_credentials_async(provider=provider)
|
120 |
+
return get_creds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
else:
|
122 |
+
def get_creds():
|
123 |
+
return get_rtc_credentials(provider=provider)
|
124 |
+
return get_creds
|