Steelskull commited on
Commit
29e570d
·
verified ·
1 Parent(s): 69c946a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -88
app.py CHANGED
@@ -12,6 +12,7 @@ import huggingface_hub
12
  import torch
13
  import base64
14
  from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
 
15
  from cryptography.hazmat.backends import default_backend
16
  import yaml
17
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
@@ -32,7 +33,7 @@ MARKDOWN_DESCRIPTION = """
32
 
33
  The fastest way to perform a model merge 🔥
34
 
35
- Specify a YAML configuration file (see examples below) and a HF token and this app will perform the merge and upload the merged model to your user profile.
36
  """
37
 
38
  MARKDOWN_ARTICLE = """
@@ -73,52 +74,119 @@ This Space is heavily inspired by LazyMergeKit by Maxime Labonne (see [Colab](ht
73
 
74
  examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
75
 
 
 
 
 
 
 
 
76
 
77
- def encrypt_file(file_path, key):
78
  """
79
- Encrypt the contents of a file using AES encryption with the provided key.
80
-
 
81
  Args:
82
  file_path: Path to the file to encrypt (pathlib.Path or string)
83
- key: Encryption key
84
-
85
  Returns:
86
  bool: True if encryption was successful, False otherwise
87
  """
88
  try:
89
  file_path = pathlib.Path(file_path)
90
  if not file_path.exists():
 
91
  return False
92
-
93
- # Ensure key is 32 bytes (256 bits)
94
- key_bytes = key.encode('utf-8')
95
- key_bytes = key_bytes + b'\0' * (32 - len(key_bytes)) if len(key_bytes) < 32 else key_bytes[:32]
96
-
97
- # Generate a random IV
98
  iv = os.urandom(16)
99
-
100
- # Create an encryptor
101
  cipher = Cipher(algorithms.AES(key_bytes), modes.CBC(iv), backend=default_backend())
102
  encryptor = cipher.encryptor()
103
-
104
- # Read file content
 
 
105
  with open(file_path, 'rb') as f:
106
- content = f.read()
107
-
108
- # Pad the content to be a multiple of 16 bytes
109
- padding = 16 - (len(content) % 16)
110
- content += bytes([padding]) * padding
111
-
112
- # Encrypt and write back
113
- encrypted = iv + encryptor.update(content) + encryptor.finalize()
 
 
 
 
114
  with open(file_path, 'wb') as f:
115
- f.write(base64.b64encode(encrypted))
116
-
117
  return True
118
  except Exception as e:
119
  print(f"Encryption error: {e}")
120
  return False
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  def merge(yaml_config: str, hf_token: str, repo_name: str, cipher_key: str) -> Iterable[List[Log]]:
124
  runner = LogsViewRunner()
@@ -136,7 +204,7 @@ def merge(yaml_config: str, hf_token: str, repo_name: str, cipher_key: str) -> I
136
  if not hf_token:
137
  yield runner.log("No HF token provided. A valid token is required for uploading.", level="ERROR")
138
  return
139
-
140
  # Validate that the token works by trying to get user info
141
  try:
142
  api = huggingface_hub.HfApi(token=hf_token)
@@ -146,10 +214,14 @@ def merge(yaml_config: str, hf_token: str, repo_name: str, cipher_key: str) -> I
146
  yield runner.log(f"Invalid HF token: {e}", level="ERROR")
147
  return
148
 
149
- # Set default cipher key if none provided
150
  if not cipher_key:
151
- cipher_key = "default_key" # Fallback key, though we should encourage users to set their own
152
- yield runner.log("No cipher key provided. Using default key (not recommended).", level="WARNING")
 
 
 
 
153
 
154
  with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
155
  tmpdir = pathlib.Path(tmpdirname)
@@ -182,37 +254,49 @@ def merge(yaml_config: str, hf_token: str, repo_name: str, cipher_key: str) -> I
182
 
183
  if runner.exit_code != 0:
184
  yield runner.log("Merge failed. Deleting repo as no model is uploaded.", level="ERROR")
185
- api.delete_repo(repo_url.repo_id)
 
 
 
 
186
  return
187
 
188
- yield runner.log("Model merged successfully. Uploading to HF.")
189
-
190
- # Delete Readme.md if it exists (case-insensitive check)
191
  merge_dir = merged_path / "merge"
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  readme_deleted = False
193
- for file in merge_dir.glob("*"):
194
- if file.name.lower() == "readme.md":
195
- try:
196
  file.unlink()
197
  readme_deleted = True
198
  yield runner.log(f"Deleted {file.name} file before upload")
199
- except Exception as e:
200
- yield runner.log(f"Error deleting {file.name}: {e}", level="WARNING")
201
-
 
202
  if not readme_deleted:
203
- yield runner.log("No Readme.md file found to delete", level="INFO")
204
-
205
- # Encrypt mergekit_config.yml if it exists
206
- config_yml_path = merged_path / "merge" / "mergekit_config.yml"
207
- if not config_yml_path.exists():
208
- yield runner.log("mergekit_config.yml not found, nothing to encrypt", level="INFO")
209
- elif encrypt_file(config_yml_path, cipher_key):
210
- yield runner.log("Encrypted mergekit_config.yml with provided key")
211
-
212
  yield from runner.run_python(
213
  api.upload_folder,
214
  repo_id=repo_url.repo_id,
215
- folder_path=merged_path / "merge",
216
  )
217
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
218
 
@@ -234,43 +318,65 @@ NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')}
234
  with gr.Blocks() as demo:
235
  gr.Markdown(MARKDOWN_DESCRIPTION)
236
  gr.Markdown(NEXT_RESTART)
237
-
238
- with gr.Row():
239
- filename = gr.Textbox(visible=False, label="filename")
240
- config = gr.Code(language="yaml", lines=10, label="config.yaml")
241
- with gr.Column():
242
- token = gr.Textbox(
243
- lines=1,
244
- label="HF Write Token",
245
- info="https://hf.co/settings/token",
246
- type="password",
247
- placeholder="Required for model upload.",
248
- )
249
- repo_name = gr.Textbox(
250
- lines=1,
251
- label="Repo name",
252
- placeholder="Optional. Will create a random name if empty.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  )
254
- cipher_key = gr.Textbox(
255
- lines=1,
256
- label="Encryption Key",
257
- type="password",
258
- placeholder="Key used to encrypt the config file.",
259
- value="Default"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  )
261
- button = gr.Button("Merge", variant="primary")
262
- logs = LogsView(label="Terminal output")
263
- gr.Examples(
264
- examples,
265
- fn=lambda s: (s,),
266
- run_on_click=True,
267
- label="Examples",
268
- inputs=[filename],
269
- outputs=[config],
270
- )
271
- gr.Markdown(MARKDOWN_ARTICLE)
272
-
273
- button.click(fn=merge, inputs=[config, token, repo_name, cipher_key], outputs=[logs])
274
-
275
-
276
- demo.queue(default_concurrency_limit=1).launch()
 
12
  import torch
13
  import base64
14
  from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
15
+ from cryptography.hazmat.primitives import padding as sym_padding # Use symmetric padding
16
  from cryptography.hazmat.backends import default_backend
17
  import yaml
18
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
 
33
 
34
  The fastest way to perform a model merge 🔥
35
 
36
+ Specify a YAML configuration file (see examples below) and a HF token and this app will perform the merge and upload the merged model to your user profile. Includes encryption for the `mergekit_config.yml` and a tool to decrypt it.
37
  """
