yakine commited on
Commit
b6c092a
·
verified ·
1 Parent(s): a73fd11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM
4
  from io import StringIO
5
  import os
 
6
  from huggingface_hub import HfFolder
7
 
8
  # Access the Hugging Face API token from environment variables
@@ -18,7 +19,8 @@ model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
18
  # Load the Llama3 model in sharded mode
19
  model_name = "meta-llama/Meta-Llama-3.1-8B"
20
  try:
21
- model_llama = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", token = hf_token) # use device_map for automatic sharding
 
22
  except OSError as e:
23
  print(f"Error loading model: {e}")
24
 
 
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM
4
  from io import StringIO
5
  import os
6
+ import torch
7
  from huggingface_hub import HfFolder
8
 
9
  # Access the Hugging Face API token from environment variables
 
19
  # Load the Llama3 model in sharded mode
20
  model_name = "meta-llama/Meta-Llama-3.1-8B"
21
  try:
22
+ model_llama = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16,
23
+ load_in_8bit=True. token = hf_token) # use device_map for automatic sharding
24
  except OSError as e:
25
  print(f"Error loading model: {e}")
26