sagawa commited on
Commit
542de7d
·
1 Parent(s): 91adf86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -25,8 +25,8 @@ class CFG():
25
  input_data = st.text_area(display_text)
26
  model_name_or_path = 'sagawa/ZINC-t5-productpredicition'
27
  model = 't5'
28
- num_beams = 5
29
- num_return_sequences = 5
30
  seed = 42
31
 
32
 
@@ -79,7 +79,8 @@ if CFG.uploaded_file is not None:
79
  def convert_df(df):
80
  # IMPORTANT: Cache the conversion to prevent computation on every rerun
81
  return df.to_csv(index=False)
82
-
 
83
  csv = convert_df(output_df)
84
 
85
  st.download_button(
@@ -110,5 +111,20 @@ else:
110
  try:
111
  output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score'])
112
  st.table(output_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  except:
114
  pass
 
25
  input_data = st.text_area(display_text)
26
  model_name_or_path = 'sagawa/ZINC-t5-productpredicition'
27
  model = 't5'
28
+ num_beams = st.number_input(label='num beams', min_value=1, max_value=10, value=5, step=1)
29
+ num_return_sequences = st.number_input(label='num return sequences', min_value=1, max_value=num_beams, value=num_beams, step=1)
30
  seed = 42
31
 
32
 
 
79
  def convert_df(df):
80
  # IMPORTANT: Cache the conversion to prevent computation on every rerun
81
  return df.to_csv(index=False)
82
+
83
+ st.table(output_df)
84
  csv = convert_df(output_df)
85
 
86
  st.download_button(
 
111
  try:
112
  output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score'])
113
  st.table(output_df)
114
+
115
+ @st.cache
116
+ def convert_df(df):
117
+ # IMPORTANT: Cache the conversion to prevent computation on every rerun
118
+ return df.to_csv(index=False)
119
+
120
+ csv = convert_df(output_df)
121
+
122
+ st.download_button(
123
+ label="Download data as CSV",
124
+ data=csv,
125
+ file_name='output.csv',
126
+ mime='text/csv',
127
+ )
128
+
129
  except:
130
  pass