FL33TW00D commited on
Commit
dc80200
·
unverified ·
1 Parent(s): 1c9123b

chore: init

Browse files
README.md CHANGED
@@ -1,14 +1,54 @@
1
- ---
2
- title: Throughput Calculator
3
- emoji: 🐠
4
- colorFrom: gray
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.23.3
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Calculate the estimated throughput of on-device LLMs 🚀
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # On-Device LLM Throughput Calculator
2
+
3
+ A Gradio web application that helps visualize LLM throughput on memory-bandwidth-constrained devices.
4
+
5
+ ## Overview
6
+
7
+ This tool calculates and visualizes the theoretical throughput (tokens per second) that can be achieved by a Large Language Model (LLM) running on devices with memory bandwidth constraints. It supports different attention mechanisms:
8
+
9
+ - Grouped Query Attention (GQA)
10
+ - Multi-Query Attention (MQA)
11
+ - Memory-Latent Attention (MLA)
12
+
13
+ It also visualizes how sliding window attention impacts throughput at different context lengths.
14
+
15
+ ## Features
16
+
17
+ - Customize device specifications (memory bandwidth)
18
+ - Configure model parameters (size, layers, heads)
19
+ - Compare different attention mechanisms
20
+ - Visualize performance across different context lengths
21
+ - Sliding window attention support
22
+
23
+ ## Usage
24
+
25
+ 1. Configure your device details (name, memory bandwidth)
26
+ 2. Set model parameters (number of parameters, layer count, etc.)
27
+ 3. Choose which attention mechanism configurations to compare
28
+ 4. Generate a visualization of expected throughput
29
+
30
+ ## Installation
31
+
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
35
+
36
+ ## Running Locally
37
+
38
+ ```bash
39
+ cd src
40
+ python app.py
41
+ ```
42
+
43
+ ## Theory
44
+
45
+ The calculations are based on memory bandwidth bottlenecks as described in the [JAX ML Scaling Book](https://jax-ml.github.io/scaling-book/inference/#theoretical-estimates-for-llm-latency-and-throughput).
46
+
47
+ The basic formula for tokens per second:
48
+ ```
49
+ tokens_per_second = (batch_size * memory_bandwidth) / (batch_size * total_kv_size + parameter_size)
50
+ ```
51
+
52
+ ## License
53
+
54
+ MIT
pyproject.toml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "throughput-calculator"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10.6"
7
+ dependencies = [
8
+ "gradio>=4.0.0",
9
+ "numpy>=1.24.0",
10
+ "matplotlib>=3.7.0",
11
+ "seaborn>=0.12.0",
12
+ ]
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (178 Bytes). View file
 
src/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (182 Bytes). View file
 
src/__pycache__/app.cpython-310.pyc ADDED
Binary file (7.07 kB). View file
 
src/__pycache__/throughput_utils.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
src/__pycache__/throughput_utils.cpython-313.pyc ADDED
Binary file (6.68 kB). View file
 
src/app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from enum import Enum
3
+ from throughput_utils import create_throughput_plot
4
+
5
+ class AttentionType(Enum):
6
+ LOCAL = 0
7
+ GLOBAL = 1
8
+
9
+ class PhoneBandwidth(Enum):
10
+ Sixteen = 60
11
+ Fifteen = 51.2
12
+ Fourteen = 34.1
13
+
14
+ custom_css = """
15
+ #plot-container {
16
+ border-radius: 10px;
17
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 1px 3px rgba(0, 0, 0, 0.08);
18
+ padding: 1rem;
19
+ background-color: white;
20
+ height: 100%;
21
+ margin-bottom: 1.5rem;
22
+ }
23
+
24
+ #generate-button {
25
+ background-color: #2563eb;
26
+ color: white;
27
+ border-radius: 8px;
28
+ font-weight: bold;
29
+ padding: 10px 20px;
30
+ box-shadow: 0 4px 6px rgba(37, 99, 235, 0.1);
31
+ transition: all 0.2s ease;
32
+ width: 100%;
33
+ max-width: 400px;
34
+ margin: 0 auto;
35
+ font-size: 16px;
36
+ }
37
+
38
+ #generate-button:hover {
39
+ background-color: #1d4ed8;
40
+ box-shadow: 0 6px 8px rgba(37, 99, 235, 0.2);
41
+ transform: translateY(-2px);
42
+ }
43
+
44
+ .gradio-container {
45
+ background-color: #f5f7fa;
46
+ }
47
+
48
+ /* Custom styles for sliders containers */
49
+ .sliders-container {
50
+ border: 1px solid rgba(0, 0, 0, 0.1);
51
+ border-radius: 8px;
52
+ padding: 1rem;
53
+ margin-top: 0.5rem;
54
+ background-color: rgba(255, 255, 255, 0.8);
55
+ }
56
+
57
+ #error-status {
58
+ color: #b91c1c;
59
+ background-color: #fee2e2;
60
+ border-radius: 8px;
61
+ padding: 0.75rem;
62
+ margin-top: 0.5rem;
63
+ border: 1px solid #f87171;
64
+ font-weight: 500;
65
+ }
66
+ """
67
+
68
+ with gr.Blocks(css=custom_css) as demo:
69
+ gqa_sliders = []
70
+ mla_sliders = []
71
+
72
+ with gr.Column():
73
+ gr.Markdown(
74
+ """# 📊 On-Device LLM Throughput Calculator
75
+
76
+ This tool estimates the throughput (tokens per second) of Large Language Models on devices with memory bandwidth constraints.
77
+ It visualizes how different attention mechanisms (GQA, MLA) and context lengths affect throughput.
78
+ """
79
+ )
80
+
81
+ with gr.Row():
82
+ plot_output = gr.Image(label="Throughput Plot", type="pil", elem_id="plot-container")
83
+
84
+ # Add status element to display validation errors
85
+ status_output = gr.Markdown(visible=False, elem_id="error-status")
86
+
87
+ with gr.Row():
88
+ plot_button = gr.Button("Generate Throughput Plot", size="lg", elem_id="generate-button", variant="primary")
89
+
90
+ with gr.Row():
91
+ with gr.Column(scale=1):
92
+ with gr.Group():
93
+ gr.Markdown("### Device Configuration")
94
+ model_name = gr.Textbox(label="Model Name", value="TinyLLM")
95
+ iphone_model = gr.Dropdown(
96
+ label="iPhone Model",
97
+ choices=[e.name for e in PhoneBandwidth],
98
+ value=PhoneBandwidth.Sixteen.name,
99
+ interactive=True
100
+ )
101
+
102
+ with gr.Group():
103
+ gr.Markdown("### Attention Configurations to Plot")
104
+
105
+ gr.Markdown("#### GQA Head Configurations")
106
+ gr.Markdown("*Note: GQA head count must be less than or equal to the total number of heads*")
107
+
108
+ with gr.Column(elem_classes="sliders-container"):
109
+ gqa_slider1 = gr.Slider(minimum=1, maximum=32, step=2, value=4,
110
+ label="GQA Head Count #1")
111
+ gqa_slider2 = gr.Slider(minimum=1, maximum=32, step=2, value=8,
112
+ label="GQA Head Count #2")
113
+ gqa_sliders.extend([gqa_slider1, gqa_slider2])
114
+
115
+ gr.Markdown("#### MLA Compressed Dimensions")
116
+ gr.Markdown("*Note: MLA dimension must be less than or equal to d_model*")
117
+
118
+ with gr.Column(elem_classes="sliders-container"):
119
+ mla_slider1 = gr.Slider(minimum=64, maximum=1024, step=64, value=256,
120
+ label="MLA Dimension #1")
121
+ mla_slider2 = gr.Slider(minimum=64, maximum=1024, step=64, value=512,
122
+ label="MLA Dimension #2")
123
+ mla_sliders.extend([mla_slider1, mla_slider2])
124
+
125
+ with gr.Column(scale=1):
126
+ with gr.Group():
127
+ gr.Markdown("### Model Configuration")
128
+ num_parameters = gr.Number(label="Parameters (Billions)", value=3)
129
+ parameter_size = gr.Slider(minimum=1, maximum=16.0, step=1.0, label="Parameter Size (bits per param)", value=5)
130
+ kv_parameter_size = gr.Slider(minimum=0.25, maximum=4.0, step=0.25,
131
+ label="KV Cache Size (bytes per value)", value=2.0)
132
+ num_layers = gr.Number(label="Number of Layers", value=36)
133
+ num_heads = gr.Number(label="Number of Heads", value=16,
134
+ info="GQA head counts must be less than or equal to this value")
135
+ d_model = gr.Number(label="D Model", value=2048,
136
+ info="MLA dimensions must be less than or equal to this value")
137
+
138
+ with gr.Group():
139
+ gr.Markdown("### Context Configuration")
140
+ ctx_length = gr.Slider(minimum=1024, maximum=131072, step=1024,
141
+ label="Max Context Length", value=65536)
142
+ local_layers = gr.Number(label="Local Attention Layers", value=0)
143
+ global_layers = gr.Number(label="Global Attention Layers", value=1)
144
+ swa_size = gr.Slider(minimum=1024, maximum=32768, step=1024,
145
+ label="Sliding Window Size", value=4096)
146
+
147
+ gr.Markdown(
148
+ """
149
+ For more information, see [JAX ML Scaling Book](https://jax-ml.github.io/scaling-book/inference/#theoretical-estimates-for-llm-latency-and-throughput).
150
+ """
151
+ )
152
+
153
+ def generate_throughput_plot(
154
+ model_name, iphone_model, num_parameters, parameter_size,
155
+ kv_parameter_size, num_layers, num_heads, d_model, ctx_length,
156
+ local_layers, global_layers, swa_size, gqa_1, gqa_2, mla_1, mla_2
157
+ ):
158
+ memory_bandwidth = PhoneBandwidth[iphone_model].value
159
+
160
+ if "iPhone" not in model_name:
161
+ model_name = f"iPhone {iphone_model}: {model_name}"
162
+
163
+ try:
164
+ # Validate GQA head counts must be less than total attention heads
165
+ for gqa_heads, label in [(gqa_1, "GQA Head Count #1"), (gqa_2, "GQA Head Count #2")]:
166
+ if gqa_heads > num_heads:
167
+ raise ValueError(f"{label} ({gqa_heads}) cannot be greater than the total number of attention heads ({num_heads})")
168
+
169
+ # Validate MLA compressed dimensions must be less than d_model
170
+ for mla_dim, label in [(mla_1, "MLA Dimension #1"), (mla_2, "MLA Dimension #2")]:
171
+ if mla_dim > d_model:
172
+ raise ValueError(f"{label} ({mla_dim}) cannot be greater than the model dimension (d_model = {d_model})")
173
+
174
+ plot_img = create_throughput_plot(
175
+ model_name,
176
+ memory_bandwidth,
177
+ num_parameters,
178
+ parameter_size,
179
+ kv_parameter_size,
180
+ num_layers,
181
+ num_heads,
182
+ d_model,
183
+ ctx_length,
184
+ local_layers,
185
+ global_layers,
186
+ swa_size,
187
+ [gqa_1, gqa_2],
188
+ [mla_1, mla_2],
189
+ )
190
+
191
+ # Hide error message, show plot
192
+ return [
193
+ gr.update(value=plot_img),
194
+ gr.update(visible=False, value="")
195
+ ]
196
+ except Exception as e:
197
+ err_string = f"Error generating plot: {str(e)}"
198
+ print(err_string)
199
+ # Show error message, clear plot
200
+ return [
201
+ gr.update(value=None),
202
+ gr.update(visible=True, value=f"⚠️ {err_string}")
203
+ ]
204
+
205
+ # Function to update GQA sliders based on number of heads
206
+ def update_gqa_sliders(heads_value):
207
+ if not heads_value or heads_value < 1:
208
+ heads_value = 1
209
+ return [gr.update(maximum=heads_value, value=min(slider.value, heads_value)) for slider in gqa_sliders]
210
+
211
+ # Function to update MLA sliders based on d_model
212
+ def update_mla_sliders(d_model_value):
213
+ if not d_model_value or d_model_value < 64:
214
+ d_model_value = 64
215
+ return [gr.update(maximum=d_model_value, value=min(slider.value, d_model_value)) for slider in mla_sliders]
216
+
217
+ # Add event handlers to update sliders when model configuration changes
218
+ num_heads.change(
219
+ update_gqa_sliders,
220
+ inputs=[num_heads],
221
+ outputs=gqa_sliders
222
+ )
223
+
224
+ d_model.change(
225
+ update_mla_sliders,
226
+ inputs=[d_model],
227
+ outputs=mla_sliders
228
+ )
229
+
230
+ plot_button.click(
231
+ generate_throughput_plot,
232
+ inputs=[
233
+ model_name,
234
+ iphone_model,
235
+ num_parameters,
236
+ parameter_size,
237
+ kv_parameter_size,
238
+ num_layers,
239
+ num_heads,
240
+ d_model,
241
+ ctx_length,
242
+ local_layers,
243
+ global_layers,
244
+ swa_size,
245
+ *gqa_sliders,
246
+ *mla_sliders,
247
+ ],
248
+ outputs=[plot_output, status_output]
249
+ )
250
+
251
+ if __name__ == "__main__":
252
+ demo.launch()
src/throughput_utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+ from matplotlib.ticker import ScalarFormatter
5
+ from enum import Enum
6
+ import io
7
+
8
+ class AttentionType(Enum):
9
+ LOCAL = 0
10
+ GLOBAL = 1
11
+
12
+ def gqa_kv_per_layer_per_token(n_kv_heads, d_head, kv_parameter_size):
13
+ return 2 * kv_parameter_size * n_kv_heads * d_head
14
+
15
+ def mla_kv_per_layer_per_token(d_compressed, kv_parameter_size):
16
+ return kv_parameter_size * d_compressed
17
+
18
+ def tokens_per_second(batch_size, bandwidth, total_kv_size, param_size):
19
+ return (batch_size * bandwidth) / (batch_size * total_kv_size + param_size)
20
+
21
+ def compute_tps(kv_per_layer_per_token, seq_len, batch_size, total_param_size,
22
+ num_layers, swa_pattern, swa_size, bandwidth):
23
+ tps_values = []
24
+ for ctx_len in seq_len:
25
+ total_kv_size = 0
26
+ for l in range(num_layers):
27
+ if swa_pattern[l % len(swa_pattern)] == AttentionType.LOCAL:
28
+ total_kv_size += kv_per_layer_per_token * min(ctx_len, swa_size)
29
+ else:
30
+ total_kv_size += kv_per_layer_per_token * ctx_len
31
+ tps = tokens_per_second(batch_size, bandwidth, total_kv_size, total_param_size)
32
+ tps_values.append(tps)
33
+ return tps_values
34
+
35
+ def create_throughput_plot(
36
+ model_name,
37
+ memory_bandwidth,
38
+ num_parameters,
39
+ parameter_size,
40
+ kv_parameter_size,
41
+ num_layers,
42
+ num_heads,
43
+ d_model,
44
+ ctx_length,
45
+ local_layers,
46
+ global_layers,
47
+ swa_size,
48
+ gqa_heads,
49
+ mla_d_compressed,
50
+ ):
51
+ memory_bandwidth = float(memory_bandwidth) * 1_000_000_000
52
+ num_parameters = float(num_parameters) * 1_000_000_000
53
+
54
+ d_head = d_model // num_heads
55
+ total_param_size = num_parameters * (parameter_size / 8.0)
56
+
57
+ swa_pattern = ([AttentionType.LOCAL] * local_layers +
58
+ [AttentionType.GLOBAL] * global_layers)
59
+
60
+ if len(swa_pattern) == 0:
61
+ swa_pattern = [AttentionType.GLOBAL]
62
+
63
+ sns.set_theme(style="whitegrid", context="paper")
64
+ palette = sns.color_palette("viridis", len(gqa_heads) + len(mla_d_compressed))
65
+ plt.figure(figsize=(14, 8), dpi=300)
66
+
67
+ seq_len = np.logspace(2, 5, 100).astype(int)
68
+ batch_size = 1
69
+
70
+ tps_values = []
71
+ gqa_count = len(gqa_heads)
72
+ for i, n_kv_head in enumerate(gqa_heads):
73
+ n_kv_head = int(n_kv_head)
74
+ kv_per_token = gqa_kv_per_layer_per_token(n_kv_head, d_head, kv_parameter_size)
75
+ gqa_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size,
76
+ num_layers, swa_pattern, swa_size, memory_bandwidth)
77
+ tps_values.extend(gqa_tps_values)
78
+ plt.plot(seq_len, gqa_tps_values, label=f"GQA: {n_kv_head} heads", color=palette[i],
79
+ linewidth=3.5, alpha=0.85)
80
+
81
+ plt.axvline(x=ctx_length, color='red', linestyle='--', alpha=0.8, linewidth=2.5,
82
+ label=f"Max Context Length ({ctx_length:,})")
83
+
84
+ local_count = swa_pattern.count(AttentionType.LOCAL)
85
+ global_count = swa_pattern.count(AttentionType.GLOBAL)
86
+ if local_count > 0:
87
+ plt.axvline(x=swa_size, color='blue', linestyle='--', alpha=0.8, linewidth=2.5,
88
+ label=f"Sliding Window Limit ({swa_size:,})")
89
+
90
+ for i, d_comp in enumerate(mla_d_compressed):
91
+ d_comp = int(d_comp)
92
+ kv_per_token = mla_kv_per_layer_per_token(d_comp, kv_parameter_size)
93
+ mla_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size,
94
+ num_layers, swa_pattern, swa_size, memory_bandwidth)
95
+ tps_values.extend(mla_tps_values)
96
+ plt.plot(seq_len, mla_tps_values, label=f"MLA: dc = {d_comp}",
97
+ color=palette[i + gqa_count], linewidth=3.5, alpha=0.85)
98
+
99
+ plt.xscale('log')
100
+ if all(np.isfinite(tps_values)):
101
+ min_tps = min(tps_values)
102
+ max_tps = max(tps_values)
103
+ y_min = max(0, min_tps * 0.9)
104
+ y_max = max_tps * 1.1
105
+
106
+ plt.ylim(y_min, y_max)
107
+ else:
108
+ plt.ylim(15, 40)
109
+
110
+ plt.gca().xaxis.set_major_formatter(ScalarFormatter())
111
+ plt.gca().yaxis.set_major_formatter(ScalarFormatter())
112
+
113
+ ax = plt.gca()
114
+ ax.spines['top'].set_visible(False)
115
+ ax.spines['right'].set_visible(False)
116
+ ax.spines['left'].set_linewidth(1.5)
117
+ ax.spines['bottom'].set_linewidth(1.5)
118
+
119
+ attn_label = "Global" if local_count == 0 else f"SWA {local_count}:{global_count}"
120
+ device_name = model_name.split(':')[0] if ':' in model_name else model_name
121
+
122
+ plt.annotate(f"{device_name}\nBandwidth: {memory_bandwidth/1e9:.1f} GB/s\nParameter Size: {parameter_size:.1f} bits\nAttention Kind: {attn_label}",
123
+ xy=(0.8, 0.97),
124
+ xycoords='axes fraction',
125
+ bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor='darkgray'),
126
+ va='top',
127
+ fontsize=11)
128
+
129
+ plt.xlabel('Context Length (tokens)', fontsize=14, fontweight='bold')
130
+ plt.ylabel('Tokens per Second', fontsize=14, fontweight='bold')
131
+
132
+ plt.tick_params(axis='both', which='major', labelsize=12)
133
+
134
+ model_title = model_name.split(':')[1] if ':' in model_name else model_name
135
+ plt.title(f"{model_title}: Tokens Per Second vs. Sequence Length", fontsize=18,
136
+ fontweight='bold', pad=20)
137
+
138
+ plt.legend(title="Configuration", frameon=True, framealpha=0.95, fontsize=12, title_fontsize=14)
139
+
140
+ plt.grid(True, alpha=0.5)
141
+
142
+ buf = io.BytesIO()
143
+ plt.savefig(buf, format='png')
144
+ plt.close()
145
+ buf.seek(0)
146
+ from PIL import Image
147
+ img = Image.open(buf)
148
+ return img
uv.lock ADDED
The diff for this file is too large to render. See raw diff