Manoj Acharya commited on
Commit
3c83b5b
·
1 Parent(s): 0abf63d

Add application file

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def calculate_training_memory(params, precision, batch_size, seq_length, num_heads, head_dim, num_layers):
4
+ bytes_per_param = 2 if precision == "FP16" else 4
5
+
6
+ # Model Weights
7
+ model_memory = params * bytes_per_param
8
+
9
+ # Optimizer States (Adam)
10
+ optimizer_memory = model_memory * 2
11
+
12
+ # Gradients
13
+ gradient_memory = model_memory
14
+
15
+ # Activation Memory (approximate formula)
16
+ activation_memory = batch_size * seq_length * num_heads * head_dim * num_layers * bytes_per_param
17
+
18
+ # Total Training Memory
19
+ total_memory = model_memory + optimizer_memory + gradient_memory + activation_memory
20
+
21
+ return f"Model Weights: {model_memory / 1e9:.2f} GB\nOptimizer: {optimizer_memory / 1e9:.2f} GB\nGradients: {gradient_memory / 1e9:.2f} GB\nActivation Memory: {activation_memory / 1e9:.2f} GB\nTotal Training Memory: {total_memory / 1e9:.2f} GB"
22
+
23
+ def calculate_inference_memory(params, precision, batch_size, seq_length, num_heads, head_dim, num_layers):
24
+ bytes_per_param = 2 if precision == "FP16" else 4
25
+
26
+ # Model Weights
27
+ model_memory = params * bytes_per_param
28
+
29
+ # KV Cache
30
+ kv_cache_memory = batch_size * seq_length * num_heads * head_dim * 2 * num_layers * bytes_per_param
31
+
32
+ # Total Inference Memory
33
+ total_memory = model_memory + kv_cache_memory
34
+
35
+ return f"Model Weights: {model_memory / 1e9:.2f} GB\nKV Cache: {kv_cache_memory / 1e9:.2f} GB\nTotal Inference Memory: {total_memory / 1e9:.2f} GB"
36
+
37
+ def calculate_kv_cache(batch_size, seq_length, num_heads, head_dim, num_layers, precision):
38
+ bytes_per_param = 2 if precision == "FP16" else 4
39
+
40
+ # KV Cache Calculation
41
+ kv_cache_memory = batch_size * seq_length * num_heads * head_dim * 2 * num_layers * bytes_per_param
42
+
43
+ return f"KV Cache Memory: {kv_cache_memory / 1e9:.2f} GB"
44
+
45
+ with gr.Blocks() as app:
46
+ gr.Markdown("# GPU Memory Calculator for Transformer Models")
47
+
48
+ with gr.Tabs():
49
+ with gr.Tab("Training Memory Calculation"):
50
+ with gr.Row():
51
+ params = gr.Number(label="Number of Parameters (e.g., 175B = 175e9)", value=175e9)
52
+ precision = gr.Radio(["FP16", "FP32"], label="Precision", value="FP16")
53
+ with gr.Row():
54
+ batch_size = gr.Number(label="Batch Size", value=1)
55
+ seq_length = gr.Number(label="Sequence Length", value=2048)
56
+ with gr.Row():
57
+ num_heads = gr.Number(label="Number of Attention Heads", value=96)
58
+ head_dim = gr.Number(label="Head Dimension", value=128)
59
+ num_layers = gr.Number(label="Number of Layers", value=96)
60
+ train_button = gr.Button("Calculate Training Memory")
61
+ train_output = gr.Textbox(label="Training Memory Usage")
62
+ train_button.click(calculate_training_memory, [params, precision, batch_size, seq_length, num_heads, head_dim, num_layers], train_output)
63
+
64
+ with gr.Tab("Inference Memory Calculation"):
65
+ with gr.Row():
66
+ params_inf = gr.Number(label="Number of Parameters (e.g., 175B = 175e9)", value=175e9)
67
+ precision_inf = gr.Radio(["FP16", "FP32"], label="Precision", value="FP16")
68
+ with gr.Row():
69
+ batch_size_inf = gr.Number(label="Batch Size", value=1)
70
+ seq_length_inf = gr.Number(label="Sequence Length", value=2048)
71
+ with gr.Row():
72
+ num_heads_inf = gr.Number(label="Number of Attention Heads", value=96)
73
+ head_dim_inf = gr.Number(label="Head Dimension", value=128)
74
+ num_layers_inf = gr.Number(label="Number of Layers", value=96)
75
+ infer_button = gr.Button("Calculate Inference Memory")
76
+ infer_output = gr.Textbox(label="Inference Memory Usage")
77
+ infer_button.click(calculate_inference_memory, [params_inf, precision_inf, batch_size_inf, seq_length_inf, num_heads_inf, head_dim_inf, num_layers_inf], infer_output)
78
+
79
+ with gr.Tab("KV Cache Calculation"):
80
+ with gr.Row():
81
+ batch_size_kv = gr.Number(label="Batch Size", value=1)
82
+ seq_length_kv = gr.Number(label="Sequence Length", value=2048)
83
+ with gr.Row():
84
+ num_heads_kv = gr.Number(label="Number of Attention Heads", value=96)
85
+ head_dim_kv = gr.Number(label="Head Dimension", value=128)
86
+ num_layers_kv = gr.Number(label="Number of Layers", value=96)
87
+ precision_kv = gr.Radio(["FP16", "FP32"], label="Precision", value="FP16")
88
+ kv_button = gr.Button("Calculate KV Cache Memory")
89
+ kv_output = gr.Textbox(label="KV Cache Memory Usage")
90
+ kv_button.click(calculate_kv_cache, [batch_size_kv, seq_length_kv, num_heads_kv, head_dim_kv, num_layers_kv, precision_kv], kv_output)
91
+
92
+ app.launch()
93
+