FL33TW00D
commited on
chore: init
Browse files- README.md +54 -14
- pyproject.toml +12 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/__init__.cpython-313.pyc +0 -0
- src/__pycache__/app.cpython-310.pyc +0 -0
- src/__pycache__/throughput_utils.cpython-310.pyc +0 -0
- src/__pycache__/throughput_utils.cpython-313.pyc +0 -0
- src/app.py +252 -0
- src/throughput_utils.py +148 -0
- uv.lock +0 -0
README.md
CHANGED
@@ -1,14 +1,54 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# On-Device LLM Throughput Calculator
|
2 |
+
|
3 |
+
A Gradio web application that helps visualize LLM throughput on memory-bandwidth-constrained devices.
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
|
7 |
+
This tool calculates and visualizes the theoretical throughput (tokens per second) that can be achieved by a Large Language Model (LLM) running on devices with memory bandwidth constraints. It supports different attention mechanisms:
|
8 |
+
|
9 |
+
- Grouped Query Attention (GQA)
|
10 |
+
- Multi-Query Attention (MQA)
|
11 |
+
- Memory-Latent Attention (MLA)
|
12 |
+
|
13 |
+
It also visualizes how sliding window attention impacts throughput at different context lengths.
|
14 |
+
|
15 |
+
## Features
|
16 |
+
|
17 |
+
- Customize device specifications (memory bandwidth)
|
18 |
+
- Configure model parameters (size, layers, heads)
|
19 |
+
- Compare different attention mechanisms
|
20 |
+
- Visualize performance across different context lengths
|
21 |
+
- Sliding window attention support
|
22 |
+
|
23 |
+
## Usage
|
24 |
+
|
25 |
+
1. Configure your device details (name, memory bandwidth)
|
26 |
+
2. Set model parameters (number of parameters, layer count, etc.)
|
27 |
+
3. Choose which attention mechanism configurations to compare
|
28 |
+
4. Generate a visualization of expected throughput
|
29 |
+
|
30 |
+
## Installation
|
31 |
+
|
32 |
+
```bash
|
33 |
+
pip install -r requirements.txt
|
34 |
+
```
|
35 |
+
|
36 |
+
## Running Locally
|
37 |
+
|
38 |
+
```bash
|
39 |
+
cd src
|
40 |
+
python app.py
|
41 |
+
```
|
42 |
+
|
43 |
+
## Theory
|
44 |
+
|
45 |
+
The calculations are based on memory bandwidth bottlenecks as described in the [JAX ML Scaling Book](https://jax-ml.github.io/scaling-book/inference/#theoretical-estimates-for-llm-latency-and-throughput).
|
46 |
+
|
47 |
+
The basic formula for tokens per second:
|
48 |
+
```
|
49 |
+
tokens_per_second = (batch_size * memory_bandwidth) / (batch_size * total_kv_size + parameter_size)
|
50 |
+
```
|
51 |
+
|
52 |
+
## License
|
53 |
+
|
54 |
+
MIT
|
pyproject.toml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "throughput-calculator"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Add your description here"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.10.6"
|
7 |
+
dependencies = [
|
8 |
+
"gradio>=4.0.0",
|
9 |
+
"numpy>=1.24.0",
|
10 |
+
"matplotlib>=3.7.0",
|
11 |
+
"seaborn>=0.12.0",
|
12 |
+
]
|
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (178 Bytes). View file
|
|
src/__pycache__/__init__.cpython-313.pyc
ADDED
Binary file (182 Bytes). View file
|
|
src/__pycache__/app.cpython-310.pyc
ADDED
Binary file (7.07 kB). View file
|
|
src/__pycache__/throughput_utils.cpython-310.pyc
ADDED
Binary file (4.47 kB). View file
|
|
src/__pycache__/throughput_utils.cpython-313.pyc
ADDED
Binary file (6.68 kB). View file
|
|
src/app.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from enum import Enum
|
3 |
+
from throughput_utils import create_throughput_plot
|
4 |
+
|
5 |
+
class AttentionType(Enum):
|
6 |
+
LOCAL = 0
|
7 |
+
GLOBAL = 1
|
8 |
+
|
9 |
+
class PhoneBandwidth(Enum):
|
10 |
+
Sixteen = 60
|
11 |
+
Fifteen = 51.2
|
12 |
+
Fourteen = 34.1
|
13 |
+
|
14 |
+
custom_css = """
|
15 |
+
#plot-container {
|
16 |
+
border-radius: 10px;
|
17 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 1px 3px rgba(0, 0, 0, 0.08);
|
18 |
+
padding: 1rem;
|
19 |
+
background-color: white;
|
20 |
+
height: 100%;
|
21 |
+
margin-bottom: 1.5rem;
|
22 |
+
}
|
23 |
+
|
24 |
+
#generate-button {
|
25 |
+
background-color: #2563eb;
|
26 |
+
color: white;
|
27 |
+
border-radius: 8px;
|
28 |
+
font-weight: bold;
|
29 |
+
padding: 10px 20px;
|
30 |
+
box-shadow: 0 4px 6px rgba(37, 99, 235, 0.1);
|
31 |
+
transition: all 0.2s ease;
|
32 |
+
width: 100%;
|
33 |
+
max-width: 400px;
|
34 |
+
margin: 0 auto;
|
35 |
+
font-size: 16px;
|
36 |
+
}
|
37 |
+
|
38 |
+
#generate-button:hover {
|
39 |
+
background-color: #1d4ed8;
|
40 |
+
box-shadow: 0 6px 8px rgba(37, 99, 235, 0.2);
|
41 |
+
transform: translateY(-2px);
|
42 |
+
}
|
43 |
+
|
44 |
+
.gradio-container {
|
45 |
+
background-color: #f5f7fa;
|
46 |
+
}
|
47 |
+
|
48 |
+
/* Custom styles for sliders containers */
|
49 |
+
.sliders-container {
|
50 |
+
border: 1px solid rgba(0, 0, 0, 0.1);
|
51 |
+
border-radius: 8px;
|
52 |
+
padding: 1rem;
|
53 |
+
margin-top: 0.5rem;
|
54 |
+
background-color: rgba(255, 255, 255, 0.8);
|
55 |
+
}
|
56 |
+
|
57 |
+
#error-status {
|
58 |
+
color: #b91c1c;
|
59 |
+
background-color: #fee2e2;
|
60 |
+
border-radius: 8px;
|
61 |
+
padding: 0.75rem;
|
62 |
+
margin-top: 0.5rem;
|
63 |
+
border: 1px solid #f87171;
|
64 |
+
font-weight: 500;
|
65 |
+
}
|
66 |
+
"""
|
67 |
+
|
68 |
+
with gr.Blocks(css=custom_css) as demo:
|
69 |
+
gqa_sliders = []
|
70 |
+
mla_sliders = []
|
71 |
+
|
72 |
+
with gr.Column():
|
73 |
+
gr.Markdown(
|
74 |
+
"""# 📊 On-Device LLM Throughput Calculator
|
75 |
+
|
76 |
+
This tool estimates the throughput (tokens per second) of Large Language Models on devices with memory bandwidth constraints.
|
77 |
+
It visualizes how different attention mechanisms (GQA, MLA) and context lengths affect throughput.
|
78 |
+
"""
|
79 |
+
)
|
80 |
+
|
81 |
+
with gr.Row():
|
82 |
+
plot_output = gr.Image(label="Throughput Plot", type="pil", elem_id="plot-container")
|
83 |
+
|
84 |
+
# Add status element to display validation errors
|
85 |
+
status_output = gr.Markdown(visible=False, elem_id="error-status")
|
86 |
+
|
87 |
+
with gr.Row():
|
88 |
+
plot_button = gr.Button("Generate Throughput Plot", size="lg", elem_id="generate-button", variant="primary")
|
89 |
+
|
90 |
+
with gr.Row():
|
91 |
+
with gr.Column(scale=1):
|
92 |
+
with gr.Group():
|
93 |
+
gr.Markdown("### Device Configuration")
|
94 |
+
model_name = gr.Textbox(label="Model Name", value="TinyLLM")
|
95 |
+
iphone_model = gr.Dropdown(
|
96 |
+
label="iPhone Model",
|
97 |
+
choices=[e.name for e in PhoneBandwidth],
|
98 |
+
value=PhoneBandwidth.Sixteen.name,
|
99 |
+
interactive=True
|
100 |
+
)
|
101 |
+
|
102 |
+
with gr.Group():
|
103 |
+
gr.Markdown("### Attention Configurations to Plot")
|
104 |
+
|
105 |
+
gr.Markdown("#### GQA Head Configurations")
|
106 |
+
gr.Markdown("*Note: GQA head count must be less than or equal to the total number of heads*")
|
107 |
+
|
108 |
+
with gr.Column(elem_classes="sliders-container"):
|
109 |
+
gqa_slider1 = gr.Slider(minimum=1, maximum=32, step=2, value=4,
|
110 |
+
label="GQA Head Count #1")
|
111 |
+
gqa_slider2 = gr.Slider(minimum=1, maximum=32, step=2, value=8,
|
112 |
+
label="GQA Head Count #2")
|
113 |
+
gqa_sliders.extend([gqa_slider1, gqa_slider2])
|
114 |
+
|
115 |
+
gr.Markdown("#### MLA Compressed Dimensions")
|
116 |
+
gr.Markdown("*Note: MLA dimension must be less than or equal to d_model*")
|
117 |
+
|
118 |
+
with gr.Column(elem_classes="sliders-container"):
|
119 |
+
mla_slider1 = gr.Slider(minimum=64, maximum=1024, step=64, value=256,
|
120 |
+
label="MLA Dimension #1")
|
121 |
+
mla_slider2 = gr.Slider(minimum=64, maximum=1024, step=64, value=512,
|
122 |
+
label="MLA Dimension #2")
|
123 |
+
mla_sliders.extend([mla_slider1, mla_slider2])
|
124 |
+
|
125 |
+
with gr.Column(scale=1):
|
126 |
+
with gr.Group():
|
127 |
+
gr.Markdown("### Model Configuration")
|
128 |
+
num_parameters = gr.Number(label="Parameters (Billions)", value=3)
|
129 |
+
parameter_size = gr.Slider(minimum=1, maximum=16.0, step=1.0, label="Parameter Size (bits per param)", value=5)
|
130 |
+
kv_parameter_size = gr.Slider(minimum=0.25, maximum=4.0, step=0.25,
|
131 |
+
label="KV Cache Size (bytes per value)", value=2.0)
|
132 |
+
num_layers = gr.Number(label="Number of Layers", value=36)
|
133 |
+
num_heads = gr.Number(label="Number of Heads", value=16,
|
134 |
+
info="GQA head counts must be less than or equal to this value")
|
135 |
+
d_model = gr.Number(label="D Model", value=2048,
|
136 |
+
info="MLA dimensions must be less than or equal to this value")
|
137 |
+
|
138 |
+
with gr.Group():
|
139 |
+
gr.Markdown("### Context Configuration")
|
140 |
+
ctx_length = gr.Slider(minimum=1024, maximum=131072, step=1024,
|
141 |
+
label="Max Context Length", value=65536)
|
142 |
+
local_layers = gr.Number(label="Local Attention Layers", value=0)
|
143 |
+
global_layers = gr.Number(label="Global Attention Layers", value=1)
|
144 |
+
swa_size = gr.Slider(minimum=1024, maximum=32768, step=1024,
|
145 |
+
label="Sliding Window Size", value=4096)
|
146 |
+
|
147 |
+
gr.Markdown(
|
148 |
+
"""
|
149 |
+
For more information, see [JAX ML Scaling Book](https://jax-ml.github.io/scaling-book/inference/#theoretical-estimates-for-llm-latency-and-throughput).
|
150 |
+
"""
|
151 |
+
)
|
152 |
+
|
153 |
+
def generate_throughput_plot(
|
154 |
+
model_name, iphone_model, num_parameters, parameter_size,
|
155 |
+
kv_parameter_size, num_layers, num_heads, d_model, ctx_length,
|
156 |
+
local_layers, global_layers, swa_size, gqa_1, gqa_2, mla_1, mla_2
|
157 |
+
):
|
158 |
+
memory_bandwidth = PhoneBandwidth[iphone_model].value
|
159 |
+
|
160 |
+
if "iPhone" not in model_name:
|
161 |
+
model_name = f"iPhone {iphone_model}: {model_name}"
|
162 |
+
|
163 |
+
try:
|
164 |
+
# Validate GQA head counts must be less than total attention heads
|
165 |
+
for gqa_heads, label in [(gqa_1, "GQA Head Count #1"), (gqa_2, "GQA Head Count #2")]:
|
166 |
+
if gqa_heads > num_heads:
|
167 |
+
raise ValueError(f"{label} ({gqa_heads}) cannot be greater than the total number of attention heads ({num_heads})")
|
168 |
+
|
169 |
+
# Validate MLA compressed dimensions must be less than d_model
|
170 |
+
for mla_dim, label in [(mla_1, "MLA Dimension #1"), (mla_2, "MLA Dimension #2")]:
|
171 |
+
if mla_dim > d_model:
|
172 |
+
raise ValueError(f"{label} ({mla_dim}) cannot be greater than the model dimension (d_model = {d_model})")
|
173 |
+
|
174 |
+
plot_img = create_throughput_plot(
|
175 |
+
model_name,
|
176 |
+
memory_bandwidth,
|
177 |
+
num_parameters,
|
178 |
+
parameter_size,
|
179 |
+
kv_parameter_size,
|
180 |
+
num_layers,
|
181 |
+
num_heads,
|
182 |
+
d_model,
|
183 |
+
ctx_length,
|
184 |
+
local_layers,
|
185 |
+
global_layers,
|
186 |
+
swa_size,
|
187 |
+
[gqa_1, gqa_2],
|
188 |
+
[mla_1, mla_2],
|
189 |
+
)
|
190 |
+
|
191 |
+
# Hide error message, show plot
|
192 |
+
return [
|
193 |
+
gr.update(value=plot_img),
|
194 |
+
gr.update(visible=False, value="")
|
195 |
+
]
|
196 |
+
except Exception as e:
|
197 |
+
err_string = f"Error generating plot: {str(e)}"
|
198 |
+
print(err_string)
|
199 |
+
# Show error message, clear plot
|
200 |
+
return [
|
201 |
+
gr.update(value=None),
|
202 |
+
gr.update(visible=True, value=f"⚠️ {err_string}")
|
203 |
+
]
|
204 |
+
|
205 |
+
# Function to update GQA sliders based on number of heads
|
206 |
+
def update_gqa_sliders(heads_value):
|
207 |
+
if not heads_value or heads_value < 1:
|
208 |
+
heads_value = 1
|
209 |
+
return [gr.update(maximum=heads_value, value=min(slider.value, heads_value)) for slider in gqa_sliders]
|
210 |
+
|
211 |
+
# Function to update MLA sliders based on d_model
|
212 |
+
def update_mla_sliders(d_model_value):
|
213 |
+
if not d_model_value or d_model_value < 64:
|
214 |
+
d_model_value = 64
|
215 |
+
return [gr.update(maximum=d_model_value, value=min(slider.value, d_model_value)) for slider in mla_sliders]
|
216 |
+
|
217 |
+
# Add event handlers to update sliders when model configuration changes
|
218 |
+
num_heads.change(
|
219 |
+
update_gqa_sliders,
|
220 |
+
inputs=[num_heads],
|
221 |
+
outputs=gqa_sliders
|
222 |
+
)
|
223 |
+
|
224 |
+
d_model.change(
|
225 |
+
update_mla_sliders,
|
226 |
+
inputs=[d_model],
|
227 |
+
outputs=mla_sliders
|
228 |
+
)
|
229 |
+
|
230 |
+
plot_button.click(
|
231 |
+
generate_throughput_plot,
|
232 |
+
inputs=[
|
233 |
+
model_name,
|
234 |
+
iphone_model,
|
235 |
+
num_parameters,
|
236 |
+
parameter_size,
|
237 |
+
kv_parameter_size,
|
238 |
+
num_layers,
|
239 |
+
num_heads,
|
240 |
+
d_model,
|
241 |
+
ctx_length,
|
242 |
+
local_layers,
|
243 |
+
global_layers,
|
244 |
+
swa_size,
|
245 |
+
*gqa_sliders,
|
246 |
+
*mla_sliders,
|
247 |
+
],
|
248 |
+
outputs=[plot_output, status_output]
|
249 |
+
)
|
250 |
+
|
251 |
+
if __name__ == "__main__":
|
252 |
+
demo.launch()
|
src/throughput_utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import seaborn as sns
|
4 |
+
from matplotlib.ticker import ScalarFormatter
|
5 |
+
from enum import Enum
|
6 |
+
import io
|
7 |
+
|
8 |
+
class AttentionType(Enum):
|
9 |
+
LOCAL = 0
|
10 |
+
GLOBAL = 1
|
11 |
+
|
12 |
+
def gqa_kv_per_layer_per_token(n_kv_heads, d_head, kv_parameter_size):
|
13 |
+
return 2 * kv_parameter_size * n_kv_heads * d_head
|
14 |
+
|
15 |
+
def mla_kv_per_layer_per_token(d_compressed, kv_parameter_size):
|
16 |
+
return kv_parameter_size * d_compressed
|
17 |
+
|
18 |
+
def tokens_per_second(batch_size, bandwidth, total_kv_size, param_size):
|
19 |
+
return (batch_size * bandwidth) / (batch_size * total_kv_size + param_size)
|
20 |
+
|
21 |
+
def compute_tps(kv_per_layer_per_token, seq_len, batch_size, total_param_size,
|
22 |
+
num_layers, swa_pattern, swa_size, bandwidth):
|
23 |
+
tps_values = []
|
24 |
+
for ctx_len in seq_len:
|
25 |
+
total_kv_size = 0
|
26 |
+
for l in range(num_layers):
|
27 |
+
if swa_pattern[l % len(swa_pattern)] == AttentionType.LOCAL:
|
28 |
+
total_kv_size += kv_per_layer_per_token * min(ctx_len, swa_size)
|
29 |
+
else:
|
30 |
+
total_kv_size += kv_per_layer_per_token * ctx_len
|
31 |
+
tps = tokens_per_second(batch_size, bandwidth, total_kv_size, total_param_size)
|
32 |
+
tps_values.append(tps)
|
33 |
+
return tps_values
|
34 |
+
|
35 |
+
def create_throughput_plot(
|
36 |
+
model_name,
|
37 |
+
memory_bandwidth,
|
38 |
+
num_parameters,
|
39 |
+
parameter_size,
|
40 |
+
kv_parameter_size,
|
41 |
+
num_layers,
|
42 |
+
num_heads,
|
43 |
+
d_model,
|
44 |
+
ctx_length,
|
45 |
+
local_layers,
|
46 |
+
global_layers,
|
47 |
+
swa_size,
|
48 |
+
gqa_heads,
|
49 |
+
mla_d_compressed,
|
50 |
+
):
|
51 |
+
memory_bandwidth = float(memory_bandwidth) * 1_000_000_000
|
52 |
+
num_parameters = float(num_parameters) * 1_000_000_000
|
53 |
+
|
54 |
+
d_head = d_model // num_heads
|
55 |
+
total_param_size = num_parameters * (parameter_size / 8.0)
|
56 |
+
|
57 |
+
swa_pattern = ([AttentionType.LOCAL] * local_layers +
|
58 |
+
[AttentionType.GLOBAL] * global_layers)
|
59 |
+
|
60 |
+
if len(swa_pattern) == 0:
|
61 |
+
swa_pattern = [AttentionType.GLOBAL]
|
62 |
+
|
63 |
+
sns.set_theme(style="whitegrid", context="paper")
|
64 |
+
palette = sns.color_palette("viridis", len(gqa_heads) + len(mla_d_compressed))
|
65 |
+
plt.figure(figsize=(14, 8), dpi=300)
|
66 |
+
|
67 |
+
seq_len = np.logspace(2, 5, 100).astype(int)
|
68 |
+
batch_size = 1
|
69 |
+
|
70 |
+
tps_values = []
|
71 |
+
gqa_count = len(gqa_heads)
|
72 |
+
for i, n_kv_head in enumerate(gqa_heads):
|
73 |
+
n_kv_head = int(n_kv_head)
|
74 |
+
kv_per_token = gqa_kv_per_layer_per_token(n_kv_head, d_head, kv_parameter_size)
|
75 |
+
gqa_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size,
|
76 |
+
num_layers, swa_pattern, swa_size, memory_bandwidth)
|
77 |
+
tps_values.extend(gqa_tps_values)
|
78 |
+
plt.plot(seq_len, gqa_tps_values, label=f"GQA: {n_kv_head} heads", color=palette[i],
|
79 |
+
linewidth=3.5, alpha=0.85)
|
80 |
+
|
81 |
+
plt.axvline(x=ctx_length, color='red', linestyle='--', alpha=0.8, linewidth=2.5,
|
82 |
+
label=f"Max Context Length ({ctx_length:,})")
|
83 |
+
|
84 |
+
local_count = swa_pattern.count(AttentionType.LOCAL)
|
85 |
+
global_count = swa_pattern.count(AttentionType.GLOBAL)
|
86 |
+
if local_count > 0:
|
87 |
+
plt.axvline(x=swa_size, color='blue', linestyle='--', alpha=0.8, linewidth=2.5,
|
88 |
+
label=f"Sliding Window Limit ({swa_size:,})")
|
89 |
+
|
90 |
+
for i, d_comp in enumerate(mla_d_compressed):
|
91 |
+
d_comp = int(d_comp)
|
92 |
+
kv_per_token = mla_kv_per_layer_per_token(d_comp, kv_parameter_size)
|
93 |
+
mla_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size,
|
94 |
+
num_layers, swa_pattern, swa_size, memory_bandwidth)
|
95 |
+
tps_values.extend(mla_tps_values)
|
96 |
+
plt.plot(seq_len, mla_tps_values, label=f"MLA: dc = {d_comp}",
|
97 |
+
color=palette[i + gqa_count], linewidth=3.5, alpha=0.85)
|
98 |
+
|
99 |
+
plt.xscale('log')
|
100 |
+
if all(np.isfinite(tps_values)):
|
101 |
+
min_tps = min(tps_values)
|
102 |
+
max_tps = max(tps_values)
|
103 |
+
y_min = max(0, min_tps * 0.9)
|
104 |
+
y_max = max_tps * 1.1
|
105 |
+
|
106 |
+
plt.ylim(y_min, y_max)
|
107 |
+
else:
|
108 |
+
plt.ylim(15, 40)
|
109 |
+
|
110 |
+
plt.gca().xaxis.set_major_formatter(ScalarFormatter())
|
111 |
+
plt.gca().yaxis.set_major_formatter(ScalarFormatter())
|
112 |
+
|
113 |
+
ax = plt.gca()
|
114 |
+
ax.spines['top'].set_visible(False)
|
115 |
+
ax.spines['right'].set_visible(False)
|
116 |
+
ax.spines['left'].set_linewidth(1.5)
|
117 |
+
ax.spines['bottom'].set_linewidth(1.5)
|
118 |
+
|
119 |
+
attn_label = "Global" if local_count == 0 else f"SWA {local_count}:{global_count}"
|
120 |
+
device_name = model_name.split(':')[0] if ':' in model_name else model_name
|
121 |
+
|
122 |
+
plt.annotate(f"{device_name}\nBandwidth: {memory_bandwidth/1e9:.1f} GB/s\nParameter Size: {parameter_size:.1f} bits\nAttention Kind: {attn_label}",
|
123 |
+
xy=(0.8, 0.97),
|
124 |
+
xycoords='axes fraction',
|
125 |
+
bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor='darkgray'),
|
126 |
+
va='top',
|
127 |
+
fontsize=11)
|
128 |
+
|
129 |
+
plt.xlabel('Context Length (tokens)', fontsize=14, fontweight='bold')
|
130 |
+
plt.ylabel('Tokens per Second', fontsize=14, fontweight='bold')
|
131 |
+
|
132 |
+
plt.tick_params(axis='both', which='major', labelsize=12)
|
133 |
+
|
134 |
+
model_title = model_name.split(':')[1] if ':' in model_name else model_name
|
135 |
+
plt.title(f"{model_title}: Tokens Per Second vs. Sequence Length", fontsize=18,
|
136 |
+
fontweight='bold', pad=20)
|
137 |
+
|
138 |
+
plt.legend(title="Configuration", frameon=True, framealpha=0.95, fontsize=12, title_fontsize=14)
|
139 |
+
|
140 |
+
plt.grid(True, alpha=0.5)
|
141 |
+
|
142 |
+
buf = io.BytesIO()
|
143 |
+
plt.savefig(buf, format='png')
|
144 |
+
plt.close()
|
145 |
+
buf.seek(0)
|
146 |
+
from PIL import Image
|
147 |
+
img = Image.open(buf)
|
148 |
+
return img
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|