File size: 4,718 Bytes
be2715b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import timm
import functools
import torch.utils.model_zoo as model_zoo
from .resnet import resnet_encoders
from .dpn import dpn_encoders
from .vgg import vgg_encoders
from .senet import senet_encoders
from .densenet import densenet_encoders
from .inceptionresnetv2 import inceptionresnetv2_encoders
from .inceptionv4 import inceptionv4_encoders
from .efficientnet import efficient_net_encoders
from .mobilenet import mobilenet_encoders
from .xception import xception_encoders
from .timm_efficientnet import timm_efficientnet_encoders
from .timm_resnest import timm_resnest_encoders
from .timm_res2net import timm_res2net_encoders
from .timm_regnet import timm_regnet_encoders
from .timm_sknet import timm_sknet_encoders
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
from .timm_gernet import timm_gernet_encoders
from .mix_transformer import mix_transformer_encoders
from .mobileone import mobileone_encoders
from .timm_universal import TimmUniversalEncoder
from ._preprocessing import preprocess_input
encoders = {}
encoders.update(resnet_encoders)
encoders.update(dpn_encoders)
encoders.update(vgg_encoders)
encoders.update(senet_encoders)
encoders.update(densenet_encoders)
encoders.update(inceptionresnetv2_encoders)
encoders.update(inceptionv4_encoders)
encoders.update(efficient_net_encoders)
encoders.update(mobilenet_encoders)
encoders.update(xception_encoders)
encoders.update(timm_efficientnet_encoders)
encoders.update(timm_resnest_encoders)
encoders.update(timm_res2net_encoders)
encoders.update(timm_regnet_encoders)
encoders.update(timm_sknet_encoders)
encoders.update(timm_mobilenetv3_encoders)
encoders.update(timm_gernet_encoders)
encoders.update(mix_transformer_encoders)
encoders.update(mobileone_encoders)
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
if name.startswith("tu-"):
name = name[3:]
encoder = TimmUniversalEncoder(
name=name,
in_channels=in_channels,
depth=depth,
output_stride=output_stride,
pretrained=weights is not None,
**kwargs,
)
return encoder
try:
Encoder = encoders[name]["encoder"]
except KeyError:
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
params = encoders[name]["params"]
params.update(depth=depth)
encoder = Encoder(**params)
if weights is not None:
try:
settings = encoders[name]["pretrained_settings"][weights]
except KeyError:
raise KeyError(
"Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
weights,
name,
list(encoders[name]["pretrained_settings"].keys()),
)
)
# encoder.load_state_dict(model_zoo.load_url(settings["url"]))
try:
if 'lvmmed' in settings["url"]:
print(settings['url'])
path = settings['url']
import torch
weights = torch.load(path, map_location = 'cpu')
except KeyError:
raise KeyError(
"Pretrained weights not exist")
encoder.load_state_dict(weights)
encoder.set_in_channels(in_channels, pretrained=weights is not None)
if output_stride != 32:
encoder.make_dilated(output_stride)
return encoder
def get_encoder_names():
return list(encoders.keys())
def get_preprocessing_params(encoder_name, pretrained="imagenet"):
if encoder_name.startswith("tu-"):
encoder_name = encoder_name[3:]
if not timm.models.is_model_pretrained(encoder_name):
raise ValueError(f"{encoder_name} does not have pretrained weights and preprocessing parameters")
settings = timm.models.get_pretrained_cfg(encoder_name).__dict__
else:
all_settings = encoders[encoder_name]["pretrained_settings"]
if pretrained not in all_settings.keys():
raise ValueError("Available pretrained options {}".format(all_settings.keys()))
settings = all_settings[pretrained]
formatted_settings = {}
formatted_settings["input_space"] = settings.get("input_space", "RGB")
formatted_settings["input_range"] = list(settings.get("input_range", [0, 1]))
formatted_settings["mean"] = list(settings["mean"])
formatted_settings["std"] = list(settings["std"])
return formatted_settings
def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
params = get_preprocessing_params(encoder_name, pretrained=pretrained)
return functools.partial(preprocess_input, **params)
|