output_ids = model.generate(input_ids, max_length=100)[0].tolist() | |
output_ids | |
[0, 258, 108, 118, 35, 119, 107, 104, 35, 114, 113, 104, 35, 122, 107, 114, 35, 103, 114, 104, 118, 257, 35, 108, 113, 35, 119, 107, 104, 35, 103, 108, 118, 102, 114, 256, 108, 113, 35, 119, 107, 104, 35, 115, 100, 117, 110, 49, 35, 87, 107, 104, 35, 103, 114, 106, 35, 108, 118, 35, 119, 107, 104, 35, 114, 113, 104, 35, 122, 107, 114, 35, 103, 114, 104, 118, 35, 100, 35, 101, 100, 111, 111, 35, 108, 113, 255, 35, 108, 113, 35, 119, 107, 104, 35, 115, 100, 117, 110, 49] | |
^- Note how 258 descends to 257, 256, 255 | |
Now we need to split on the sentinel tokens, let's write a short loop for this | |
output_ids_list = [] | |
start_token = 0 | |
sentinel_token = 258 | |
while sentinel_token in output_ids: | |
split_idx = output_ids.index(sentinel_token) | |
output_ids_list.append(output_ids[start_token:split_idx]) | |
start_token = split_idx | |
sentinel_token -= 1 | |
output_ids_list.append(output_ids[start_token:]) | |
output_string = tokenizer.batch_decode(output_ids_list) | |
output_string | |
['', 'is the one who does', ' in the disco', 'in the park. |