File size: 925 Bytes
34d1f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from argparse import ArgumentParser

import torch

parser = ArgumentParser()
parser.add_argument('src', default='old.pth')
parser.add_argument('dst', default='new.pth')  # ('training','validation')
parser.add_argument('--code_size', type=int, default='10')
args = parser.parse_args()
model = torch.load(args.src)
code_size = args.code_size
if model['meta'].get('detr3d_convert_tag') is not None:
    print('this model has already converted!')
else:
    print('converting...')
    # (cx, cy, w, l, cz, h, sin(φ), cos(φ), vx, vy)
    for key in model['state_dict']:
        tsr = model['state_dict'][key]
        if 'reg_branches' in key and tsr.shape[0] == code_size:
            print(key, ' with ', tsr.shape, 'has changed')
            tsr[[2, 3], ...] = tsr[[3, 2], ...]
            tsr[[6, 7], ...] = -tsr[[7, 6], ...]
    model['meta']['detr3d_convert_tag'] = True
    torch.save(model, args.dst)
    print('done...')