Tech-Meld commited on
Commit
8bb39cb
·
verified ·
1 Parent(s): 7730f68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -32
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
- from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig
3
- from huggingface_hub import cached_download, hf_hub_url, list_models
4
  from transformers.modeling_utils import PreTrainedModel
5
  import requests
6
  import json
@@ -10,7 +10,6 @@ from io import BytesIO
10
  import base64
11
  import torch
12
  from torch.nn.utils import prune
13
- from transformers.models.auto import AutoModelForCausalLM # Import for CausalLM
14
 
15
  # Function to fetch open-weight LLM models
16
  def fetch_open_weight_models():
@@ -18,20 +17,15 @@ def fetch_open_weight_models():
18
  return models
19
 
20
  # Function to prune a model using the "merge-kit" approach
21
- def prune_model(llm_model_name, target_size, output_dir):
22
  try:
23
  # Load the LLM model and tokenizer
24
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
25
  # Handle cases where the model is split into multiple safetensors
26
- if "safetensors" in llm_tokenizer.vocab_files_names:
27
- llm_model = AutoModelForCausalLM.from_pretrained(
28
- llm_model_name,
29
- from_safetensors=True,
30
- torch_dtype=torch.float16, # Adjust dtype as needed
31
- use_auth_token=None,
32
- )
33
- else:
34
- llm_model = AutoModel.from_pretrained(llm_model_name)
35
 
36
  # Get the model config
37
  config = AutoConfig.from_pretrained(llm_model_name)
@@ -41,8 +35,12 @@ def prune_model(llm_model_name, target_size, output_dir):
41
  # Use merge-kit to prune the model
42
  pruned_model = merge_kit_prune(llm_model, target_num_parameters)
43
 
44
- # Save the pruned model
45
- pruned_model.save_pretrained(output_dir)
 
 
 
 
46
 
47
  # Create a visualization
48
  fig, ax = plt.subplots(figsize=(10, 5))
@@ -53,7 +51,7 @@ def prune_model(llm_model_name, target_size, output_dir):
53
  fig.savefig(buf, format="png")
54
  buf.seek(0)
55
  image_base64 = base64.b64encode(buf.read()).decode("utf-8")
56
- return f"Pruned model saved to {output_dir}", f"data:image/png;base64,{image_base64}"
57
 
58
  except Exception as e:
59
  return f"Error: {e}", None
@@ -61,23 +59,19 @@ def prune_model(llm_model_name, target_size, output_dir):
61
  # Merge-kit Pruning Function (adjust as needed)
62
  def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
63
  """Prunes a model using a merge-kit approach.
64
-
65
  Args:
66
  model (PreTrainedModel): The model to be pruned.
67
  target_num_parameters (int): The target number of parameters after pruning.
68
-
69
  Returns:
70
  PreTrainedModel: The pruned model.
71
  """
72
-
73
  # Define the pruning method
74
  pruning_method = "unstructured"
75
 
76
  # Calculate the pruning amount
77
- amount = 1 - (target_num_parameters / model.num_parameters)
78
 
79
- # Prune the model using the selected method (adapt for Llama)
80
- # Example: If Llama uses specific layers, adjust the pruning logic here
81
  for name, module in model.named_modules():
82
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
83
  prune.random_unstructured(module, name="weight", amount=amount)
@@ -107,22 +101,25 @@ def create_interface():
107
  interactive=True,
108
  )
109
 
110
- # Output for pruning status
111
- pruning_status = gr.Textbox(label="Pruning Status")
 
 
 
112
 
113
- # Output for saving the model
114
- save_model_path = gr.Textbox(label="Save Model Path", placeholder="Path to save the pruned model", interactive=True)
115
 
116
  # Button to start pruning
117
  prune_button = gr.Button("Prune Model")
118
 
119
  # Output for visualization
120
- visualization = gr.Image(label="Model Size Comparison")
121
 
122
  # Connect components
123
  prune_button.click(
124
  fn=prune_model,
125
- inputs=[llm_model_name, target_size, save_model_path],
126
  outputs=[pruning_status, visualization],
127
  )
128
 
@@ -133,11 +130,11 @@ def create_interface():
133
  # Generate text button
134
  generate_button = gr.Button("Generate Text")
135
 
136
- def generate_text(text, model_path):
137
  try:
