Sofia Casadei commited on
Commit
97f18ea
·
1 Parent(s): 382a8a5

fix: turn server config

Browse files
Files changed (2) hide show
  1. main.py +21 -14
  2. utils/turn_server.py +86 -81
main.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import logging
3
  import json
4
  import torch
 
5
 
6
  import gradio as gr
7
  import numpy as np
@@ -16,8 +17,6 @@ from fastrtc import (
16
  AlgoOptions,
17
  SileroVadOptions,
18
  audio_to_bytes,
19
- get_cloudflare_turn_credentials_async,
20
- get_cloudflare_turn_credentials,
21
  )
22
  from transformers import (
23
  AutoModelForSpeechSeq2Seq,
@@ -28,7 +27,7 @@ from transformers.utils import is_flash_attn_2_available
28
 
29
  from utils.logger_config import setup_logging
30
  from utils.device import get_device, get_torch_and_np_dtypes
31
- from utils.turn_server import get_rtc_credentials
32
 
33
 
34
  load_dotenv()
@@ -39,10 +38,10 @@ logger = logging.getLogger(__name__)
39
  UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
40
  UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
41
  APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
 
42
  MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
43
  LANGUAGE = os.getenv("LANGUAGE", "english").lower()
44
 
45
-
46
  device = get_device(force_cpu=False)
47
  torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
48
  logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
@@ -93,8 +92,8 @@ async def transcribe(audio: tuple[int, np.ndarray]):
93
 
94
  outputs = transcribe_pipeline(
95
  audio_to_bytes(audio),
96
- chunk_length_s=6,
97
- batch_size=1,
98
  generate_kwargs={
99
  'task': 'transcribe',
100
  'language': LANGUAGE,
@@ -103,8 +102,8 @@ async def transcribe(audio: tuple[int, np.ndarray]):
103
  )
104
  yield AdditionalOutputs(outputs["text"].strip())
105
 
106
- async def get_credentials():
107
- return await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN"))
108
 
109
  logger.info("Initializing FastRTC stream")
110
  stream = Stream(
@@ -123,12 +122,13 @@ stream = Stream(
123
  threshold=0.5,
124
  # Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
125
  min_speech_duration_ms=250,
126
- # Max duration of speech chunks, longer will be split (default float('inf'))
127
- max_speech_duration_s=6,
 
128
  # Wait for ms at the end of each speech chunk before separating it (default 2000)
129
  min_silence_duration_ms=100,
130
  # Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
131
- window_size_samples=1024,
132
  # Final speech chunks are padded by speech_pad_ms each side (default 400)
133
  speech_pad_ms=200,
134
  ),
@@ -142,8 +142,8 @@ stream = Stream(
142
  gr.Textbox(label="Transcript"),
143
  ],
144
  additional_outputs_handler=lambda current, new: current + " " + new,
145
- rtc_configuration=get_credentials if APP_MODE == "deployed" else None,
146
- server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000) if APP_MODE == "deployed" else None,
147
  concurrency_limit=6
148
  )
149
 
@@ -158,7 +158,14 @@ async def index():
158
  elif UI_TYPE == "screen":
159
  html_content = open("static/index-screen.html").read()
160
 
161
- rtc_config = get_credentials if APP_MODE == "deployed" else None
 
 
 
 
 
 
 
162
  return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)))
163
 
164
  @app.get("/transcript")
 
2
  import logging
3
  import json
4
  import torch
5
+ import asyncio
6
 
7
  import gradio as gr
8
  import numpy as np
 
17
  AlgoOptions,
18
  SileroVadOptions,
19
  audio_to_bytes,
 
 
20
  )
