amiguel commited on
Commit
4a47453
·
verified ·
1 Parent(s): 026dc4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -6,7 +6,7 @@ import pickle
6
  from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
7
  from huggingface_hub import login, hf_hub_download
8
  import time
9
- from ch09util import subsequent_mask, create_model # Ensure ch09util.py is available
10
 
11
  # Device setup
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -119,13 +119,22 @@ def load_model_and_resources(hf_token):
119
  else:
120
  config = TransformerConfig(**config_dict)
121
 
122
- # Load model
123
- model = CustomTransformer.from_pretrained(
124
- MODEL_NAME,
125
- config=config,
126
- token=hf_token
127
- ).to(DEVICE)
 
 
 
 
 
 
 
128
 
 
 
129
  # Load dictionaries from Hugging Face Hub
130
  dict_path = hf_hub_download(repo_id=MODEL_NAME, filename="dict.p", token=hf_token)
131
  with open(dict_path, "rb") as fb:
@@ -243,7 +252,7 @@ if prompt := st.chat_input("Enter text to translate into French..."):
243
 
244
  # Display metrics
245
  st.caption(
246
- f"🔑 Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
247
  f"🕒 Speed: {speed:.1f}t/s | 💰 Cost (USD): ${total_cost_usd:.4f} | "
248
  f"💵 Cost (AOA): {total_cost_aoa:.4f}"
249
  )
 
6
  from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
7
  from huggingface_hub import login, hf_hub_download
8
  import time
9
+ from utils.ch09util import subsequent_mask, create_model # Ensure ch09util.py is available
10
 
11
  # Device setup
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
119
  else:
120
  config = TransformerConfig(**config_dict)
121
 
122
+ # Initialize model on meta device and load weights explicitly
123
+ model = CustomTransformer(config)
124
+ weights_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors", token=hf_token)
125
+ from safetensors.torch import load_file
126
+ state_dict = load_file(weights_path)
127
+ model.load_state_dict(state_dict)
128
+
129
+ # Move model to the target device safely
130
+ if DEVICE == "cuda":
131
+ model = model.to_empty(device=DEVICE) # Move structure to GPU
132
+ model.load_state_dict(state_dict) # Reload weights on GPU
133
+ else:
134
+ model = model.to(DEVICE) # CPU can handle direct move after loading weights
135
 
136
+ model.eval()
137
+
138
  # Load dictionaries from Hugging Face Hub
139
  dict_path = hf_hub_download(repo_id=MODEL_NAME, filename="dict.p", token=hf_token)
140
  with open(dict_path, "rb") as fb:
 
252
 
253
  # Display metrics
254
  st.caption(
255
+ f"🤖 Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
256
  f"🕒 Speed: {speed:.1f}t/s | 💰 Cost (USD): ${total_cost_usd:.4f} | "
257
  f"💵 Cost (AOA): {total_cost_aoa:.4f}"
258
  )