Spestly commited on
Commit
e7e2534
Β·
verified Β·
1 Parent(s): e7e4bd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -119,6 +119,10 @@ class AtlasInferenceApp:
119
  pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
120
  eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
121
  ):
 
 
 
 
122
  # Decode only the new tokens
123
  new_tokens = chunk[:, inputs.input_ids.shape[1]:] # Exclude input tokens
124
  generated_tokens.extend(new_tokens[0].tolist()) # Add new tokens to the list
 
119
  pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
120
  eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
121
  ):
122
+ # Ensure chunk is 2D (batch size Γ— sequence length)
123
+ if chunk.dim() == 1:
124
+ chunk = chunk.unsqueeze(0) # Add batch dimension
125
+
126
  # Decode only the new tokens
127
  new_tokens = chunk[:, inputs.input_ids.shape[1]:] # Exclude input tokens
128
  generated_tokens.extend(new_tokens[0].tolist()) # Add new tokens to the list