|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
from dataclasses import dataclass |
|
from typing import Any |
|
|
|
from omegaconf import II, MISSING |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from fairseq import checkpoint_utils, tasks |
|
|
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.models import BaseFairseqModel, register_model |
|
from fairseq.models.roberta.model import RobertaClassificationHead |
|
|
|
from examples.data2vec.data.modality import Modality |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class Data2VecTextClassificationConfig(FairseqDataclass): |
|
pooler_dropout: float = 0.0 |
|
pooler_activation_fn: str = "tanh" |
|
quant_noise_pq: int = 0 |
|
quant_noise_pq_block_size: int = 8 |
|
spectral_norm_classification_head: bool = False |
|
|
|
model_path: str = MISSING |
|
no_pretrained_weights: bool = False |
|
|
|
pretrained_model_args: Any = None |
|
|
|
|
|
@register_model( |
|
"data2vec_text_classification", dataclass=Data2VecTextClassificationConfig |
|
) |
|
class Data2VecTextClassificationModel(BaseFairseqModel): |
|
def __init__(self, cfg: Data2VecTextClassificationConfig): |
|
super().__init__() |
|
self.cfg = cfg |
|
|
|
if cfg.pretrained_model_args is None: |
|
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {}) |
|
pretrained_args = state.get("cfg", None) |
|
pretrained_args.criterion = None |
|
pretrained_args.lr_scheduler = None |
|
cfg.pretrained_model_args = pretrained_args |
|
|
|
logger.info(pretrained_args) |
|
else: |
|
state = None |
|
pretrained_args = cfg.pretrained_model_args |
|
|
|
task = tasks.setup_task(pretrained_args.task) |
|
model = task.build_model(pretrained_args.model, from_checkpoint=True) |
|
|
|
model.remove_pretraining_modules() |
|
|
|
self.model = model |
|
|
|
if state is not None and not cfg.no_pretrained_weights: |
|
self.load_model_weights(state, model, cfg) |
|
|
|
self.classification_heads = nn.ModuleDict() |
|
|
|
|
|
def load_model_weights(self, state, model, cfg): |
|
for k in list(state["model"].keys()): |
|
if ( |
|
k.startswith("shared_decoder") or |
|
k.startswith("_ema") or |
|
"decoder" in k |
|
): |
|
logger.info(f"Deleting {k} from checkpoint") |
|
del state["model"][k] |
|
model.load_state_dict(state["model"], strict=True) |
|
|
|
@classmethod |
|
def build_model(cls, cfg: Data2VecTextClassificationConfig, task=None): |
|
"""Build a new model instance.""" |
|
|
|
return cls(cfg) |
|
|
|
def register_classification_head( |
|
self, name, num_classes=None, inner_dim=None, **kwargs |
|
): |
|
"""Register a classification head.""" |
|
if name in self.classification_heads: |
|
prev_num_classes = self.classification_heads[name].out_proj.out_features |
|
prev_inner_dim = self.classification_heads[name].dense.out_features |
|
if num_classes != prev_num_classes or inner_dim != prev_inner_dim: |
|
logger.warning( |
|
're-registering head "{}" with num_classes {} (prev: {}) ' |
|
"and inner_dim {} (prev: {})".format( |
|
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim |
|
) |
|
) |
|
embed_dim = self.cfg.pretrained_model_args.model.embed_dim |
|
self.classification_heads[name] = RobertaClassificationHead( |
|
input_dim=embed_dim, |
|
inner_dim=inner_dim or embed_dim, |
|
num_classes=num_classes, |
|
activation_fn=self.cfg.pooler_activation_fn, |
|
pooler_dropout=self.cfg.pooler_dropout, |
|
q_noise=self.cfg.quant_noise_pq, |
|
qn_block_size=self.cfg.quant_noise_pq_block_size, |
|
do_spectral_norm=self.cfg.spectral_norm_classification_head, |
|
) |
|
|
|
def forward( |
|
self, |
|
source, |
|
id, |
|
padding_mask, |
|
features_only=True, |
|
remove_extra_tokens=True, |
|
classification_head_name=None, |
|
): |
|
encoder_out = self.model( |
|
source, |
|
id=id, |
|
mode=Modality.TEXT, |
|
padding_mask=padding_mask, |
|
mask=False, |
|
features_only=features_only, |
|
remove_extra_tokens=remove_extra_tokens |
|
) |
|
logits = self.classification_heads[classification_head_name](encoder_out["x"]) |
|
return logits, encoder_out |
|
|