awacke1 commited on
Commit
acca2e2
Β·
verified Β·
1 Parent(s): d334e6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import os
3
  import shutil
4
  import glob
5
- import base64 # Added this import
6
  import streamlit as st
7
  import pandas as pd
8
  import torch
@@ -77,13 +77,15 @@ class ModelBuilder:
77
  self.tokenizer = None
78
  self.sft_data = None
79
 
80
- def load_model(self, model_path: str):
81
- """Load a model from a path"""
82
  with st.spinner("Loading model... ⏳"):
83
  self.model = AutoModelForCausalLM.from_pretrained(model_path)
84
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
85
  if self.tokenizer.pad_token is None:
86
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
87
  st.success("Model loaded! βœ…")
88
  return self
89
 
@@ -156,7 +158,9 @@ selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dir
156
  if selected_model != "None" and st.sidebar.button("Load Model πŸ“‚"):
157
  if 'builder' not in st.session_state:
158
  st.session_state['builder'] = ModelBuilder()
159
- st.session_state['builder'].load_model(selected_model)
 
 
160
  st.session_state['model_loaded'] = True
161
  st.rerun()
162
 
@@ -176,7 +180,7 @@ with tab1:
176
  if st.button("Download Model ⬇️"):
177
  config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
178
  builder = ModelBuilder()
179
- builder.load_model(base_model)
180
  builder.save_model(config.model_path)
181
  st.session_state['builder'] = builder
182
  st.session_state['model_loaded'] = True
@@ -210,7 +214,12 @@ with tab2:
210
  with open(csv_path, "wb") as f:
211
  f.write(uploaded_csv.read())
212
  new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
213
- new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small", domain=st.session_state['builder'].config.domain)
 
 
 
 
 
214
  st.session_state['builder'].config = new_config
215
  with st.status("Fine-tuning model... ⏳", expanded=True) as status:
216
  st.session_state['builder'].fine_tune_sft(csv_path)
 
2
  import os
3
  import shutil
4
  import glob
5
+ import base64
6
  import streamlit as st
7
  import pandas as pd
8
  import torch
 
77
  self.tokenizer = None
78
  self.sft_data = None
79
 
80
+ def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
81
+ """Load a model from a path with an optional config"""
82
  with st.spinner("Loading model... ⏳"):
83
  self.model = AutoModelForCausalLM.from_pretrained(model_path)
84
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
85
  if self.tokenizer.pad_token is None:
86
  self.tokenizer.pad_token = self.tokenizer.eos_token
87
+ if config:
88
+ self.config = config
89
  st.success("Model loaded! βœ…")
90
  return self
91
 
 
158
  if selected_model != "None" and st.sidebar.button("Load Model πŸ“‚"):
159
  if 'builder' not in st.session_state:
160
  st.session_state['builder'] = ModelBuilder()
161
+ # Create a config for the loaded model if none exists
162
+ config = ModelConfig(name=os.path.basename(selected_model), base_model="unknown", size="small", domain="general")
163
+ st.session_state['builder'].load_model(selected_model, config)
164
  st.session_state['model_loaded'] = True
165
  st.rerun()
166
 
 
180
  if st.button("Download Model ⬇️"):
181
  config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
182
  builder = ModelBuilder()
183
+ builder.load_model(base_model, config) # Pass config here
184
  builder.save_model(config.model_path)
185
  st.session_state['builder'] = builder
186
  st.session_state['model_loaded'] = True
 
214
  with open(csv_path, "wb") as f:
215
  f.write(uploaded_csv.read())
216
  new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
217
+ new_config = ModelConfig(
218
+ name=new_model_name,
219
+ base_model=st.session_state['builder'].config.base_model,
220
+ size="small",
221
+ domain=st.session_state['builder'].config.domain
222
+ )
223
  st.session_state['builder'].config = new_config
224
  with st.status("Fine-tuning model... ⏳", expanded=True) as status:
225
  st.session_state['builder'].fine_tune_sft(csv_path)