Steelskull commited on
Commit
e8ac14f
·
verified ·
1 Parent(s): f20a860

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -70
app.py CHANGED
@@ -10,6 +10,9 @@ from typing import Iterable, List
10
  import gradio as gr
11
  import huggingface_hub
12
  import torch
 
 
 
13
  import yaml
14
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
15
  from mergekit.config import MergeConfiguration
@@ -20,30 +23,6 @@ from datetime import datetime, timezone
20
 
21
  has_gpu = torch.cuda.is_available()
22
 
23
- # Running directly from Python doesn't work well with Gradio+run_process because of:
24
- # Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
25
- # Let's use the CLI instead.
26
- #
27
- # import mergekit.merge
28
- # from mergekit.common import parse_kmb
29
- # from mergekit.options import MergeOptions
30
- #
31
- # merge_options = (
32
- # MergeOptions(
33
- # copy_tokenizer=True,
34
- # cuda=True,
35
- # low_cpu_memory=True,
36
- # write_model_card=True,
37
- # )
38
- # if has_gpu
39
- # else MergeOptions(
40
- # allow_crimes=True,
41
- # out_shard_size=parse_kmb("1B"),
42
- # lazy_unpickle=True,
43
- # write_model_card=True,
44
- # )
45
- # )
46
-
47
  cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
48
  " --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --out-shard-size 1B --lazy-unpickle"
49
  )
@@ -87,19 +66,6 @@ A quick overview of the currently supported merge methods:
87
  | Passthrough | `passthrough` | ❌ | ❌ |
