you should add statements like: | |
python | |
assert ( | |
model_pointer.weight.shape == pretrained_weight.shape | |
), f"Pointer shape of random weight {model_pointer.shape} and array shape of checkpoint weight {pretrained_weight.shape} mismatched" | |
Besides, you should also print out the names of both weights to make sure they match, e.g. |