21
  from transformers import (
22
  AutoModelForSpeechSeq2Seq,
 
27
 
28
  from utils.logger_config import setup_logging
29
  from utils.device import get_device, get_torch_and_np_dtypes
30
+ from utils.turn_server import get_credential_function, get_rtc_credentials
31
 
32
 
33
  load_dotenv()
 
38
  UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
39
  UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
40
  APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
41
+ TURN_SERVER_PROVIDER = os.getenv("TURN_SERVER_PROVIDER", "hf-cloudflare").lower() # hf-cloudflare | cloudflare | hf | twilio
42
  MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
43
  LANGUAGE = os.getenv("LANGUAGE", "english").lower()
44
 
 
45
  device = get_device(force_cpu=False)
46
  torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
47
  logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
 
92
 
93
  outputs = transcribe_pipeline(
94
  audio_to_bytes(audio),
95
+ chunk_length_s=3,
96
+ batch_size=2,
97
  generate_kwargs={
98
  'task': 'transcribe',
99
  'language': LANGUAGE,
 
102
  )
103
  yield AdditionalOutputs(outputs["text"].strip())
104
 
105
+ get_credentials = get_credential_function(TURN_SERVER_PROVIDER, is_async=True) if APP_MODE == "deployed" else None
106
+ server_rtc_configuration = get_rtc_credentials(provider=TURN_SERVER_PROVIDER, ttl=360_000) if APP_MODE == "deployed" else None
107
 
108
  logger.info("Initializing FastRTC stream")
109
  stream = Stream(
 
122
  threshold=0.5,
123
  # Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
124
  min_speech_duration_ms=250,
125
+ # Max duration of speech chunks, longer will be split at the timestamp of the last silence
126
+ # that lasts more than 100ms (if any) or just before max_speech_duration_s (default float('inf'))
127
+ max_speech_duration_s=3,
128
  # Wait for ms at the end of each speech chunk before separating it (default 2000)
129
  min_silence_duration_ms=100,
130
  # Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
131
+ window_size_samples=512,
132
  # Final speech chunks are padded by speech_pad_ms each side (default 400)
133
  speech_pad_ms=200,
134
  ),
 
142
  gr.Textbox(label="Transcript"),
143
  ],
144
  additional_outputs_handler=lambda current, new: current + " " + new,
145
+ rtc_configuration=get_credentials,
146
+ server_rtc_configuration=server_rtc_configuration,
147
  concurrency_limit=6
148
  )
149
 
 
158
  elif UI_TYPE == "screen":
159
  html_content = open("static/index-screen.html").read()
160
 
161
+ # Return the actual credentials for client-side, not the function
162
+ rtc_config = None
163
+ if APP_MODE == "deployed":
164
+ if asyncio.iscoroutinefunction(get_credentials):
165
+ rtc_config = await get_credentials()
166
+ else:
167
+ rtc_config = get_rtc_credentials(provider=TURN_SERVER_PROVIDER)
168
+
169
  return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)))
170
 
171
  @app.get("/transcript")
utils/turn_server.py CHANGED
@@ -1,19 +1,24 @@
1
  import os
2
- from typing import Literal, Optional, Dict, Any
3
  import requests
4
 
5
- from fastrtc import get_hf_turn_credentials, get_twilio_turn_credentials
 
 
 
 
 
6
 
7
 
8
  def get_rtc_credentials(
9
- provider: Literal["hf", "twilio", "cloudflare"] = "hf",
10
  **kwargs
11
  ) -> Dict[str, Any]:
12
  """
13
  Get RTC configuration for different TURN server providers.
14
 
15
  Args:
16
- provider: The TURN server provider to use ('hf', 'twilio', or 'cloudflare')
17
  **kwargs: Additional arguments passed to the specific provider's function
18
 
19
  Returns:
@@ -21,99 +26,99 @@ def get_rtc_credentials(
21
  """
22
  try:
23
  if provider == "hf":
24
- return get_hf_credentials(**kwargs)
 
 
 
 
 
 
 
 
25
  elif provider == "twilio":
26
- return get_twilio_credentials(**kwargs)
 
 
 
 
 
 
 
 
 
27
  elif provider == "cloudflare":