88
  | [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ |
89
 
90
-
91
- ## Citation
92
-
93
- This GUI is powered by [Arcee's MergeKit](https://arxiv.org/abs/2403.13257).
94
- If you use it in your research, please cite the following paper:
95
-
96
- ```
97
- @article{goddard2024arcee,
98
- title={Arcee's MergeKit: A Toolkit for Merging Large Language Models},
99
- author={Goddard, Charles and Siriwardhana, Shamane and Ehghaghi, Malikeh and Meyers, Luke and Karpukhin, Vlad and Benedict, Brian and McQuade, Mark and Solawetz, Jacob},
100
- journal={arXiv preprint arXiv:2403.13257},
101
- year={2024}
102
- }
103
  ```
104
 
105
  This Space is heavily inspired by LazyMergeKit by Maxime Labonne (see [Colab](https://colab.research.google.com/drive/1obulZ1ROXHjYLn6PPZJwRR6GzgQogxxb)).
@@ -107,13 +73,54 @@ This Space is heavily inspired by LazyMergeKit by Maxime Labonne (see [Colab](ht
107
 
108
  examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
109
 
110
- # Do not set community token as `HF_TOKEN` to avoid accidentally using it in merge scripts.
111
- # `COMMUNITY_HF_TOKEN` is used to upload models to the community organization (https://huggingface.co/mergekit-community)
112
- # when user do not provide a token.
113
- COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
- def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
117
  runner = LogsViewRunner()
118
 
119
  if not yaml_config:
@@ -125,23 +132,24 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
125
  yield runner.log(f"Invalid yaml {e}", level="ERROR")
126
  return
127
 
128
- is_community_model = False
129
  if not hf_token:
130
- if "/" in repo_name and not repo_name.startswith("mergekit-community/"):
131
- yield runner.log(
132
- f"Cannot upload merge model to namespace {repo_name.split('/')[0]}: you must provide a valid token.",
133
- level="ERROR",
134
- )
135
- return
136
- yield runner.log(
137
- "No HF token provided. Your merged model will be uploaded to the https://huggingface.co/mergekit-community organization."
138
- )
139
- is_community_model = True
140
- if not COMMUNITY_HF_TOKEN:
141
- raise gr.Error("Cannot upload to community org: community token not set by Space owner.")
142
- hf_token = COMMUNITY_HF_TOKEN
143
 
144
- api = huggingface_hub.HfApi(token=hf_token)
 
 
 
145
 
146
  with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
147
  tmpdir = pathlib.Path(tmpdirname)
@@ -158,9 +166,6 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
158
  repo_name += "-" + "".join(random.choices(string.ascii_lowercase, k=7))
159
  repo_name = repo_name.replace("/", "-").strip("-")
160
 
161
- if is_community_model and not repo_name.startswith("mergekit-community/"):
162
- repo_name = f"mergekit-community/{repo_name}"
163
-
164
  try:
165
  yield runner.log(f"Creating repo {repo_name}")
166
  repo_url = api.create_repo(repo_name, exist_ok=True)
@@ -181,6 +186,29 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
181
  return
182
 
183
  yield runner.log("Model merged successfully. Uploading to HF.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  yield from runner.run_python(
185
  api.upload_folder,
186
  repo_id=repo_url.repo_id,
@@ -188,22 +216,18 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
188
  )
189
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
190
 
191
- # This is workaround. As the space always getting stuck.
192
- def _restart_space():
193
- huggingface_hub.HfApi().restart_space(repo_id="arcee-ai/mergekit-gui", token=COMMUNITY_HF_TOKEN, factory_reboot=False)
194
  # Run garbage collection every hour to keep the community org clean.
195
  # Empty models might exists if the merge fails abruptly (e.g. if user leaves the Space).
196
  def _garbage_remover():
197
  try:
198
- garbage_collect_empty_models(token=COMMUNITY_HF_TOKEN)
199
  except Exception as e:
200
  print("Error running garbage collection", e)
201
 
202
  scheduler = BackgroundScheduler()
203
- restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=21600)
204
  garbage_remover_job = scheduler.add_job(_garbage_remover, "interval", seconds=3600)
205
  scheduler.start()
206
- next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc)
207
 
208
  NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC)"
209
 
@@ -220,13 +244,20 @@ with gr.Blocks() as demo:
220
  label="HF Write Token",
221
  info="https://hf.co/settings/token",
222
  type="password",
223
- placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
224
  )
225
  repo_name = gr.Textbox(
226
  lines=1,
227
  label="Repo name",
228
  placeholder="Optional. Will create a random name if empty.",
229
  )
 
 
 
 
 
 
 
230
  button = gr.Button("Merge", variant="primary")
231
  logs = LogsView(label="Terminal output")
232
  gr.Examples(
@@ -239,8 +270,7 @@ with gr.Blocks() as demo:
239
  )
240
  gr.Markdown(MARKDOWN_ARTICLE)
241
 
242
- button.click(fn=merge, inputs=[config, token, repo_name], outputs=[logs])
243
-
244
 
245
 
246
- demo.queue(default_concurrency_limit=1).launch()
 
10
  import gradio as gr
11
  import huggingface_hub
12
  import torch
13
+ import base64
14
+ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
15
+ from cryptography.hazmat.backends import default_backend
16
  import yaml
17
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
18
  from mergekit.config import MergeConfiguration
 
23
 
24
  has_gpu = torch.cuda.is_available()
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
27
  " --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --out-shard-size 1B --lazy-unpickle"
28
  )
 
66
  | Passthrough | `passthrough` | ❌ | ❌ |
67
  | [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ |
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  ```
70
 
71
  This Space is heavily inspired by LazyMergeKit by Maxime Labonne (see [Colab](https://colab.research.google.com/drive/1obulZ1ROXHjYLn6PPZJwRR6GzgQogxxb)).
 
73
 
74
  examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
75
 
76
+
77
+ def encrypt_file(file_path, key):
78
+ """
79
+ Encrypt the contents of a file using AES encryption with the provided key.
80
+
81
+ Args:
82
+ file_path: Path to the file to encrypt (pathlib.Path or string)
83
+ key: Encryption key
84
+
85
+ Returns:
86
+ bool: True if encryption was successful, False otherwise
87
+ """
88
+ try:
89
+ file_path = pathlib.Path(file_path)
90
+ if not file_path.exists():
91
+ return False
92
+
93
+ # Ensure key is 32 bytes (256 bits)
94
+ key_bytes = key.encode('utf-8')
95
+ key_bytes = key_bytes + b'\0' * (32 - len(key_bytes)) if len(key_bytes) < 32 else key_bytes[:32]
96
+
97
+ # Generate a random IV
98
+ iv = os.urandom(16)
99
+
100
+ # Create an encryptor
101
+ cipher = Cipher(algorithms.AES(key_bytes), modes.CBC(iv), backend=default_backend())
102
+ encryptor = cipher.encryptor()
103
+
104
+ # Read file content
105
+ with open(file_path, 'rb') as f:
106
+ content = f.read()
107
+
108
+ # Pad the content to be a multiple of 16 bytes
109
+ padding = 16 - (len(content) % 16)
110
+ content += bytes([padding]) * padding
111
+
112
+ # Encrypt and write back
113
+ encrypted = iv + encryptor.update(content) + encryptor.finalize()
114
+ with open(file_path, 'wb') as f:
115
+ f.write(base64.b64encode(encrypted))
116
+
117
+ return True
118
+ except Exception as e:
119
+ print(f"Encryption error: {e}")
120
+ return False
121
 
122
 
123
+ def merge(yaml_config: str, hf_token: str, repo_name: str, cipher_key: str) -> Iterable[List[Log]]:
124
  runner = LogsViewRunner()
125
 
126
  if not yaml_config:
 
132
  yield runner.log(f"Invalid yaml {e}", level="ERROR")
133
  return
134
 
135
+ # Check if HF token is provided
136
  if not hf_token:
137
+ yield runner.log("No HF token provided. A valid token is required for uploading.", level="ERROR")
138
+ return
139
+
140
+ # Validate that the token works by trying to get user info
141
+ try:
142
+ api = huggingface_hub.HfApi(token=hf_token)
143
+ me = api.whoami()
144
+ yield runner.log(f"Authenticated as: {me['name']} ({me.get('fullname', '')})")
145
+ except Exception as e:
146
+ yield runner.log(f"Invalid HF token: {e}", level="ERROR")
147
+ return
 
 
148
 
149
+ # Set default cipher key if none provided
150
+ if not cipher_key:
151
+ cipher_key = "default_key" # Fallback key, though we should encourage users to set their own
152
+ yield runner.log("No cipher key provided. Using default key (not recommended).", level="WARNING")
153
 
154
  with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
155
  tmpdir = pathlib.Path(tmpdirname)
 
166
  repo_name += "-" + "".join(random.choices(string.ascii_lowercase, k=7))
167
  repo_name = repo_name.replace("/", "-").strip("-")
168
 
 
 
 
169
  try:
170
  yield runner.log(f"Creating repo {repo_name}")
171
  repo_url = api.create_repo(repo_name, exist_ok=True)
 
186
  return
187
 
188
  yield runner.log("Model merged successfully. Uploading to HF.")
189
+
190
+ # Delete Readme.md if it exists (case-insensitive check)
191
+ merge_dir = merged_path / "merge"
192
+ readme_deleted = False
193
+ for file in merge_dir.glob("*"):
194
+ if file.name.lower() == "readme.md":
195
+ try:
196
+ file.unlink()
197
+ readme_deleted = True
198
+ yield runner.log(f"Deleted {file.name} file before upload")
199
+ except Exception as e:
200
+ yield runner.log(f"Error deleting {file.name}: {e}", level="WARNING")
201
+
202
+ if not readme_deleted:
203
+ yield runner.log("No Readme.md file found to delete", level="INFO")
204
+
205
+ # Encrypt mergekit_config.yml if it exists
206
+ config_yml_path = merged_path / "merge" / "mergekit_config.yml"
207
+ if not config_yml_path.exists():
208
+ yield runner.log("mergekit_config.yml not found, nothing to encrypt", level="INFO")
209
+ elif encrypt_file(config_yml_path, cipher_key):
210
+ yield runner.log("Encrypted mergekit_config.yml with provided key")
211
+
212
  yield from runner.run_python(
213
  api.upload_folder,
214
  repo_id=repo_url.repo_id,
 
216
  )
217
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
218
 
 
 
 
219
  # Run garbage collection every hour to keep the community org clean.
220
  # Empty models might exists if the merge fails abruptly (e.g. if user leaves the Space).
221
  def _garbage_remover():
222
  try:
223
+ garbage_collect_empty_models(token=os.getenv("COMMUNITY_HF_TOKEN"))
224
  except Exception as e:
225
  print("Error running garbage collection", e)
226
 
227
  scheduler = BackgroundScheduler()
 
228
  garbage_remover_job = scheduler.add_job(_garbage_remover, "interval", seconds=3600)
229
  scheduler.start()
230
+ next_run_time_utc = garbage_remover_job.next_run_time.astimezone(timezone.utc)
231
 
232
  NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC)"
233
 
 
244
  label="HF Write Token",
245
  info="https://hf.co/settings/token",
246
  type="password",
247
+ placeholder="Required for model upload.",
248
  )
249
  repo_name = gr.Textbox(
250
  lines=1,
251
  label="Repo name",
252
  placeholder="Optional. Will create a random name if empty.",
253
  )
254
+ cipher_key = gr.Textbox(
255
+ lines=1,
256
+ label="Encryption Key",
257
+ type="password",
258
+ placeholder="Key used to encrypt the config file.",
259
+ value="GPuccini"
260
+ )
261
  button = gr.Button("Merge", variant="primary")
262
  logs = LogsView(label="Terminal output")
263
  gr.Examples(
 
270
  )
271
  gr.Markdown(MARKDOWN_ARTICLE)
272
 
273
+ button.click(fn=merge, inputs=[config, token, repo_name, cipher_key], outputs=[logs])
 
274
 
275
 
276
+ demo.queue(default_concurrency_limit=1).launch()