|
import argparse |
|
import sys |
|
from datetime import datetime |
|
import os |
|
|
|
|
|
import copy |
|
import shutil |
|
|
|
from natsort import natsorted |
|
from collections import OrderedDict |
|
import torch |
|
|
|
from mmdet.datasets import build_dataset, CocoDataset |
|
from mmdet.datasets.api_wrappers import COCO |
|
from mmdet.datasets.builder import DATASETS |
|
from mmdet.models import build_detector |
|
from mmdet.apis import train_detector |
|
from base_config_track import get_config |
|
|
|
@DATASETS.register_module() |
|
class CocoDatasetSubset(CocoDataset): |
|
""" |
|
A subclass of MMDetection's default COCO dataset which has the ability |
|
to take the first or last n% of the original dataset. Set either |
|
take_first_percent or take_last_percent to a value greater than 0. |
|
""" |
|
def __init__(self, *args, take_first_percent=-1, take_last_percent=-1, **kwargs): |
|
self.take_first_percent = take_first_percent |
|
self.take_last_percent = take_last_percent |
|
super().__init__(*args, **kwargs) |
|
|
|
def load_annotations(self, ann_file): |
|
"""Load annotation from COCO style annotation file. |
|
|
|
Args: |
|
ann_file (str): Path of annotation file. |
|
|
|
Returns: |
|
list[dict]: Annotation info from COCO api. |
|
""" |
|
assert self.take_first_percent > 0 or self.take_last_percent > 0, f'take_first_percent: {self.take_first_percent}, take_last_percent: {self.take_first_percent}' |
|
assert(self.take_first_percent > 0 if self.take_last_percent <= 0 else self.take_first_percent <= 0) |
|
|
|
self.coco = COCO(ann_file) |
|
|
|
|
|
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) |
|
|
|
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} |
|
self.img_ids = self.coco.get_img_ids() |
|
|
|
original_count = len(self.img_ids) |
|
|
|
|
|
if self.take_first_percent > 0: |
|
first_n = True |
|
count = int(len(self.img_ids) * self.take_first_percent) |
|
self.img_ids = self.img_ids[:count] |
|
elif self.take_last_percent > 0: |
|
first_n = False |
|
count = int(len(self.img_ids) * self.take_last_percent) |
|
self.img_ids = self.img_ids[-count:] |
|
|
|
new_count = len(self.img_ids) |
|
|
|
print(f'Taking {"first" if first_n else "last"} {new_count} of original dataset ({original_count}), ({(new_count / original_count) * 100})%') |
|
|
|
data_infos = [] |
|
total_ann_ids = [] |
|
for i in self.img_ids: |
|
info = self.coco.load_imgs([i])[0] |
|
info['filename'] = info['file_name'] |
|
data_infos.append(info) |
|
ann_ids = self.coco.get_ann_ids(img_ids=[i]) |
|
total_ann_ids.extend(ann_ids) |
|
assert len(set(total_ann_ids)) == len( |
|
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!" |
|
return data_infos |
|
|
|
|
|
def get_training_datasets(labeled_dataset_percent, base_directory = '.'): |
|
cfg = get_config(base_directory) |
|
cfg.data.train['dataset']['take_last_percent'] = labeled_dataset_percent |
|
dataset_finetune = build_dataset(cfg.data.train) |
|
|
|
if labeled_dataset_percent < 1: |
|
cfg.data.train['dataset']['take_last_percent'] = -1 |
|
cfg.data.train['dataset']['take_first_percent'] = 1 - labeled_dataset_percent |
|
dataset_pretrain = build_dataset(cfg.data.train) |
|
else: |
|
dataset_pretrain = None |
|
|
|
return dataset_pretrain, dataset_finetune |
|
|
|
def train(experiment_name, weight_path, labeled_dataset_percent, epochs, batch_size, optim, clip, lr, resume): |
|
cfg = get_config() |
|
cfg.total_epochs = epochs |
|
cfg.runner.max_epochs = epochs |
|
cfg.data.samples_per_gpu = batch_size |
|
|
|
if optim=='adam': |
|
cfg.optimizer = dict(type='Adam', lr=lr, weight_decay=0.0001) |
|
else: |
|
cfg.optimizer = dict(type='SGD', lr=lr, momentum=0.9, weight_decay=0.0001) |
|
|
|
if clip: |
|
cfg.optimizer_config = dict(grad_clip=dict(max_norm=clip, norm_type=2)) |
|
else: |
|
cfg.optimizer_config = dict(grad_clip=None) |
|
|
|
cfg.work_dir += '/' + experiment_name |
|
|
|
logs_folder = os.path.join(cfg.work_dir, 'tf_logs') |
|
|
|
if resume: |
|
checkpoints = os.listdir(cfg.work_dir) |
|
checkpoints = natsorted(checkpoints) |
|
checkpoints = [p for p in checkpoints if 'epoch_' in p] |
|
checkpoint = os.path.join(cfg.work_dir, checkpoints[-1]) |
|
cfg.resume_from = checkpoint |
|
print ('initialize learning rate again') |
|
cfg.optimizer.lr = lr |
|
print (cfg.optimizer) |
|
else: |
|
if (os.path.exists(logs_folder)): |
|
shutil.rmtree(logs_folder) |
|
|
|
print(cfg.model.backbone.init_cfg) |
|
|
|
if (os.path.exists(weight_path)): |
|
state_dict = torch.load(weight_path) |
|
new_state_dict = OrderedDict() |
|
for k, v in state_dict.items(): |
|
name = 'backbone.' + k |
|
new_state_dict[name] = v |
|
torch.save(new_state_dict, 'tmp.pth') |
|
cfg.load_from = 'tmp.pth' |
|
print('Loading pretrained backbone from ' + weight_path) |
|
|
|
_, train_dataset = get_training_datasets(labeled_dataset_percent) |
|
|
|
model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg')) |
|
datasets = [train_dataset] |
|
cfg.workflow = [('train', 1)] |
|
cfg.device = 'cuda' |
|
|
|
train_detector(model, datasets, cfg, distributed=False, validate=True) |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Train using MMDet and Lightly SSL') |
|
parser.add_argument('--experiment-name', default='no-exp') |
|
parser.add_argument('--weight-path', type=str, required=True) |
|
parser.add_argument('--labeled-dataset-percent', type=float, default=1) |
|
parser.add_argument( |
|
'--epochs', |
|
type=int, |
|
default=100, |
|
help='number of epochs to train', |
|
) |
|
parser.add_argument( |
|
'--batch-size', |
|
type=int, |
|
default=6, |
|
) |
|
parser.add_argument( |
|
'--optim', |
|
type=str, |
|
default='sgd', |
|
) |
|
parser.add_argument( |
|
'--clip', |
|
type=float, |
|
default=0, |
|
) |
|
parser.add_argument( |
|
'--lr', |
|
type=float, |
|
default=0.02 / 8, |
|
) |
|
parser.add_argument( |
|
'--resume', |
|
default=False, |
|
action='store_true', |
|
help='resume training from last checkpoint in work dir' |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
def main(): |
|
args = parse_args() |
|
train(**vars(args)) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
os.remove('tmp.pth') |
|
|