raygiles3 commited on
Commit
8d31a5a
·
verified ·
1 Parent(s): f376665

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
6
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
7
 
8
  # Retrieve the token from the environment variable
9
  hf_api_token = os.getenv("HF_API_TOKEN")
@@ -24,12 +24,9 @@ model_name = "meta-llama/Llama-2-7b-hf"
24
  with init_empty_weights():
25
  summarization_model = AutoModelForCausalLM.from_pretrained(model_name)
26
 
27
- summarization_model = load_checkpoint_and_dispatch(
28
- summarization_model,
29
- checkpoint_path=model_name,
30
- device_map="auto",
31
- dtype=torch.float16
32
- )
33
 
34
  summarization_tokenizer = AutoTokenizer.from_pretrained(model_name)
35
 
 
3
  from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
6
+ from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model
7
 
8
  # Retrieve the token from the environment variable
9
  hf_api_token = os.getenv("HF_API_TOKEN")
 
24
  with init_empty_weights():
25
  summarization_model = AutoModelForCausalLM.from_pretrained(model_name)
26
 
27
+ # Infer device map and dispatch model
28
+ device_map = infer_auto_device_map(summarization_model, max_memory={0: "14GiB", 1: "14GiB"})
29
+ summarization_model = dispatch_model(summarization_model, device_map=device_map)
 
 
 
30
 
31
  summarization_tokenizer = AutoTokenizer.from_pretrained(model_name)
32