piecurus commited on
Commit
b79461c
·
1 Parent(s): 17ded20

removed GPU inference

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -12,7 +12,6 @@ import torch
12
  import gradio as gr
13
  from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
14
  nltk.download("punkt")
15
- torch_device = 'cuda'
16
 
17
 
18
  # In[ ]:
@@ -23,7 +22,7 @@ model_name = "facebook/wav2vec2-base-960h"
23
 
24
  #model_name = "facebook/wav2vec2-large-xlsr-53"
25
  tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
26
- model = Wav2Vec2ForCTC.from_pretrained(model_name).to(torch_device)
27
 
28
 
29
  # In[ ]:
@@ -96,7 +95,7 @@ def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
96
  if sample_rate !=16000:
97
  speech = librosa.resample(speech, sample_rate,16000)
98
  input_values = tokenizer(speech, return_tensors="pt").input_values
99
- logits = model(input_values.to(torch_device)).logits
100
 
101
  predicted_ids = torch.argmax(logits, dim=-1)
102
  transcription = tokenizer.decode(predicted_ids[0])
 
12
  import gradio as gr
13
  from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
14
  nltk.download("punkt")
 
15
 
16
 
17
  # In[ ]:
 
22
 
23
  #model_name = "facebook/wav2vec2-large-xlsr-53"
24
  tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
25
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
26
 
27
 
28
  # In[ ]:
 
95
  if sample_rate !=16000:
96
  speech = librosa.resample(speech, sample_rate,16000)
97
  input_values = tokenizer(speech, return_tensors="pt").input_values
98
+ logits = model(input_values).logits
99
 
100
  predicted_ids = torch.argmax(logits, dim=-1)
101
  transcription = tokenizer.decode(predicted_ids[0])