|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import numpy as np |
|
import math |
|
import torch |
|
|
|
from sklearn import metrics as sklearn_metrics |
|
from dataclasses import dataclass |
|
|
|
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig |
|
from fairseq.tasks import register_task |
|
from fairseq.logging import metrics |
|
|
|
from ..data.add_class_target_dataset import AddClassTargetDataset |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class AudioClassificationConfig(AudioPretrainingConfig): |
|
label_descriptors: str = "label_descriptors.csv" |
|
labels: str = "lbl" |
|
|
|
|
|
@register_task("audio_classification", dataclass=AudioClassificationConfig) |
|
class AudioClassificationTask(AudioPretrainingTask): |
|
""" """ |
|
|
|
cfg: AudioClassificationConfig |
|
|
|
def __init__( |
|
self, |
|
cfg: AudioClassificationConfig, |
|
): |
|
super().__init__(cfg) |
|
|
|
self.state.add_factory("labels", self.load_labels) |
|
|
|
def load_labels(self): |
|
labels = {} |
|
path = os.path.join(self.cfg.data, self.cfg.label_descriptors) |
|
with open(path, "r") as ldf: |
|
for line in ldf: |
|
if line.strip() == "": |
|
continue |
|
items = line.split(",") |
|
idx = items[0] |
|
lbl = items[1] |
|
assert lbl not in labels, lbl |
|
labels[lbl] = idx |
|
return labels |
|
|
|
@property |
|
def labels(self): |
|
return self.state.labels |
|
|
|
def load_dataset( |
|
self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs |
|
): |
|
super().load_dataset(split, task_cfg, **kwargs) |
|
|
|
task_cfg = task_cfg or self.cfg |
|
|
|
data_path = self.cfg.data |
|
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") |
|
skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) |
|
labels = [] |
|
with open(label_path, "r") as f: |
|
for i, line in enumerate(f): |
|
if i not in skipped_indices: |
|
lbl_items = line.rstrip().split("\t") |
|
labels.append([int(x) for x in lbl_items[2].split(",")]) |
|
|
|
assert len(labels) == len(self.datasets[split]), ( |
|
f"labels length ({len(labels)}) and dataset length " |
|
f"({len(self.datasets[split])}) do not match" |
|
) |
|
|
|
self.datasets[split] = AddClassTargetDataset( |
|
self.datasets[split], |
|
labels, |
|
multi_class=True, |
|
add_to_input=True, |
|
num_classes=len(self.labels), |
|
) |
|
|
|
def calculate_stats(self, output, target): |
|
|
|
classes_num = target.shape[-1] |
|
stats = [] |
|
|
|
|
|
|
|
|
|
|
|
for k in range(classes_num): |
|
|
|
avg_precision = sklearn_metrics.average_precision_score( |
|
target[:, k], output[:, k], average=None |
|
) |
|
|
|
dict = { |
|
"AP": avg_precision, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stats.append(dict) |
|
|
|
return stats |
|
|
|
def valid_step(self, sample, model, criterion): |
|
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) |
|
return loss, sample_size, logging_output |
|
|
|
def reduce_metrics(self, logging_outputs, criterion): |
|
super().reduce_metrics(logging_outputs, criterion) |
|
if "_predictions" in logging_outputs[0]: |
|
metrics.log_concat_tensor( |
|
"_predictions", |
|
torch.cat([l["_predictions"].cpu() for l in logging_outputs], dim=0), |
|
) |
|
metrics.log_concat_tensor( |
|
"_targets", |
|
torch.cat([l["_targets"].cpu() for l in logging_outputs], dim=0), |
|
) |
|
|
|
def compute_stats(meters): |
|
if meters["_predictions"].tensor.shape[0] < 100: |
|
return 0 |
|
stats = self.calculate_stats( |
|
meters["_predictions"].tensor, meters["_targets"].tensor |
|
) |
|
return np.nanmean([stat["AP"] for stat in stats]) |
|
|
|
metrics.log_derived("mAP", compute_stats) |
|
|