Ankerkraut commited on
Commit
aff2efe
·
1 Parent(s): 1323203

initialize as cpu

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -118,7 +118,7 @@ def load_model():
118
  ankerbot_model = AutoModelForCausalLM.from_pretrained(
119
  model_name,
120
  quantization_config=bnb_config,
121
- device_map="cuda:0",
122
  torch_dtype=torch.float16,
123
  use_cache=True,
124
  offload_folder="../offload"
@@ -128,10 +128,10 @@ def load_model():
128
  torch_dtype=torch.float16,
129
  truncation=True,
130
  padding=True, )
 
131
  prompt_format = "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
132
  generator = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=True) # True for flash-attn2 else False
133
  generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=True) # True for flash-attn2 else False
134
-
135
  load_model()
136
  @spaces.GPU
137
  def generate_response(query, context, prompts, max_tokens, temperature, top_p):
 
118
  ankerbot_model = AutoModelForCausalLM.from_pretrained(
119
  model_name,
120
  quantization_config=bnb_config,
121
+ device_map="cpu",
122
  torch_dtype=torch.float16,
123
  use_cache=True,
124
  offload_folder="../offload"
 
128
  torch_dtype=torch.float16,
129
  truncation=True,
130
  padding=True, )
131
+ ankerbot_model.to("cuda")
132
  prompt_format = "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
133
  generator = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=True) # True for flash-attn2 else False
134
  generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=True) # True for flash-attn2 else False
 
135
  load_model()
136
  @spaces.GPU
137
  def generate_response(query, context, prompts, max_tokens, temperature, top_p):