raygiles3 commited on
Commit
f376665
·
verified ·
1 Parent(s): 0e2a1b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
 
6
 
7
  # Retrieve the token from the environment variable
8
  hf_api_token = os.getenv("HF_API_TOKEN")
@@ -18,13 +19,19 @@ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
18
  whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
19
 
20
  # Initialize the summarization model and tokenizer
21
- # Use a smaller version of the Llama model and load in FP16
22
- summarization_model = AutoModelForCausalLM.from_pretrained(
23
- "meta-llama/LlamaGuard-7b",
24
- torch_dtype=torch.float16,
25
- low_cpu_mem_usage=True
 
 
 
 
 
26
  )
27
- summarization_tokenizer = AutoTokenizer.from_pretrained("meta-llama/LlamaGuard-7b")
 
28
 
29
  # Function to transcribe audio
30
  def transcribe_audio(audio_file):
@@ -51,7 +58,7 @@ def process_audio(audio_file):
51
  # Gradio UI
52
  iface = gr.Interface(
53
  fn=process_audio,
54
- inputs=gr.Audio(source="upload", type="file"),
55
  outputs=[
56
  gr.Textbox(label="Transcription"),
57
  gr.Textbox(label="Summary")
 
3
  from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
6
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
7
 
8
  # Retrieve the token from the environment variable
9
  hf_api_token = os.getenv("HF_API_TOKEN")
 
19
  whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
20
 
21
  # Initialize the summarization model and tokenizer
22
+ # Load LLAMA 7B model with accelerate
23
+ model_name = "meta-llama/Llama-2-7b-hf"
24
+ with init_empty_weights():
25
+ summarization_model = AutoModelForCausalLM.from_pretrained(model_name)
26
+
27
+ summarization_model = load_checkpoint_and_dispatch(
28
+ summarization_model,
29
+ checkpoint_path=model_name,
30
+ device_map="auto",
31
+ dtype=torch.float16
32
  )
33
+
34
+ summarization_tokenizer = AutoTokenizer.from_pretrained(model_name)
35
 
36
  # Function to transcribe audio
37
  def transcribe_audio(audio_file):
 
58
  # Gradio UI
59
  iface = gr.Interface(
60
  fn=process_audio,
61
+ inputs=gr.Audio(type="file"),
62
  outputs=[
63
  gr.Textbox(label="Transcription"),
64
  gr.Textbox(label="Summary")