File size: 1,758 Bytes
bd53d2a
 
 
 
 
 
 
 
 
 
 
 
 
dc80200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
---
title: On-Device LLM Throughput Calculator 
emoji: πŸš€
colorFrom: pink
colorTo: blue
sdk: gradio
sdk_version: 4.36.0
app_file: src/app.py
pinned: false
license: mit 
---


# On-Device LLM Throughput Calculator

A Gradio web application that helps visualize LLM throughput on memory-bandwidth-constrained devices.

## Overview

This tool calculates and visualizes the theoretical throughput (tokens per second) that can be achieved by a Large Language Model (LLM) running on devices with memory bandwidth constraints. It supports different attention mechanisms:

- Grouped Query Attention (GQA)
- Multi-Query Attention (MQA)
- Memory-Latent Attention (MLA)

It also visualizes how sliding window attention impacts throughput at different context lengths.

## Features

- Customize device specifications (memory bandwidth)
- Configure model parameters (size, layers, heads)
- Compare different attention mechanisms
- Visualize performance across different context lengths
- Sliding window attention support

## Usage

1. Configure your device details (name, memory bandwidth)
2. Set model parameters (number of parameters, layer count, etc.)
3. Choose which attention mechanism configurations to compare
4. Generate a visualization of expected throughput

## Installation

```bash
pip install -r requirements.txt
```

## Running Locally

```bash
cd src
python app.py
```

## Theory

The calculations are based on memory bandwidth bottlenecks as described in the [JAX ML Scaling Book](https://jax-ml.github.io/scaling-book/inference/#theoretical-estimates-for-llm-latency-and-throughput).

The basic formula for tokens per second:
```
tokens_per_second = (batch_size * memory_bandwidth) / (batch_size * total_kv_size + parameter_size)
```

## License

MIT