138
  # Load the pruned model and tokenizer
139
- tokenizer = AutoTokenizer.from_pretrained(model_path)
140
- model = AutoModelForCausalLM.from_pretrained(model_path) # Load as CausalLM
141
 
142
  # Use the pipeline for text generation
143
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
@@ -146,7 +143,7 @@ def create_interface():
146
  except Exception as e:
147
  return f"Error: {e}"
148
 
149
- generate_button.click(fn=generate_text, inputs=[text_input, save_model_path], outputs=text_output)
150
 
151
  return demo
152
 
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
3
+ from huggingface_hub import cached_download, hf_hub_url, list_models, create_repo, HfApi
4
  from transformers.modeling_utils import PreTrainedModel
5
  import requests
6
  import json
 
10
  import base64
11
  import torch
12
  from torch.nn.utils import prune
 
13
 
14
  # Function to fetch open-weight LLM models
15
  def fetch_open_weight_models():
 
17
  return models
18
 
19
  # Function to prune a model using the "merge-kit" approach
20
+ def prune_model(llm_model_name, target_size, hf_write_token, repo_name):
21
  try:
22
  # Load the LLM model and tokenizer
23
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
24
  # Handle cases where the model is split into multiple safetensors
25
+ llm_model = AutoModelForCausalLM.from_pretrained(
26
+ llm_model_name,
27
+ torch_dtype=torch.float16, # Adjust dtype as needed
28
+ )
 
 
 
 
 
29
 
30
  # Get the model config
31
  config = AutoConfig.from_pretrained(llm_model_name)
 
35
  # Use merge-kit to prune the model
36
  pruned_model = merge_kit_prune(llm_model, target_num_parameters)
37
 
38
+ # Save the pruned model to Hugging Face repository
39
+ api = HfApi()
40
+ repo_id = f"{hf_write_token}/{repo_name}"
41
+ create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True)
42
+ pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token)
43
+ llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token)
44
 
45
  # Create a visualization
46
  fig, ax = plt.subplots(figsize=(10, 5))
 
51
  fig.savefig(buf, format="png")
52
  buf.seek(0)
53
  image_base64 = base64.b64encode(buf.read()).decode("utf-8")
54
+ return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}"
55
 
56
  except Exception as e:
57
  return f"Error: {e}", None
 
59
  # Merge-kit Pruning Function (adjust as needed)
60
  def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
61
  """Prunes a model using a merge-kit approach.
 
62
  Args:
63
  model (PreTrainedModel): The model to be pruned.
64
  target_num_parameters (int): The target number of parameters after pruning.
 
65
  Returns:
66
  PreTrainedModel: The pruned model.
67
  """
 
68
  # Define the pruning method
69
  pruning_method = "unstructured"
70
 
71
  # Calculate the pruning amount
72
+ amount = 1 - (target_num_parameters / sum(p.numel() for p in model.parameters()))
73
 
74
+ # Prune the model using the selected method
 
75
  for name, module in model.named_modules():
76
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
77
  prune.random_unstructured(module, name="weight", amount=amount)
 
101
  interactive=True,
102
  )
103
 
104
+ # Input for Hugging Face write token
105
+ hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
106
+
107
+ # Input for repository name
108
+ repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
109
 
110
+ # Output for pruning status
111
+ pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
112
 
113
  # Button to start pruning
114
  prune_button = gr.Button("Prune Model")
115
 
116
  # Output for visualization
117
+ visualization = gr.Image(label="Model Size Comparison", interactive=False)
118
 
119
  # Connect components
120
  prune_button.click(
121
  fn=prune_model,
122
+ inputs=[llm_model_name, target_size, hf_write_token, repo_name],
123
  outputs=[pruning_status, visualization],
124
  )
125
 
 
130
  # Generate text button
131
  generate_button = gr.Button("Generate Text")
132
 
133
+ def generate_text(text, repo_name):
134
  try:
135
  # Load the pruned model and tokenizer
136
+ tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
137
+ model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
138
 
139
  # Use the pipeline for text generation
140
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
143
  except Exception as e:
144
  return f"Error: {e}"
145
 
146
+ generate_button.click(fn=generate_text, inputs=[text_input, repo_name], outputs=text_output)
147
 
148
  return demo
149