yakine commited on
Commit
8628226
·
verified ·
1 Parent(s): c90ec2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -8
app.py CHANGED
@@ -34,19 +34,14 @@ model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
34
  # Create a pipeline for text generation using GPT-2
35
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
36
 
37
- # Initialize accelerator with disk offload
38
- accelerator = Accelerator(cpu=False, disk_offload=True)
39
-
40
- # Load the Llama-3 model and tokenizer with disk offload
41
  tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
42
  model_llama = AutoModelForCausalLM.from_pretrained(
43
  "meta-llama/Meta-Llama-3-8B",
44
- torch_dtype='auto',
45
  device_map='auto',
46
- offload_folder="offload", # Folder to offload weights to disk
47
- offload_state_dict=True, # Offload state_dict to disk
48
  token=hf_token
49
- ).to(accelerator.device)
50
 
51
  # Define your prompt template
52
  prompt_template = """...""" # Your existing prompt template here
 
34
  # Create a pipeline for text generation using GPT-2
35
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
36
 
37
+ # Load the Llama-3 model and tokenizer once during startup
 
 
 
38
  tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
39
  model_llama = AutoModelForCausalLM.from_pretrained(
40
  "meta-llama/Meta-Llama-3-8B",
41
+ torch_dtype='float16',
42
  device_map='auto',
 
 
43
  token=hf_token
44
+ )
45
 
46
  # Define your prompt template
47
  prompt_template = """...""" # Your existing prompt template here