PyTorch
ssl-aasist
custom_code
ssl-aasist / fairseq /examples /data2vec /tasks /audio_classification.py
ash56's picture
Add files using upload-large-folder tool
010952f verified
raw
history blame
5.68 kB
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import os
import numpy as np
import math
import torch
from sklearn import metrics as sklearn_metrics
from dataclasses import dataclass
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
from fairseq.tasks import register_task
from fairseq.logging import metrics
from ..data.add_class_target_dataset import AddClassTargetDataset
logger = logging.getLogger(__name__)
@dataclass
class AudioClassificationConfig(AudioPretrainingConfig):
label_descriptors: str = "label_descriptors.csv"
labels: str = "lbl"
@register_task("audio_classification", dataclass=AudioClassificationConfig)
class AudioClassificationTask(AudioPretrainingTask):
""" """
cfg: AudioClassificationConfig
def __init__(
self,
cfg: AudioClassificationConfig,
):
super().__init__(cfg)
self.state.add_factory("labels", self.load_labels)
def load_labels(self):
labels = {}
path = os.path.join(self.cfg.data, self.cfg.label_descriptors)
with open(path, "r") as ldf:
for line in ldf:
if line.strip() == "":
continue
items = line.split(",")
idx = items[0]
lbl = items[1]
assert lbl not in labels, lbl
labels[lbl] = idx
return labels
@property
def labels(self):
return self.state.labels
def load_dataset(
self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs
):
super().load_dataset(split, task_cfg, **kwargs)
task_cfg = task_cfg or self.cfg
data_path = self.cfg.data
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
labels = []
with open(label_path, "r") as f:
for i, line in enumerate(f):
if i not in skipped_indices:
lbl_items = line.rstrip().split("\t")
labels.append([int(x) for x in lbl_items[2].split(",")])
assert len(labels) == len(self.datasets[split]), (
f"labels length ({len(labels)}) and dataset length "
f"({len(self.datasets[split])}) do not match"
)
self.datasets[split] = AddClassTargetDataset(
self.datasets[split],
labels,
multi_class=True,
add_to_input=True,
num_classes=len(self.labels),
)
def calculate_stats(self, output, target):
classes_num = target.shape[-1]
stats = []
# Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
# acc = sklearn_metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
# Class-wise statistics
for k in range(classes_num):
# Average precision
avg_precision = sklearn_metrics.average_precision_score(
target[:, k], output[:, k], average=None
)
dict = {
"AP": avg_precision,
}
# # AUC
# try:
# auc = sklearn_metrics.roc_auc_score(target[:, k], output[:, k], average=None)
# except:
# auc = 0
#
# # Precisions, recalls
# (precisions, recalls, thresholds) = sklearn_metrics.precision_recall_curve(
# target[:, k], output[:, k]
# )
#
# # FPR, TPR
# (fpr, tpr, thresholds) = sklearn_metrics.roc_curve(target[:, k], output[:, k])
#
# save_every_steps = 1000 # Sample statistics to reduce size
# dict = {
# "precisions": precisions[0::save_every_steps],
# "recalls": recalls[0::save_every_steps],
# "AP": avg_precision,
# "fpr": fpr[0::save_every_steps],
# "fnr": 1.0 - tpr[0::save_every_steps],
# "auc": auc,
# # note acc is not class-wise, this is just to keep consistent with other metrics
# "acc": acc,
# }
stats.append(dict)
return stats
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
return loss, sample_size, logging_output
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if "_predictions" in logging_outputs[0]:
metrics.log_concat_tensor(
"_predictions",
torch.cat([l["_predictions"].cpu() for l in logging_outputs], dim=0),
)
metrics.log_concat_tensor(
"_targets",
torch.cat([l["_targets"].cpu() for l in logging_outputs], dim=0),
)
def compute_stats(meters):
if meters["_predictions"].tensor.shape[0] < 100:
return 0
stats = self.calculate_stats(
meters["_predictions"].tensor, meters["_targets"].tensor
)
return np.nanmean([stat["AP"] for stat in stats])
metrics.log_derived("mAP", compute_stats)