Tech-Meld commited on
Commit
0445e3f
·
verified ·
1 Parent(s): 5c5e320

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -14
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
3
- from huggingface_hub import cached_download, hf_hub_url, list_models, create_repo, HfApi
4
  from transformers.modeling_utils import PreTrainedModel
5
  import requests
6
  import json
@@ -11,11 +11,11 @@ import base64
11
  import torch
12
  from torch.nn.utils import prune
13
  import subprocess
 
 
14
 
15
- # Function to fetch open-weight LLM models
16
- def fetch_open_weight_models():
17
- models = list_models()
18
- return models
19
 
20
  # Ensure sentencepiece is installed
21
  try:
@@ -23,8 +23,14 @@ try:
23
  except ImportError:
24
  subprocess.check_call(["pip", "install", "sentencepiece"])
25
 
 
 
 
 
 
26
  # Function to prune a model using the "merge-kit" approach
27
- def prune_model(llm_model_name, target_size, hf_write_token, repo_name):
 
28
  try:
29
  # Load the LLM model and tokenizer
30
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
@@ -33,12 +39,18 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name):
33
  torch_dtype=torch.float16,
34
  )
35
 
 
 
 
36
  # Get the model config
37
  config = AutoConfig.from_pretrained(llm_model_name)
38
  target_num_parameters = int(config.num_parameters * (target_size / 100))
39
 
40
  # Prune the model
41
- pruned_model = merge_kit_prune(llm_model, target_num_parameters)
 
 
 
42
 
43
  # Save the pruned model
44
  api = HfApi()
@@ -47,6 +59,9 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name):
47
  pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token)
48
  llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token)
49
 
 
 
 
50
  # Create a visualization
51
  fig, ax = plt.subplots(figsize=(10, 5))
52
  ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters])
@@ -57,13 +72,16 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name):
57
  buf.seek(0)
58
  image_base64 = base64.b64encode(buf.read()).decode("utf-8")
59
 
60
- return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", None
61
 
62
  except Exception as e:
63
- return f"Error: {e}", None, None
 
 
 
64
 
65
  # Merge-kit Pruning Function (adjust as needed)
66
- def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
67
  """Prunes a model using a merge-kit approach.
68
  Args:
69
  model (PreTrainedModel): The model to be pruned.
@@ -75,10 +93,11 @@ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTr
75
  pruning_method = "unstructured"
76
 
77
  # Calculate the pruning amount
78
- amount = 1 - (target_num_parameters / sum(p.numel() for p in model.parameters()))
 
79
 
80
  # Prune the model using the selected method
81
- for name, module in model.named_modules():
82
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
83
  prune.random_unstructured(module, name="weight", amount=amount)
84
 
@@ -101,8 +120,18 @@ def create_interface():
101
  pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
102
  prune_button = gr.Button("Prune Model")
103
  visualization = gr.Image(label="Model Size Comparison", interactive=False)
 
 
 
 
 
 
 
 
 
 
104
 
105
- prune_button.click(fn=prune_model, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization])
106
 
107
  text_input = gr.Textbox(label="Input Text")
108
  text_output = gr.Textbox(label="Generated Text")
@@ -124,4 +153,4 @@ def create_interface():
124
 
125
  # Create and launch the Gradio interface
126
  demo = create_interface()
127
- demo.launch(share=True)
 
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
3
+ from huggingface_hub import create_repo, HfApi, list_models
4
  from transformers.modeling_utils import PreTrainedModel
5
  import requests
6
  import json
 
11
  import torch
12
  from torch.nn.utils import prune
13
  import subprocess
14
+ from tqdm import tqdm
15
+ import logging
16
 
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
19
 
20
  # Ensure sentencepiece is installed
21
  try:
 
23
  except ImportError:
24
  subprocess.check_call(["pip", "install", "sentencepiece"])
25
 
26
+ # Function to fetch open-weight LLM models
27
+ def fetch_open_weight_models():
28
+ models = list_models()
29
+ return models
30
+
31
  # Function to prune a model using the "merge-kit" approach
32
+ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress=gr.Progress(track_tqdm=True)):
33
+ log_messages = []
34
  try:
35
  # Load the LLM model and tokenizer
36
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
 
39
  torch_dtype=torch.float16,
40
  )
41
 
42
+ log_messages.append("Model and tokenizer loaded successfully.")
43
+ logging.info("Model and tokenizer loaded successfully.")
44
+
45
  # Get the model config
46
  config = AutoConfig.from_pretrained(llm_model_name)
47
  target_num_parameters = int(config.num_parameters * (target_size / 100))
48
 
49
  # Prune the model
50
+ pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
51
+
52
+ log_messages.append("Model pruned successfully.")
53
+ logging.info("Model pruned successfully.")
54
 
55
  # Save the pruned model
56
  api = HfApi()
 
59
  pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token)
60
  llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token)
61
 
62
+ log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
63
+ logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
64
+
65
  # Create a visualization
66
  fig, ax = plt.subplots(figsize=(10, 5))
67
  ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters])
 
72
  buf.seek(0)
73
  image_base64 = base64.b64encode(buf.read()).decode("utf-8")
74
 
75
+ return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", "\n".join(log_messages)
76
 
77
  except Exception as e:
78
+ error_message = f"Error: {e}"
79
+ log_messages.append(error_message)
80
+ logging.error(error_message)
81
+ return error_message, None, "\n".join(log_messages)
82
 
83
  # Merge-kit Pruning Function (adjust as needed)
84
+ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress) -> PreTrainedModel:
85
  """Prunes a model using a merge-kit approach.
86
  Args:
87
  model (PreTrainedModel): The model to be pruned.
 
93
  pruning_method = "unstructured"
94
 
95
  # Calculate the pruning amount
96
+ total_params = sum(p.numel() for p in model.parameters())
97
+ amount = 1 - (target_num_parameters / total_params)
98
 
99
  # Prune the model using the selected method
100
+ for name, module in tqdm(model.named_modules(), desc="Pruning", file=sys.stdout):
101
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
102
  prune.random_unstructured(module, name="weight", amount=amount)
103
 
 
120
  pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
121
  prune_button = gr.Button("Prune Model")
122
  visualization = gr.Image(label="Model Size Comparison", interactive=False)
123
+ progress_bar = gr.Progress()
124
+ logs_button = gr.Button("Show Logs")
125
+ logs_output = gr.Textbox(label="Logs", interactive=False)
126
+
127
+ def show_logs():
128
+ with open("pruning.log", "r") as log_file:
129
+ logs = log_file.read()
130
+ return logs
131
+
132
+ logs_button.click(fn=show_logs, outputs=logs_output)
133
 
134
+ prune_button.click(fn=prune_model, inputs=[llm_model_name, target_size, hf_write_token, repo_name, progress_bar], outputs=[pruning_status, visualization, logs_output])
135
 
136
  text_input = gr.Textbox(label="Input Text")
137
  text_output = gr.Textbox(label="Generated Text")
 
153
 
154
  # Create and launch the Gradio interface
155
  demo = create_interface()
156
+ demo.launch(share=True)