pair-ranker / README.md
Dongfu Jiang
Update README.md
31ff879
|
raw
history blame
2.81 kB
---
license: mit
datasets:
- llm-blender/mix-instruct
metrics:
- BERTScore
- BLEURT
- BARTScore
- Pairwise Rank
tags:
- pair_ranker
- reward_model
- RLHF
---
PairRanker used in llm-blender, trained on deberta-v3-large. This is the ranker model used in experiments in LLM-Blender paper,
which is trained on [mixinstruct](https://huggingface.co/datasets/llm-blender/mix-instruct) dataset for 5 epochs.
| PairRanker type | Source max length | Candidate max length | Total max length |
|:-----------------:|:-----------------:|----------------------|------------------|
| [pair-ranker](https://huggingface.co/llm-blender/pair-ranker) (This model) | 128 | 128 | 384 |
| [pair-reward-model](https://huggingface.co/llm-blender/pair-reward-model/) | 1224 | 412 | 2048 |
- Github: [https://github.com/yuchenlin/LLM-Blender](https://github.com/yuchenlin/LLM-Blender)
- Paper: [https://arxiv.org/abs/2306.02561](https://arxiv.org/abs/2306.02561)
## Usage Example
Since PairRanker contains some custom layers and tokens. We recommend use our pairranker with our llm-blender python repo.
Otherwise, loading it directly with hugging face `from_pretrained()` API will encounter errors.
First install `llm-blender` by `pip install git+https://github.com/yuchenlin/LLM-Blender.git`
Then use pairranker with the following code:
```python
import llm_blender
# ranker config
ranker_config = llm_blender.RankerConfig()
ranker_config.ranker_type = "pairranker" # only supports pairranker now.
ranker_config.model_type = "deberta"
ranker_config.model_name = "microsoft/deberta-v3-large" # ranker backbone
ranker_config.load_checkpoint = "llm-blender/pair-ranker" # hugging face hub model path or your local ranker checkpoint <your checkpoint path>
ranker_config.cache_dir = "./hf_models" # hugging face model cache dir
ranker_config.source_maxlength = 128
ranker_config.candidate_maxlength = 128
ranker_config.n_tasks = 1 # number of singal that has been used to train the ranker. This checkpoint is trained using BARTScore only, thus being 1.
fuser_config = llm_blender.GenFuserConfig()
# ignore fuser config as we don't use it here. You can load it if you want
blender_config = llm_blender.BlenderConfig()
# blender config
blender_config.device = "cuda" # blender ranker and fuser device
blender = llm_blender.Blender(blender_config, ranker_config, fuser_config)
```
Then you are good to use pairrankers with
- `blender.rank()` to rank candidates
- `blender.compare()` to compare 2 candiates.
See LLM-Blender Github [README.md](https://github.com/yuchenlin/LLM-Blender#rank-and-fusion)
and jupyter file [blender_usage.ipynb](https://github.com/yuchenlin/LLM-Blender/blob/main/blender_usage.ipynb)
for detailed usage examples.