olegshulyakov commited on
Commit
46a6f32
·
verified ·
1 Parent(s): d1518f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -28
app.py CHANGED
@@ -33,10 +33,10 @@ QUANT_PARAMS = {
33
  def list_files_in_folder(folder_path):
34
  # List all files and directories in the specified folder
35
  all_items = os.listdir(folder_path)
36
-
37
  # Filter out only files
38
  files = [item for item in all_items if os.path.isfile(os.path.join(folder_path, item))]
39
-
40
  return files
41
 
42
  def clear_hf_cache_space():
@@ -48,8 +48,8 @@ def clear_hf_cache_space():
48
  scan.delete_revisions(*to_delete).execute()
49
  print("Cache has been cleared")
50
 
51
- def upload_to_hub(path, upload_repo, hf_path, oauth_token):
52
- card = ModelCard.load(hf_path, token=oauth_token.token)
53
  card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx", "mlx-my-repo"]
54
  card.data.base_model = hf_path
55
  card.text = dedent(
@@ -85,7 +85,7 @@ def upload_to_hub(path, upload_repo, hf_path, oauth_token):
85
 
86
  logging.set_verbosity_info()
87
 
88
- api = HfApi(token=oauth_token.token)
89
  api.create_repo(repo_id=upload_repo, exist_ok=True)
90
 
91
  files = list_files_in_folder(path)
@@ -98,35 +98,47 @@ def upload_to_hub(path, upload_repo, hf_path, oauth_token):
98
  path_in_repo=file,
99
  repo_id=upload_repo,
100
  )
101
-
102
- print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
103
 
104
  def process_model(model_id, q_method, oauth_token: gr.OAuthToken | None):
105
  if oauth_token.token is None:
106
- raise ValueError("You must be logged in to use MLX-my-repo")
107
-
108
- model_name = model_id.split('/')[-1]
109
- username = whoami(oauth_token.token)["name"]
 
 
 
 
 
 
 
110
  try:
 
 
 
 
111
  if q_method == "FP16":
112
- upload_repo = f"{username}/{model_name}-mlx-fp16"
113
- with tempfile.TemporaryDirectory(dir="converted") as tmpdir:
114
- # The target directory must not exist
115
- mlx_path = os.path.join(tmpdir, "mlx")
116
- convert(model_id, mlx_path=mlx_path, quantize=False, dtype="float16")
117
- print("Conversion done")
118
- upload_to_hub(path=mlx_path, upload_repo=upload_repo, hf_path=model_id, oauth_token=oauth_token)
119
- print("Upload done")
120
  else:
121
  q_bits = QUANT_PARAMS[q_method]
122
- upload_repo = f"{username}/{model_name}-mlx-{q_bits}Bit"
123
- with tempfile.TemporaryDirectory(dir="converted") as tmpdir:
124
- # The target directory must not exist
125
- mlx_path = os.path.join(tmpdir, "mlx")
 
 
 
 
 
 
126
  convert(model_id, mlx_path=mlx_path, quantize=True, q_bits=q_bits)
127
- print("Conversion done")
128
- upload_to_hub(path=mlx_path, upload_repo=upload_repo, hf_path=model_id, oauth_token=oauth_token)
129
- print("Upload done")
 
130
  return (
131
  f'Find your repo <a href="https://hf.co/{upload_repo}" target="_blank" style="text-decoration:underline">here</a>',
132
  "llama.png",
@@ -141,7 +153,7 @@ css="""/* Custom CSS to allow scrolling */
141
  .gradio-container {overflow-y: auto;}
142
  """
143
  # Create Gradio interface
144
- with gr.Blocks(css=css) as demo:
145
  gr.Markdown("You must be logged in to use MLX-my-repo.")
146
  gr.LoginButton(min_width=250)
147
 
@@ -159,7 +171,7 @@ with gr.Blocks(css=css) as demo:
159
  filterable=False,
160
  visible=True
161
  )
162
-
163
  iface = gr.Interface(
164
  fn=process_model,
165
  inputs=[
 
33
  def list_files_in_folder(folder_path):
34
  # List all files and directories in the specified folder
35
  all_items = os.listdir(folder_path)
36
+
37
  # Filter out only files
38
  files = [item for item in all_items if os.path.isfile(os.path.join(folder_path, item))]
39
+
40
  return files
41
 
42
  def clear_hf_cache_space():
 
48
  scan.delete_revisions(*to_delete).execute()
49
  print("Cache has been cleared")
50
 
51
+ def upload_to_hub(path, upload_repo, hf_path, token):
52
+ card = ModelCard.load(hf_path, token=token)
53
  card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx", "mlx-my-repo"]
54
  card.data.base_model = hf_path
55
  card.text = dedent(
 
85
 
86
  logging.set_verbosity_info()
87
 
88
+ api = HfApi(token=token)
89
  api.create_repo(repo_id=upload_repo, exist_ok=True)
90
 
91
  files = list_files_in_folder(path)
 
98
  path_in_repo=file,
99
  repo_id=upload_repo,
100
  )
101
+
102
+ print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
103
 
104
  def process_model(model_id, q_method, oauth_token: gr.OAuthToken | None):
105
  if oauth_token.token is None:
106
+ return ("You must be logged in to use MLX-my-repo", "error.png")
107
+
108
+ # Verify the token
109
+ username = None
110
+ try:
111
+ user_info = whoami(token=token)
112
+ username = user_info["name"]
113
+ print(f"✅ Logged in as {username}")
114
+ except Exception as e:
115
+ return (f"❌ Authentication failed: {e}", "error.png")
116
+
117
  try:
118
+ model_name = model_id.split('/')[-1]
119
+ repo_name = None
120
+ q_bits = None
121
+
122
  if q_method == "FP16":
123
+ q_bits = "float16"
124
+ repo_name = f"{model_name}-fp16"
 
 
 
 
 
 
125
  else:
126
  q_bits = QUANT_PARAMS[q_method]
127
+ repo_name = f"{model_name}-{q_bits}bit"
128
+
129
+ upload_repo = f"${username}/${repo_name}"
130
+
131
+ with tempfile.TemporaryDirectory(dir=f"converted/${repo_name}") as tmpdir:
132
+ # The target directory must not exist
133
+ mlx_path = os.path.join(tmpdir, "mlx")
134
+ if q_method == "FP16":
135
+ convert(model_id, mlx_path=mlx_path, quantize=False, dtype="float16")
136
+ else:
137
  convert(model_id, mlx_path=mlx_path, quantize=True, q_bits=q_bits)
138
+ print("Conversion done")
139
+ upload_to_hub(path=mlx_path, upload_repo=upload_repo, hf_path=model_id, token=token)
140
+ print("Upload done")
141
+
142
  return (
143
  f'Find your repo <a href="https://hf.co/{upload_repo}" target="_blank" style="text-decoration:underline">here</a>',
144
  "llama.png",
 
153
  .gradio-container {overflow-y: auto;}
154
  """
155
  # Create Gradio interface
156
+ with gr.Blocks(css=css) as demo:
157
  gr.Markdown("You must be logged in to use MLX-my-repo.")
158
  gr.LoginButton(min_width=250)
159
 
 
171
  filterable=False,
172
  visible=True
173
  )
174
+
175
  iface = gr.Interface(
176
  fn=process_model,
177
  inputs=[