Srinivasan Iyer sviyer commited on
Commit
22c7fe1
·
unverified ·
1 Parent(s): fe45f69

fix save and reload model state (#49)

Browse files

Co-authored-by: Srini Iyer <[email protected]>

Files changed (1) hide show
  1. bytelatent/model/local_models.py +14 -10
bytelatent/model/local_models.py CHANGED
@@ -74,12 +74,10 @@ class LocalModelBase(nn.Module):
74
 
75
  self.boe_id = BOE_ID
76
 
77
- self.norm = RMSNorm(args.dim, eps=args.norm_eps)
78
  self.layers = nn.ModuleList(
79
  [TransformerBlock(args) for _ in range(args.n_layers)]
80
  )
81
 
82
- self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
83
  if not self.use_rope:
84
  self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
85
  else:
@@ -131,16 +129,18 @@ class LocalModelBase(nn.Module):
131
 
132
  def init_weights(self, init_std=None):
133
  self.rope.reset_parameters()
134
- self.norm.reset_parameters()
 
135
 
136
  init_std = init_std or (self.dim ** (-0.5))
137
- nn.init.trunc_normal_(
138
- self.tok_embeddings.weight,
139
- mean=0.0,
140
- std=init_std,
141
- a=-3 * init_std,
142
- b=3 * init_std,
143
- )
 
144
  if self.pos_embeddings is not None:
145
  nn.init.trunc_normal_(
146
  self.pos_embeddings.weight,
@@ -212,6 +212,8 @@ class LocalEncoder(LocalModelBase):
212
  self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
213
  self.cross_attn_nheads = args.cross_attn_nheads
214
 
 
 
215
  if self.cross_attn_encoder:
216
  self.cross_attn_layers = torch.nn.ModuleList()
217
  layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1
@@ -314,6 +316,8 @@ class LocalDecoder(LocalModelBase):
314
  self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
315
  self.cross_attn_nheads = args.cross_attn_nheads
316
 
 
 
317
  if self.cross_attn_decoder:
318
  self.cross_attn_layers = torch.nn.ModuleList()
319
  layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1
 
74
 
75
  self.boe_id = BOE_ID
76
 
 
77
  self.layers = nn.ModuleList(
78
  [TransformerBlock(args) for _ in range(args.n_layers)]
79
  )
80
 
 
81
  if not self.use_rope:
82
  self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
83
  else:
 
129
 
130
  def init_weights(self, init_std=None):
131
  self.rope.reset_parameters()
132
+ if hasattr(self, "norm"):
133
+ self.norm.reset_parameters()
134
 
135
  init_std = init_std or (self.dim ** (-0.5))
136
+ if hasattr(self, "tok_embeddings"):
137
+ nn.init.trunc_normal_(
138
+ self.tok_embeddings.weight,
139
+ mean=0.0,
140
+ std=init_std,
141
+ a=-3 * init_std,
142
+ b=3 * init_std,
143
+ )
144
  if self.pos_embeddings is not None:
145
  nn.init.trunc_normal_(
146
  self.pos_embeddings.weight,
 
212
  self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
213
  self.cross_attn_nheads = args.cross_attn_nheads
214
 
215
+ self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
216
+
217
  if self.cross_attn_encoder:
218
  self.cross_attn_layers = torch.nn.ModuleList()
219
  layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1
 
316
  self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
317
  self.cross_attn_nheads = args.cross_attn_nheads
318
 
319
+ self.norm = RMSNorm(args.dim, eps=args.norm_eps)
320
+
321
  if self.cross_attn_decoder:
322
  self.cross_attn_layers = torch.nn.ModuleList()
323
  layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1