Ankerkraut commited on
Commit
d2a7626
·
1 Parent(s): f016689

remove device as loaded with accelerate

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -45,7 +45,7 @@ last_messages = []
45
  def load_model():
46
  ankerbot_model = AutoModelForCausalLM.from_pretrained(
47
  model_name,
48
- device_map="auto",
49
  torch_dtype=torch.float16,
50
  use_cache=True,
51
  offload_folder="../offload"
@@ -55,8 +55,8 @@ def load_model():
55
  torch_dtype=torch.float16,
56
  truncation=True,
57
  padding=True, )
58
- generator = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=False, device="cuda:0") # True for flash-attn2 else False
59
- generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=False, device="cuda:0") # True for flash-attn2 else False
60
  return (generator, generator_mini)
61
 
62
  _model_cache = None
 
45
  def load_model():
46
  ankerbot_model = AutoModelForCausalLM.from_pretrained(
47
  model_name,
48
+ device_map="cuda:0",
49
  torch_dtype=torch.float16,
50
  use_cache=True,
51
  offload_folder="../offload"
 
55
  torch_dtype=torch.float16,
56
  truncation=True,
57
  padding=True, )
58
+ generator = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=False) # True for flash-attn2 else False
59
+ generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=False) # True for flash-attn2 else False
60
  return (generator, generator_mini)
61
 
62
  _model_cache = None