raygiles3 commited on
Commit
0a9e822
·
verified ·
1 Parent(s): 8d31a5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
6
- from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model
7
 
8
  # Retrieve the token from the environment variable
9
  hf_api_token = os.getenv("HF_API_TOKEN")
@@ -24,9 +24,13 @@ model_name = "meta-llama/Llama-2-7b-hf"
24
  with init_empty_weights():
25
  summarization_model = AutoModelForCausalLM.from_pretrained(model_name)
26
 
27
- # Infer device map and dispatch model
28
- device_map = infer_auto_device_map(summarization_model, max_memory={0: "14GiB", 1: "14GiB"})
29
- summarization_model = dispatch_model(summarization_model, device_map=device_map)
 
 
 
 
30
 
31
  summarization_tokenizer = AutoTokenizer.from_pretrained(model_name)
32
 
 
3
  from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
6
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
7
 
8
  # Retrieve the token from the environment variable
9
  hf_api_token = os.getenv("HF_API_TOKEN")
 
24
  with init_empty_weights():
25
  summarization_model = AutoModelForCausalLM.from_pretrained(model_name)
26
 
27
+ # Load checkpoint and dispatch model
28
+ summarization_model = load_checkpoint_and_dispatch(
29
+ summarization_model,
30
+ checkpoint=model_name,
31
+ device_map="auto",
32
+ dtype=torch.float16
33
+ )
34
 
35
  summarization_tokenizer = AutoTokenizer.from_pretrained(model_name)
36