PyTorch
ssl-aasist
custom_code
ssl-aasist / fairseq /examples /data2vec /models /data2vec_text_classification.py
ash56's picture
Add files using upload-large-folder tool
6789f6f verified
raw
history blame
4.74 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
from dataclasses import dataclass
from typing import Any
from omegaconf import II, MISSING
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, tasks
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.roberta.model import RobertaClassificationHead
from examples.data2vec.data.modality import Modality
logger = logging.getLogger(__name__)
@dataclass
class Data2VecTextClassificationConfig(FairseqDataclass):
pooler_dropout: float = 0.0
pooler_activation_fn: str = "tanh"
quant_noise_pq: int = 0
quant_noise_pq_block_size: int = 8
spectral_norm_classification_head: bool = False
model_path: str = MISSING
no_pretrained_weights: bool = False
pretrained_model_args: Any = None
@register_model(
"data2vec_text_classification", dataclass=Data2VecTextClassificationConfig
)
class Data2VecTextClassificationModel(BaseFairseqModel):
def __init__(self, cfg: Data2VecTextClassificationConfig):
super().__init__()
self.cfg = cfg
if cfg.pretrained_model_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
pretrained_args = state.get("cfg", None)
pretrained_args.criterion = None
pretrained_args.lr_scheduler = None
cfg.pretrained_model_args = pretrained_args
logger.info(pretrained_args)
else:
state = None
pretrained_args = cfg.pretrained_model_args
task = tasks.setup_task(pretrained_args.task)
model = task.build_model(pretrained_args.model, from_checkpoint=True)
model.remove_pretraining_modules()
self.model = model
if state is not None and not cfg.no_pretrained_weights:
self.load_model_weights(state, model, cfg)
self.classification_heads = nn.ModuleDict()
def load_model_weights(self, state, model, cfg):
for k in list(state["model"].keys()):
if (
k.startswith("shared_decoder") or
k.startswith("_ema") or
"decoder" in k
):
logger.info(f"Deleting {k} from checkpoint")
del state["model"][k]
model.load_state_dict(state["model"], strict=True)
@classmethod
def build_model(cls, cfg: Data2VecTextClassificationConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
embed_dim = self.cfg.pretrained_model_args.model.embed_dim
self.classification_heads[name] = RobertaClassificationHead(
input_dim=embed_dim,
inner_dim=inner_dim or embed_dim,
num_classes=num_classes,
activation_fn=self.cfg.pooler_activation_fn,
pooler_dropout=self.cfg.pooler_dropout,
q_noise=self.cfg.quant_noise_pq,
qn_block_size=self.cfg.quant_noise_pq_block_size,
do_spectral_norm=self.cfg.spectral_norm_classification_head,
)
def forward(
self,
source,
id,
padding_mask,
features_only=True,
remove_extra_tokens=True,
classification_head_name=None,
):
encoder_out = self.model(
source,
id=id,
mode=Modality.TEXT,
padding_mask=padding_mask,
mask=False,
features_only=features_only,
remove_extra_tokens=remove_extra_tokens
)
logits = self.classification_heads[classification_head_name](encoder_out["x"])
return logits, encoder_out