weismart1807's picture
Upload folder using huggingface_hub
e90b704 verified
import torch
import torch.nn as nn
from . import functional as F
__all__ = [ 'BinaryHeatmap2Coordinate' ]
class BinaryHeatmap2Coordinate(nn.Module):
"""BinaryHeatmap2Coordinate
"""
def __init__(self, stride=4.0, topk=5, **kwargs):
super(BinaryHeatmap2Coordinate, self).__init__()
self.topk = topk
self.stride = stride
def forward(self, input):
return self.stride * F.heatmap2coord(input[:,1,...], self.topk)
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += 'topk={}, '.format(self.topk)
format_string += 'stride={}'.format(self.stride)
format_string += ')'
return format_string