316usman commited on
Commit
a8fabcc
·
1 Parent(s): 4ae8bb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -4
app.py CHANGED
@@ -48,11 +48,26 @@ def load_model():
48
  tokenizer = LlamaTokenizer.from_pretrained(repo_id)
49
  return model, tokenizer
50
 
51
- if st.button("Load Model"):
52
- model1, tokenizer1 = load_model()
53
- model_loaded = True
54
-
55
  print ("Model Loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  if model_loaded:
58
  # Set up initial values for pipeline parameters
 
48
  tokenizer = LlamaTokenizer.from_pretrained(repo_id)
49
  return model, tokenizer
50
 
 
 
 
 
51
  print ("Model Loaded")
52
+ # Initialize session state variables
53
+ if "model_loaded" not in st.session_state:
54
+ st.session_state["model_loaded"] = False
55
+ if "model" not in st.session_state:
56
+ st.session_state["model"] = None
57
+ if "tokenizer" not in st.session_state:
58
+ st.session_state["tokenizer"] = None
59
+
60
+ # Display the "Load Model" button
61
+ if not st.session_state["model_loaded"]:
62
+ if st.button("Load Model"):
63
+ model1, tokenizer1 = load_model(repo_id)
64
+ st.session_state["model"] = model1
65
+ st.session_state["tokenizer"] = tokenizer1
66
+ st.session_state["model_loaded"] = True
67
+ else:
68
+ model1 = st.session_state["model"]
69
+ tokenizer1 = st.session_state["tokenizer"]
70
+
71
 
72
  if model_loaded:
73
  # Set up initial values for pipeline parameters