LVM-Med / object_detection /finetune_with_path_modify_test_eval.py
duynhm's picture
Initial commit
be2715b
raw
history blame
6.2 kB
import argparse
import sys
from datetime import datetime
import os
# os.chdir('/home/caduser/KOTORI/vin-ssl/source')
# sys.path.append('/home/caduser/KOTORI/vin-ssl/source')
import copy
import shutil
from natsort import natsorted
from collections import OrderedDict
import torch
from mmdet.datasets import build_dataset, CocoDataset
from mmdet.datasets.api_wrappers import COCO
from mmdet.datasets.builder import DATASETS
from mmdet.models import build_detector
from mmdet.apis import train_detector
from base_config_track import get_config
@DATASETS.register_module()
class CocoDatasetSubset(CocoDataset):
"""
A subclass of MMDetection's default COCO dataset which has the ability
to take the first or last n% of the original dataset. Set either
take_first_percent or take_last_percent to a value greater than 0.
"""
def __init__(self, *args, take_first_percent=-1, take_last_percent=-1, **kwargs):
self.take_first_percent = take_first_percent
self.take_last_percent = take_last_percent
super().__init__(*args, **kwargs)
def load_annotations(self, ann_file):
"""Load annotation from COCO style annotation file.
Args:
ann_file (str): Path of annotation file.
Returns:
list[dict]: Annotation info from COCO api.
"""
assert self.take_first_percent > 0 or self.take_last_percent > 0, f'take_first_percent: {self.take_first_percent}, take_last_percent: {self.take_first_percent}'
assert(self.take_first_percent > 0 if self.take_last_percent <= 0 else self.take_first_percent <= 0)
self.coco = COCO(ann_file)
# The order of returned `cat_ids` will not
# change with the order of the CLASSES
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.img_ids = self.coco.get_img_ids()
original_count = len(self.img_ids)
# make a subset
if self.take_first_percent > 0:
first_n = True
count = int(len(self.img_ids) * self.take_first_percent)
self.img_ids = self.img_ids[:count]
elif self.take_last_percent > 0:
first_n = False
count = int(len(self.img_ids) * self.take_last_percent)
self.img_ids = self.img_ids[-count:]
new_count = len(self.img_ids)
print(f'Taking {"first" if first_n else "last"} {new_count} of original dataset ({original_count}), ({(new_count / original_count) * 100})%')
data_infos = []
total_ann_ids = []
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
ann_ids = self.coco.get_ann_ids(img_ids=[i])
total_ann_ids.extend(ann_ids)
assert len(set(total_ann_ids)) == len(
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
return data_infos
def get_training_datasets(labeled_dataset_percent, base_directory = '.'):
cfg = get_config(base_directory)
cfg.data.train['dataset']['take_last_percent'] = labeled_dataset_percent
dataset_finetune = build_dataset(cfg.data.train)
if labeled_dataset_percent < 1:
cfg.data.train['dataset']['take_last_percent'] = -1
cfg.data.train['dataset']['take_first_percent'] = 1 - labeled_dataset_percent
dataset_pretrain = build_dataset(cfg.data.train)
else:
dataset_pretrain = None
return dataset_pretrain, dataset_finetune
def train(experiment_name, weight_path, labeled_dataset_percent, epochs, batch_size, optim, clip, lr, resume):
cfg = get_config()
cfg.total_epochs = epochs
cfg.runner.max_epochs = epochs
cfg.data.samples_per_gpu = batch_size
if optim=='adam':
cfg.optimizer = dict(type='Adam', lr=lr, weight_decay=0.0001)
else:
cfg.optimizer = dict(type='SGD', lr=lr, momentum=0.9, weight_decay=0.0001)
if clip:
cfg.optimizer_config = dict(grad_clip=dict(max_norm=clip, norm_type=2))
else:
cfg.optimizer_config = dict(grad_clip=None)
cfg.work_dir += '/' + experiment_name
logs_folder = os.path.join(cfg.work_dir, 'tf_logs')
if resume:
checkpoints = os.listdir(cfg.work_dir)
checkpoints = natsorted(checkpoints)
checkpoints = [p for p in checkpoints if 'epoch_' in p]
checkpoint = os.path.join(cfg.work_dir, checkpoints[-1])
cfg.resume_from = checkpoint
print ('initialize learning rate again')
cfg.optimizer.lr = lr
print (cfg.optimizer)
else:
if (os.path.exists(logs_folder)):
shutil.rmtree(logs_folder)
print(cfg.model.backbone.init_cfg)
if (os.path.exists(weight_path)):
state_dict = torch.load(weight_path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = 'backbone.' + k
new_state_dict[name] = v
torch.save(new_state_dict, 'tmp.pth')
cfg.load_from = 'tmp.pth'
print('Loading pretrained backbone from ' + weight_path)
_, train_dataset = get_training_datasets(labeled_dataset_percent)
model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'))
datasets = [train_dataset]
cfg.workflow = [('train', 1)]
cfg.device = 'cuda'
# train model
train_detector(model, datasets, cfg, distributed=False, validate=True)
def parse_args():
parser = argparse.ArgumentParser(description='Train using MMDet and Lightly SSL')
parser.add_argument('--experiment-name', default='no-exp')
parser.add_argument('--weight-path', type=str, required=True)
parser.add_argument('--labeled-dataset-percent', type=float, default=1)
parser.add_argument(
'--epochs',
type=int,
default=100,
help='number of epochs to train',
)
parser.add_argument(
'--batch-size',
type=int,
default=6,
)
parser.add_argument(
'--optim',
type=str,
default='sgd',
)
parser.add_argument(
'--clip',
type=float,
default=0,
)
parser.add_argument(
'--lr',
type=float,
default=0.02 / 8,
)
parser.add_argument(
'--resume',
default=False,
action='store_true',
help='resume training from last checkpoint in work dir'
)
args = parser.parse_args()
return args
def main():
args = parse_args()
train(**vars(args))
if __name__ == '__main__':
main()
os.remove('tmp.pth')