Sofia Casadei commited on
Commit
437ed2e
·
1 Parent(s): 7338a56

fix: use hf-cloudflare turn server

Browse files
Files changed (1) hide show
  1. main.py +4 -14
main.py CHANGED
@@ -29,8 +29,6 @@ from transformers.utils import is_flash_attn_2_available
29
 
30
  from utils.logger_config import setup_logging
31
  from utils.device import get_device, get_torch_and_np_dtypes
32
- from utils.turn_server import get_credential_function, get_rtc_credentials
33
-
34
 
35
  load_dotenv()
36
  setup_logging()
@@ -40,7 +38,6 @@ logger = logging.getLogger(__name__)
40
  UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
41
  UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
42
  APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
43
- TURN_SERVER_PROVIDER = os.getenv("TURN_SERVER_PROVIDER", "hf-cloudflare").lower() # hf-cloudflare | cloudflare | hf | twilio
44
  MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
45
  LANGUAGE = os.getenv("LANGUAGE", "english").lower()
46
 
@@ -48,7 +45,6 @@ device = get_device(force_cpu=False)
48
  torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
49
  logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
50
 
51
-
52
  attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
53
  logger.info(f"Using attention: {attention}")
54
 
@@ -87,7 +83,6 @@ warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
87
  transcribe_pipeline(warmup_audio)
88
  logger.info("Model warmup complete")
89
 
90
-
91
  async def transcribe(audio: tuple[int, np.ndarray]):
92
  sample_rate, audio_array = audio
93
  logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
@@ -104,11 +99,6 @@ async def transcribe(audio: tuple[int, np.ndarray]):
104
  )
105
  yield AdditionalOutputs(outputs["text"].strip())
106
 
107
- async def get_credentials():
108
- return await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN"))
109
-
110
- server_credentials = get_cloudflare_turn_credentials(ttl=360_000) if APP_MODE == "deployed" else None
111
-
112
  logger.info("Initializing FastRTC stream")
113
  stream = Stream(
114
  handler=ReplyOnPause(
@@ -146,8 +136,7 @@ stream = Stream(
146
  gr.Textbox(label="Transcript"),
147
  ],
148
  additional_outputs_handler=lambda current, new: current + " " + new,
149
- rtc_configuration=get_credentials,
150
- server_rtc_configuration=server_credentials,
151
  concurrency_limit=6
152
  )
153
 
@@ -162,8 +151,9 @@ async def index():
162
  elif UI_TYPE == "screen":
163
  html_content = open("static/index-screen.html").read()
164
 
165
- # Use the same server credentials for the client
166
- return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(server_credentials)))
 
167
 
168
  @app.get("/transcript")
169
  def _(webrtc_id: str):
 
29
 
30
  from utils.logger_config import setup_logging
31
  from utils.device import get_device, get_torch_and_np_dtypes
 
 
32
 
33
  load_dotenv()
34
  setup_logging()
 
38
  UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
39
  UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
40
  APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
 
41
  MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
42
  LANGUAGE = os.getenv("LANGUAGE", "english").lower()
43
 
 
45
  torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
46
  logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
47
 
 
48
  attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
49
  logger.info(f"Using attention: {attention}")
50
 
 
83
  transcribe_pipeline(warmup_audio)
84
  logger.info("Model warmup complete")
85
 
 
86
  async def transcribe(audio: tuple[int, np.ndarray]):
87
  sample_rate, audio_array = audio
88
  logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
 
99
  )
100
  yield AdditionalOutputs(outputs["text"].strip())
101
 
 
 
 
 
 
102
  logger.info("Initializing FastRTC stream")
103
  stream = Stream(
104
  handler=ReplyOnPause(
 
136
  gr.Textbox(label="Transcript"),
137
  ],
138
  additional_outputs_handler=lambda current, new: current + " " + new,
139
+ rtc_configuration=get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None,
 
140
  concurrency_limit=6
141
  )
142
 
 
151
  elif UI_TYPE == "screen":
152
  html_content = open("static/index-screen.html").read()
153
 
154
+ rtc_configuration = get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None
155
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_configuration))
156
+ return HTMLResponse(content=html_content)
157
 
158
  @app.get("/transcript")
159
  def _(webrtc_id: str):