PyTorch
ssl-aasist
custom_code
ssl-aasist / fairseq /examples /data2vec /fb_convert_beit_cp.py
ash56's picture
Add files using upload-large-folder tool
b1b22fb verified
raw
history blame
4.97 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
from omegaconf import OmegaConf
from fairseq.criterions.model_criterion import ModelCriterionConfig
from fairseq.dataclass.configs import FairseqConfig
from tasks import ImageClassificationConfig, ImagePretrainingConfig
from models.data2vec_image_classification import (
Data2VecImageClassificationConfig,
Data2VecImageClassificationModel,
)
from models.data2vec_vision import Data2VecVisionConfig, Data2VecVisionModel
def get_parser():
parser = argparse.ArgumentParser(
description="convert beit checkpoint into data2vec - vision checkpoint"
)
# fmt: off
parser.add_argument('checkpoint', help='checkpoint to convert')
parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted checkpoint')
parser.add_argument('--type', type=str, choices=['vision', 'image_classification'], default='image_classification', help='type of model to upgrade')
parser.add_argument('--inception_norms', action='store_true', default=False)
# fmt: on
return parser
def update_checkpoint(model_dict, prefix, is_nested):
replace_paths = {
"cls_token": "model.cls_emb" if is_nested else "cls_emb",
"patch_embed": "model.patch_embed" if is_nested else "patch_embed",
"mask_token": "mask_emb",
}
starts_with = {
"patch_embed.proj": "model.patch_embed.conv"
if is_nested
else "patch_embed.conv",
"lm_head": "final_proj",
"fc_norm": "fc_norm",
"head": "head",
}
partial = {
"mlp.fc1": "mlp.0",
"mlp.fc2": "mlp.2",
}
for k in list(model_dict.keys()):
for sw, r in starts_with.items():
if k.startswith(sw):
replace_paths[k] = k.replace(sw, r)
for p, r in partial.items():
if p in k:
replace_paths[k] = prefix + k.replace(p, r)
if prefix != "":
for k in list(model_dict.keys()):
if k not in replace_paths:
replace_paths[k] = prefix + k
for k in list(model_dict.keys()):
if k in replace_paths:
model_dict[replace_paths[k]] = model_dict[k]
if k != replace_paths[k]:
del model_dict[k]
return model_dict
def main():
parser = get_parser()
args = parser.parse_args()
cp = torch.load(args.checkpoint, map_location="cpu")
cfg = FairseqConfig(
criterion=ModelCriterionConfig(_name="model", log_keys=["correct"]),
)
if args.type == "image_classification":
cfg.task = ImageClassificationConfig(
_name="image_classification",
data=".",
)
if args.inception_norms:
cfg.task.normalization_mean = [0.5, 0.5, 0.5]
cfg.task.normalization_std = [0.5, 0.5, 0.5]
cfg.model = Data2VecImageClassificationConfig(
_name="data2vec_image_classification",
)
cfg.model.pretrained_model_args = FairseqConfig(
model=Data2VecVisionConfig(
_name="data2vec_vision", shared_rel_pos_bias=False
),
task=ImagePretrainingConfig(
_name="image_pretraining",
),
)
cfg = OmegaConf.create(cfg)
state = {
"cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
"model": cp["module"],
"best_loss": None,
"optimizer": None,
"extra_state": {},
}
model = Data2VecImageClassificationModel(cfg.model)
model.load_state_dict(
update_checkpoint(state["model"], prefix="model.encoder.", is_nested=True),
strict=True,
)
elif args.type == "vision":
cfg.task = ImagePretrainingConfig(
_name="image_pretraining",
data=".",
)
if args.inception_norms:
cfg.task.normalization_mean = [0.5, 0.5, 0.5]
cfg.task.normalization_std = [0.5, 0.5, 0.5]
cfg.model = Data2VecVisionConfig(
_name="data2vec_vision",
)
cfg = OmegaConf.create(cfg)
state = {
"cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
"model": cp["model"],
"best_loss": None,
"optimizer": None,
"extra_state": {},
}
model = Data2VecVisionModel(cfg.model)
model.load_state_dict(
update_checkpoint(state["model"], prefix="encoder.", is_nested=False),
strict=True,
)
else:
raise Exception("unsupported type " + args.type)
print(state["cfg"], state.keys())
torch.save(state, args.output)
if __name__ == "__main__":
main()