38
 
39
  MARKDOWN_ARTICLE = """
 
74
 
75
  examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
76
 
77
+ def _prepare_key(key: str) -> bytes:
78
+ """Pads or truncates the key to 32 bytes (256 bits) for AES."""
79
+ key_bytes = key.encode('utf-8')
80
+ if len(key_bytes) < 32:
81
+ return key_bytes + b'\0' * (32 - len(key_bytes))
82
+ else:
83
+ return key_bytes[:32]
84
 
85
+ def encrypt_file(file_path, key: str) -> bool:
86
  """
87
+ Encrypt the contents of a file using AES-256-CBC encryption with the provided key.
88
+ The output is Base64 encoded.
89
+
90
  Args:
91
  file_path: Path to the file to encrypt (pathlib.Path or string)
92
+ key: Encryption key string.
93
+
94
  Returns:
95
  bool: True if encryption was successful, False otherwise
96
  """
97
  try:
98
  file_path = pathlib.Path(file_path)
99
  if not file_path.exists():
100
+ print(f"Encryption error: File not found at {file_path}")
101
  return False
102
+
103
+ key_bytes = _prepare_key(key)
104
+
105
+ # Generate a random IV (Initialization Vector) - 16 bytes for AES
 
 
106
  iv = os.urandom(16)
107
+
108
+ # Create an AES cipher instance with CBC mode
109
  cipher = Cipher(algorithms.AES(key_bytes), modes.CBC(iv), backend=default_backend())
110
  encryptor = cipher.encryptor()
111
+
112
+ # Use PKCS7 padding
113
+ padder = sym_padding.PKCS7(algorithms.AES.block_size).padder()
114
+
115
  with open(file_path, 'rb') as f:
116
+ plaintext = f.read()
117
+
118
+ # Pad the data
119
+ padded_data = padder.update(plaintext) + padder.finalize()
120
+
121
+ # Encrypt the padded data
122
+ ciphertext = encryptor.update(padded_data) + encryptor.finalize()
123
+
124
+ # Prepend the IV to the ciphertext and base64 encode the result
125
+ encrypted_data_with_iv = base64.b64encode(iv + ciphertext)
126
+
127
+ # Write the base64 encoded encrypted data back to the file
128
  with open(file_path, 'wb') as f:
129
+ f.write(encrypted_data_with_iv)
130
+
131
  return True
132
  except Exception as e:
133
  print(f"Encryption error: {e}")
134
  return False
135
 
136
+ def decrypt_file_content(file_input, key: str) -> str:
137
+ """
138
+ Decrypts the content of an uploaded file using AES-256-CBC and returns the result.
139
+ Assumes the file content is Base64 encoded IV + ciphertext.
140
+
141
+ Args:
142
+ file_input: Gradio File component output (temporary file object).
143
+ key: Decryption key string.
144
+
145
+ Returns:
146
+ str: Decrypted content as a UTF-8 string, or an error message.
147
+ """
148
+ if file_input is None:
149
+ return "Error: No file provided for decryption."
150
+ if not key:
151
+ return "Error: Decryption key cannot be empty."
152
+
153
+ try:
154
+ file_path = file_input.name # Get the temporary file path from Gradio
155
+ key_bytes = _prepare_key(key)
156
+
157
+ with open(file_path, 'rb') as f:
158
+ base64_encoded_data = f.read()
159
+
160
+ # Decode from Base64
161
+ encrypted_data_with_iv = base64.b64decode(base64_encoded_data)
162
+
163
+ # Extract the IV (first 16 bytes)
164
+ iv = encrypted_data_with_iv[:16]
165
+ # Extract the ciphertext (the rest)
166
+ ciphertext = encrypted_data_with_iv[16:]
167
+
168
+ # Create an AES cipher instance with CBC mode for decryption
169
+ cipher = Cipher(algorithms.AES(key_bytes), modes.CBC(iv), backend=default_backend())
170
+ decryptor = cipher.decryptor()
171
+
172
+ # Decrypt the data
173
+ padded_plaintext = decryptor.update(ciphertext) + decryptor.finalize()
174
+
175
+ # Unpad the data using PKCS7
176
+ unpadder = sym_padding.PKCS7(algorithms.AES.block_size).unpadder()
177
+ plaintext = unpadder.update(padded_plaintext) + unpadder.finalize()
178
+
179
+ # Decode the plaintext from bytes to string (assuming UTF-8)
180
+ return plaintext.decode('utf-8')
181
+
182
+ except (ValueError, TypeError) as e:
183
+ # Catches Base64 decoding errors, incorrect key type errors
184
+ return f"Decryption Error: Invalid input data or key format. ({e})"
185
+ except Exception as e:
186
+ # Catches padding errors (often due to wrong key), or other crypto issues
187
+ print(f"Decryption error details: {e}")
188
+ return f"Decryption Failed: Likely incorrect key or corrupted file. Error: {type(e).__name__}"
189
+
190
 
191
  def merge(yaml_config: str, hf_token: str, repo_name: str, cipher_key: str) -> Iterable[List[Log]]:
192
  runner = LogsViewRunner()
 
204
  if not hf_token:
205
  yield runner.log("No HF token provided. A valid token is required for uploading.", level="ERROR")
206
  return
207
+
208
  # Validate that the token works by trying to get user info
209
  try:
210
  api = huggingface_hub.HfApi(token=hf_token)
 
214
  yield runner.log(f"Invalid HF token: {e}", level="ERROR")
215
  return
216
 
217
+ # Use default key if none provided, but log a warning
218
  if not cipher_key:
219
+ cipher_key = "default_insecure_key" # Make default explicitely insecure sounding
220
+ yield runner.log("No cipher key provided. Using a default, insecure key. Please provide your own key for security.", level="WARNING")
221
+ elif cipher_key == "Default": # Check against the placeholder value
222
+ cipher_key = "default_insecure_key" # Treat placeholder as no key provided
223
+ yield runner.log("Default placeholder key detected. Using an insecure key. Please provide your own key.", level="WARNING")
224
+
225
 
226
  with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
227
  tmpdir = pathlib.Path(tmpdirname)
 
254
 
255
  if runner.exit_code != 0:
256
  yield runner.log("Merge failed. Deleting repo as no model is uploaded.", level="ERROR")
257
+ try:
258
+ api.delete_repo(repo_url.repo_id)
259
+ yield runner.log(f"Repo {repo_url.repo_id} deleted.")
260
+ except Exception as delete_e:
261
+ yield runner.log(f"Failed to delete repo {repo_url.repo_id}: {delete_e}", level="WARNING")
262
  return
263
 
264
+ yield runner.log("Model merged successfully. Preparing for upload.")
265
+
266
+ # ---- Encryption Step ----
267
  merge_dir = merged_path / "merge"
268
+ config_yml_path = merge_dir / "mergekit_config.yml"
269
+
270
+ if config_yml_path.exists():
271
+ yield runner.log(f"Found {config_yml_path.name}. Encrypting...")
272
+ if encrypt_file(config_yml_path, cipher_key):
273
+ yield runner.log(f"Successfully encrypted {config_yml_path.name} with provided key.")
274
+ else:
275
+ yield runner.log(f"Failed to encrypt {config_yml_path.name}. Uploading unencrypted.", level="ERROR")
276
+ else:
277
+ yield runner.log(f"{config_yml_path.name} not found in merge output, nothing to encrypt.", level="INFO")
278
+ # ---- End Encryption Step ----
279
+
280
+ # Delete Readme.md if it exists (case-insensitive check) before upload
281
  readme_deleted = False
282
+ try:
283
+ for file in merge_dir.glob("*"):
284
+ if file.name.lower() == "readme.md":
285
  file.unlink()
286
  readme_deleted = True
287
  yield runner.log(f"Deleted {file.name} file before upload")
288
+ break # Assume only one readme
289
+ except Exception as e:
290
+ yield runner.log(f"Error deleting Readme.md: {e}", level="WARNING")
291
+
292
  if not readme_deleted:
293
+ yield runner.log("No Readme.md file found to delete.", level="INFO")
294
+
295
+ yield runner.log("Uploading merged model files to HF.")
 
 
 
 
 
 
296
  yield from runner.run_python(
297
  api.upload_folder,
298
  repo_id=repo_url.repo_id,
299
+ folder_path=merge_dir, # Upload from the 'merge' subdirectory
300
  )
301
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
302
 
 
318
  with gr.Blocks() as demo:
319
  gr.Markdown(MARKDOWN_DESCRIPTION)
320
  gr.Markdown(NEXT_RESTART)
321
+
322
+ with gr.Tabs():
323
+ with gr.TabItem("Merge Model"):
324
+ with gr.Row():
325
+ filename = gr.Textbox(visible=False, label="filename")
326
+ config = gr.Code(language="yaml", lines=10, label="config.yaml")
327
+ with gr.Column():
328
+ token = gr.Textbox(
329
+ lines=1,
330
+ label="HF Write Token",
331
+ info="https://hf.co/settings/token",
332
+ type="password",
333
+ placeholder="Required for model upload.",
334
+ )
335
+ repo_name = gr.Textbox(
336
+ lines=1,
337
+ label="Repo name",
338
+ placeholder="Optional. Will create a random name if empty.",
339
+ )
340
+ cipher_key = gr.Textbox(
341
+ lines=1,
342
+ label="Encryption Key",
343
+ type="password",
344
+ info="Key used to encrypt the generated mergekit_config.yml file before upload. Leave blank or 'Default' for no encryption (or insecure default).",
345
+ placeholder="Enter your secret key here",
346
+ value="Default" # Set a default placeholder
347
+ )
348
+ button = gr.Button("Merge and Upload", variant="primary")
349
+ logs = LogsView(label="Merge Progress / Terminal output")
350
+ gr.Examples(
351
+ examples,
352
+ fn=lambda s: (s,),
353
+ run_on_click=True,
354
+ label="Merge Examples",
355
+ inputs=[filename],
356
+ outputs=[config],
357
  )
358
+ gr.Markdown(MARKDOWN_ARTICLE)
359
+
360
+ button.click(fn=merge, inputs=[config, token, repo_name, cipher_key], outputs=[logs])
361
+
362
+ with gr.TabItem("Decrypt Configuration"):
363
+ gr.Markdown("Upload an encrypted `mergekit_config.yml` file and provide the key to decrypt it.")
364
+ with gr.Row():
365
+ decrypt_file_input = gr.File(label="Upload Encrypted mergekit_config.yml")
366
+ decrypt_key_input = gr.Textbox(
367
+ lines=1,
368
+ label="Decryption Key",
369
+ type="password",
370
+ placeholder="Enter the key used for encryption",
371
+ )
372
+ decrypt_button = gr.Button("Decrypt File", variant="secondary")
373
+ decrypted_output = gr.Code(language="yaml", label="Decrypted Configuration", lines=15, interactive=False)
374
+
375
+ decrypt_button.click(
376
+ fn=decrypt_file_content,
377
+ inputs=[decrypt_file_input, decrypt_key_input],
378
+ outputs=[decrypted_output]
379
  )
380
+
381
+
382
+ demo.queue(default_concurrency_limit=1).launch()