|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import torch |
|
|
|
from omegaconf import OmegaConf |
|
|
|
from fairseq.criterions.model_criterion import ModelCriterionConfig |
|
from fairseq.dataclass.configs import FairseqConfig |
|
|
|
from tasks import ImageClassificationConfig, ImagePretrainingConfig |
|
from models.data2vec_image_classification import ( |
|
Data2VecImageClassificationConfig, |
|
Data2VecImageClassificationModel, |
|
) |
|
from models.data2vec_vision import Data2VecVisionConfig, Data2VecVisionModel |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser( |
|
description="convert beit checkpoint into data2vec - vision checkpoint" |
|
) |
|
|
|
parser.add_argument('checkpoint', help='checkpoint to convert') |
|
parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted checkpoint') |
|
parser.add_argument('--type', type=str, choices=['vision', 'image_classification'], default='image_classification', help='type of model to upgrade') |
|
parser.add_argument('--inception_norms', action='store_true', default=False) |
|
|
|
|
|
return parser |
|
|
|
|
|
def update_checkpoint(model_dict, prefix, is_nested): |
|
|
|
replace_paths = { |
|
"cls_token": "model.cls_emb" if is_nested else "cls_emb", |
|
"patch_embed": "model.patch_embed" if is_nested else "patch_embed", |
|
"mask_token": "mask_emb", |
|
} |
|
|
|
starts_with = { |
|
"patch_embed.proj": "model.patch_embed.conv" |
|
if is_nested |
|
else "patch_embed.conv", |
|
"lm_head": "final_proj", |
|
"fc_norm": "fc_norm", |
|
"head": "head", |
|
} |
|
|
|
partial = { |
|
"mlp.fc1": "mlp.0", |
|
"mlp.fc2": "mlp.2", |
|
} |
|
|
|
for k in list(model_dict.keys()): |
|
for sw, r in starts_with.items(): |
|
if k.startswith(sw): |
|
replace_paths[k] = k.replace(sw, r) |
|
for p, r in partial.items(): |
|
if p in k: |
|
replace_paths[k] = prefix + k.replace(p, r) |
|
|
|
if prefix != "": |
|
for k in list(model_dict.keys()): |
|
if k not in replace_paths: |
|
replace_paths[k] = prefix + k |
|
|
|
for k in list(model_dict.keys()): |
|
if k in replace_paths: |
|
model_dict[replace_paths[k]] = model_dict[k] |
|
if k != replace_paths[k]: |
|
del model_dict[k] |
|
|
|
return model_dict |
|
|
|
|
|
def main(): |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
|
|
cp = torch.load(args.checkpoint, map_location="cpu") |
|
|
|
cfg = FairseqConfig( |
|
criterion=ModelCriterionConfig(_name="model", log_keys=["correct"]), |
|
) |
|
|
|
if args.type == "image_classification": |
|
|
|
cfg.task = ImageClassificationConfig( |
|
_name="image_classification", |
|
data=".", |
|
) |
|
|
|
if args.inception_norms: |
|
cfg.task.normalization_mean = [0.5, 0.5, 0.5] |
|
cfg.task.normalization_std = [0.5, 0.5, 0.5] |
|
|
|
cfg.model = Data2VecImageClassificationConfig( |
|
_name="data2vec_image_classification", |
|
) |
|
cfg.model.pretrained_model_args = FairseqConfig( |
|
model=Data2VecVisionConfig( |
|
_name="data2vec_vision", shared_rel_pos_bias=False |
|
), |
|
task=ImagePretrainingConfig( |
|
_name="image_pretraining", |
|
), |
|
) |
|
|
|
cfg = OmegaConf.create(cfg) |
|
|
|
state = { |
|
"cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True), |
|
"model": cp["module"], |
|
"best_loss": None, |
|
"optimizer": None, |
|
"extra_state": {}, |
|
} |
|
|
|
model = Data2VecImageClassificationModel(cfg.model) |
|
model.load_state_dict( |
|
update_checkpoint(state["model"], prefix="model.encoder.", is_nested=True), |
|
strict=True, |
|
) |
|
elif args.type == "vision": |
|
cfg.task = ImagePretrainingConfig( |
|
_name="image_pretraining", |
|
data=".", |
|
) |
|
|
|
if args.inception_norms: |
|
cfg.task.normalization_mean = [0.5, 0.5, 0.5] |
|
cfg.task.normalization_std = [0.5, 0.5, 0.5] |
|
|
|
cfg.model = Data2VecVisionConfig( |
|
_name="data2vec_vision", |
|
) |
|
cfg = OmegaConf.create(cfg) |
|
|
|
state = { |
|
"cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True), |
|
"model": cp["model"], |
|
"best_loss": None, |
|
"optimizer": None, |
|
"extra_state": {}, |
|
} |
|
|
|
model = Data2VecVisionModel(cfg.model) |
|
model.load_state_dict( |
|
update_checkpoint(state["model"], prefix="encoder.", is_nested=False), |
|
strict=True, |
|
) |
|
else: |
|
raise Exception("unsupported type " + args.type) |
|
|
|
print(state["cfg"], state.keys()) |
|
torch.save(state, args.output) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|