Text-to-Image
Diffusers
Safetensors
tolgacangoz's picture
Upload 7 files
2e73038 verified
raw
history blame
1.73 kB
from torch import nn
from .RecCTCHead import CTCHead
from .RecMv1_enhance import MobileNetV1Enhance
from .RNN import Im2Im, Im2Seq, SequenceEncoder
backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance}
neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im}
head_dict = {"CTCHead": CTCHead}
class RecModel(nn.Module):
def __init__(self, config):
super().__init__()
assert "in_channels" in config, "in_channels must in model config"
backbone_type = config.backbone.pop("type")
assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
neck_type = config.neck.pop("type")
assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
head_type = config.head.pop("type")
assert head_type in head_dict, f"head.type must in {head_dict}"
self.head = head_dict[head_type](self.neck.out_channels, **config.head)
self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
def load_3rd_state_dict(self, _3rd_name, _state):
self.backbone.load_3rd_state_dict(_3rd_name, _state)
self.neck.load_3rd_state_dict(_3rd_name, _state)
self.head.load_3rd_state_dict(_3rd_name, _state)
def forward(self, x):
import torch
x = x.to(torch.float32)
x = self.backbone(x)
x = self.neck(x)
x = self.head(x)
return x
def encode(self, x):
x = self.backbone(x)
x = self.neck(x)
x = self.head.ctc_encoder(x)
return x