raygiles3 commited on
Commit
0deaf65
·
verified ·
1 Parent(s): fdf0af4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -16
app.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
6
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
7
 
8
  # Retrieve the token from the environment variable
9
  hf_api_token = os.getenv("HF_API_TOKEN")
@@ -19,20 +18,8 @@ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
19
  whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
20
 
21
  # Initialize the summarization model and tokenizer
22
- # Load LLAMA 7B model with accelerate
23
- model_name = "meta-llama/Llama-2-7b-hf"
24
- with init_empty_weights():
25
- summarization_model = AutoModelForCausalLM.from_pretrained(model_name)
26
-
27
- # Load checkpoint and dispatch model
28
- summarization_model = load_checkpoint_and_dispatch(
29
- summarization_model,
30
- checkpoint=model_name,
31
- device_map="auto",
32
- dtype=torch.float16
33
- )
34
-
35
- summarization_tokenizer = AutoTokenizer.from_pretrained(model_name)
36
 
37
  # Function to transcribe audio
38
  def transcribe_audio(audio_file):
@@ -59,7 +46,7 @@ def process_audio(audio_file):
59
  # Gradio UI
60
  iface = gr.Interface(
61
  fn=process_audio,
62
- inputs=gr.Audio(type="file"),
63
  outputs=[
64
  gr.Textbox(label="Transcription"),
65
  gr.Textbox(label="Summary")
 
3
  from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
 
6
 
7
  # Retrieve the token from the environment variable
8
  hf_api_token = os.getenv("HF_API_TOKEN")
 
18
  whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
19
 
20
  # Initialize the summarization model and tokenizer
21
+ summarization_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
22
+ summarization_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Function to transcribe audio
25
  def transcribe_audio(audio_file):
 
46
  # Gradio UI
47
  iface = gr.Interface(
48
  fn=process_audio,
49
+ inputs=gr.Audio(source="upload", type="file"),
50
  outputs=[
51
  gr.Textbox(label="Transcription"),
52
  gr.Textbox(label="Summary")