Spaces:
Running
Running
Commit
·
aff2efe
1
Parent(s):
1323203
initialize as cpu
Browse files
app.py
CHANGED
@@ -118,7 +118,7 @@ def load_model():
|
|
118 |
ankerbot_model = AutoModelForCausalLM.from_pretrained(
|
119 |
model_name,
|
120 |
quantization_config=bnb_config,
|
121 |
-
device_map="
|
122 |
torch_dtype=torch.float16,
|
123 |
use_cache=True,
|
124 |
offload_folder="../offload"
|
@@ -128,10 +128,10 @@ def load_model():
|
|
128 |
torch_dtype=torch.float16,
|
129 |
truncation=True,
|
130 |
padding=True, )
|
|
|
131 |
prompt_format = "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
132 |
generator = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=True) # True for flash-attn2 else False
|
133 |
generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=True) # True for flash-attn2 else False
|
134 |
-
|
135 |
load_model()
|
136 |
@spaces.GPU
|
137 |
def generate_response(query, context, prompts, max_tokens, temperature, top_p):
|
|
|
118 |
ankerbot_model = AutoModelForCausalLM.from_pretrained(
|
119 |
model_name,
|
120 |
quantization_config=bnb_config,
|
121 |
+
device_map="cpu",
|
122 |
torch_dtype=torch.float16,
|
123 |
use_cache=True,
|
124 |
offload_folder="../offload"
|
|
|
128 |
torch_dtype=torch.float16,
|
129 |
truncation=True,
|
130 |
padding=True, )
|
131 |
+
ankerbot_model.to("cuda")
|
132 |
prompt_format = "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
133 |
generator = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=True) # True for flash-attn2 else False
|
134 |
generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=True) # True for flash-attn2 else False
|
|
|
135 |
load_model()
|
136 |
@spaces.GPU
|
137 |
def generate_response(query, context, prompts, max_tokens, temperature, top_p):
|