pmelnechuk commited on
Commit
7ff3311
verified
1 Parent(s): 28bc663

Update src/model_load.py

Browse files
Files changed (1) hide show
  1. src/model_load.py +2 -1
src/model_load.py CHANGED
@@ -12,8 +12,9 @@ def load_model():
12
  max_memory = {0: "24GB", "cpu": "30GB"}
13
  # Cargar tokenizer y modelo de Hugging Face
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
15
  model = AutoModelForCausalLM.from_pretrained(model_name,
16
- torch_dtype=torch.bfloat16).to("cuda")
17
 
18
  # Crear pipeline de generaci贸n de texto
19
  text_generation_pipeline = pipeline(
 
12
  max_memory = {0: "24GB", "cpu": "30GB"}
13
  # Cargar tokenizer y modelo de Hugging Face
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model = AutoModelForCausalLM.from_pretrained(model_name,
17
+ torch_dtype=torch.bfloat16).to(device)
18
 
19
  # Crear pipeline de generaci贸n de texto
20
  text_generation_pipeline = pipeline(