28
- return get_cloudflare_credentials(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  except Exception as e:
30
  raise Exception(f"Failed to get RTC credentials ({provider}): {str(e)}")
31
 
32
 
33
- def get_hf_credentials(token: Optional[str] = None) -> Dict[str, Any]:
34
- """
35
- Get credentials for Hugging Face's community TURN server.
36
-
37
- Required setup:
38
- 1. Create a Hugging Face account at huggingface.co
39
- 2. Visit: https://huggingface.co/spaces/fastrtc/turn-server-login
40
- 3. Set HF_TOKEN environment variable or pass token directly
41
- """
42
- token = token or os.environ.get("HF_TOKEN")
43
- if not token:
44
- raise ValueError("HF_TOKEN environment variable not set")
45
-
46
- try:
47
- return get_hf_turn_credentials(token=token)
48
- except Exception as e:
49
- raise Exception(f"Failed to get HF TURN credentials: {str(e)}")
50
-
51
-
52
- def get_twilio_credentials(
53
- account_sid: Optional[str] = None,
54
- auth_token: Optional[str] = None
55
  ) -> Dict[str, Any]:
56
  """
57
- Get credentials for Twilio's TURN server.
58
-
59
- Required setup:
60
- 1. Create a free Twilio account at: https://login.twilio.com/u/signup
61
- 2. Get your Account SID and Auth Token from the Twilio Console
62
- 3. Set environment variables:
63
- - TWILIO_ACCOUNT_SID (or pass directly)
64
- - TWILIO_AUTH_TOKEN (or pass directly)
65
- """
66
- account_sid = account_sid or os.environ.get("TWILIO_ACCOUNT_SID")
67
- auth_token = auth_token or os.environ.get("TWILIO_AUTH_TOKEN")
68
 
69
- if not account_sid or not auth_token:
70
- raise ValueError("Twilio credentials not found. Set TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN env vars")
 
71
 
 
 
 
 
 
 
72
  try:
73
- return get_twilio_turn_credentials(account_sid=account_sid, auth_token=auth_token)
 
 
 
 
 
74
  except Exception as e:
75
- raise Exception(f"Failed to get Twilio TURN credentials: {str(e)}")
76
 
77
 
78
- def get_cloudflare_credentials(
79
- key_id: Optional[str] = None,
80
- api_token: Optional[str] = None,
81
- ttl: int = 86400
82
- ) -> Dict[str, Any]:
83
  """
84
- Get credentials for Cloudflare's TURN server.
85
-
86
- Required setup:
87
- 1. Create a free Cloudflare account
88
- 2. Go to Cloudflare dashboard -> Calls section
89
- 3. Create a TURN App and get the Turn Token ID and API Token
90
- 4. Set environment variables:
91
- - TURN_KEY_ID
92
- - TURN_KEY_API_TOKEN
93
 
94
  Args:
95
- key_id: Cloudflare Turn Token ID (optional, will use env var if not provided)
96
- api_token: Cloudflare API Token (optional, will use env var if not provided)
97
- ttl: Time-to-live for credentials in seconds (default: 24 hours)
 
 
98
  """
99
- key_id = key_id or os.environ.get("TURN_KEY_ID")
100
- api_token = api_token or os.environ.get("TURN_KEY_API_TOKEN")
101
-
102
- if not key_id or not api_token:
103
- raise ValueError("Cloudflare credentials not found. Set TURN_KEY_ID and TURN_KEY_API_TOKEN env vars")
104
-
105
- response = requests.post(
106
- f"https://rtc.live.cloudflare.com/v1/turn/keys/{key_id}/credentials/generate",
107
- headers={
108
- "Authorization": f"Bearer {api_token}",
109
- "Content-Type": "application/json",
110
- },
111
- json={"ttl": ttl},
112
- )
113
-
114
- if response.ok:
115
- return {"iceServers": [response.json()["iceServers"]]}
116
  else:
117
- raise Exception(
118
- f"Failed to get Cloudflare TURN credentials: {response.status_code} {response.text}"
119
- )
 
1
  import os
2
+ from typing import Literal, Optional, Dict, Any, Callable, Awaitable
3
  import requests
4
 
5
+ from fastrtc import (
6
+ get_hf_turn_credentials,
7
+ get_twilio_turn_credentials,
8
+ get_cloudflare_turn_credentials,
9
+ get_cloudflare_turn_credentials_async
10
+ )
11
 
12
 
13
  def get_rtc_credentials(
14
+ provider: Literal["hf", "twilio", "cloudflare", "hf-cloudflare"] = "hf-cloudflare",
15
  **kwargs
16
  ) -> Dict[str, Any]:
17
  """
18
  Get RTC configuration for different TURN server providers.
19
 
20
  Args:
21
+ provider: The TURN server provider to use ('hf', 'twilio', 'cloudflare', or 'hf-cloudflare')
22
  **kwargs: Additional arguments passed to the specific provider's function
23
 
24
  Returns:
 
26
  """
27
  try:
28
  if provider == "hf":
29
+ # HF Community Server (Deprecated)
30
+ # 1. Create a Hugging Face account at huggingface.co
31
+ # 2. Visit: https://huggingface.co/settings/tokens to create a token
32
+ # 3. Set HF_TOKEN environment variable or pass token directly
33
+ token = kwargs.pop("token", os.environ.get("HF_TOKEN"))
34
+ if not token:
35
+ raise ValueError("HF_TOKEN environment variable not set")
36
+ return get_hf_turn_credentials(token=token)
37
+
38
  elif provider == "twilio":
39
+ # Twilio TURN Server
40
+ # 1. Create a free Twilio account at: https://login.twilio.com/u/signup
41
+ # 2. Get your Account SID and Auth Token from the Twilio Console
42
+ # 3. Set environment variables: TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN
43
+ account_sid = kwargs.pop("account_sid", os.environ.get("TWILIO_ACCOUNT_SID"))
44
+ auth_token = kwargs.pop("auth_token", os.environ.get("TWILIO_AUTH_TOKEN"))
45
+ if not account_sid or not auth_token:
46
+ raise ValueError("Twilio credentials not found. Set TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN env vars")
47
+ return get_twilio_turn_credentials(account_sid=account_sid, auth_token=auth_token)
48
+
49
  elif provider == "cloudflare":
50
+ # Cloudflare TURN Server
51
+ # 1. Create a free Cloudflare account
52
+ # 2. Go to Cloudflare dashboard -> Calls section
53
+ # 3. Create a TURN App and get the Turn Token ID and API Token
54
+ # 4. Set environment variables: TURN_KEY_ID and TURN_KEY_API_TOKEN
55
+ key_id = kwargs.pop("key_id", os.environ.get("TURN_KEY_ID"))
56
+ api_token = kwargs.pop("api_token", os.environ.get("TURN_KEY_API_TOKEN"))
57
+ ttl = kwargs.pop("ttl", 86400)
58
+ if not key_id or not api_token:
59
+ raise ValueError("Cloudflare credentials not found. Set TURN_KEY_ID and TURN_KEY_API_TOKEN env vars")
60
+ return get_cloudflare_turn_credentials(key_id=key_id, api_token=api_token, ttl=ttl)
61
+
62
+ elif provider == "hf-cloudflare":
63
+ # Cloudflare with Hugging Face Token (10GB free traffic per month)
64
+ # 1. Create a Hugging Face account at huggingface.co
65
+ # 2. Visit: https://huggingface.co/settings/tokens to create a token
66
+ # 3. Set HF_TOKEN environment variable or pass token directly
67
+ hf_token = kwargs.pop("hf_token", os.environ.get("HF_TOKEN"))
68
+ ttl = kwargs.pop("ttl", 86400)
69
+ if not hf_token:
70
+ raise ValueError("HF_TOKEN environment variable not set")
71
+ return get_cloudflare_turn_credentials(hf_token=hf_token, ttl=ttl)
72
+ else:
73
+ raise ValueError(f"Unknown provider: {provider}")
74
  except Exception as e:
75
  raise Exception(f"Failed to get RTC credentials ({provider}): {str(e)}")
76
 
77
 
78
+ async def get_rtc_credentials_async(
79
+ provider: Literal["hf-cloudflare"] = "hf-cloudflare",
80
+ **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  ) -> Dict[str, Any]:
82
  """
83
+ Get RTC configuration asynchronously for different TURN server providers.
 
 
 
 
 
 
 
 
 
 
84
 
85
+ Args:
86
+ provider: Currently only supports 'hf-cloudflare'
87
+ **kwargs: Additional arguments passed to the specific provider's function
88
 
89
+ Returns:
90
+ Dictionary containing the RTC configuration
91
+ """
92
+ if provider != "hf-cloudflare":
93
+ raise NotImplementedError(f"Async credentials for {provider} not implemented")
94
+
95
  try:
96
+ # Cloudflare with Hugging Face Token (10GB free traffic per month)
97
+ hf_token = kwargs.pop("hf_token", os.environ.get("HF_TOKEN"))
98
+ ttl = kwargs.pop("ttl", 600) # Default 10 minutes for client-side
99
+ if not hf_token:
100
+ raise ValueError("HF_TOKEN environment variable not set")
101
+ return await get_cloudflare_turn_credentials_async(hf_token=hf_token, ttl=ttl)
102
  except Exception as e:
103
+ raise Exception(f"Failed to get async RTC credentials: {str(e)}")
104
 
105
 
106
+ def get_credential_function(provider: str, is_async: bool = False) -> Callable:
 
 
 
 
107
  """
108
+ Get the appropriate credential function based on provider and whether async is needed.
 
 
 
 
 
 
 
 
109
 
110
  Args:
111
+ provider: The TURN server provider
112
+ is_async: Whether to return an async function
113
+
114
+ Returns:
115
+ Function that returns credentials (async or sync)
116
  """
117
+ if is_async and provider == "hf-cloudflare":
118
+ async def get_creds():
119
+ return await get_rtc_credentials_async(provider=provider)
120
+ return get_creds
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  else:
122
+ def get_creds():
123
+ return get_rtc_credentials(provider=provider)
124
+ return get_creds