Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from . import functional as F | |
__all__ = [ 'BinaryHeatmap2Coordinate' ] | |
class BinaryHeatmap2Coordinate(nn.Module): | |
"""BinaryHeatmap2Coordinate | |
""" | |
def __init__(self, stride=4.0, topk=5, **kwargs): | |
super(BinaryHeatmap2Coordinate, self).__init__() | |
self.topk = topk | |
self.stride = stride | |
def forward(self, input): | |
return self.stride * F.heatmap2coord(input[:,1,...], self.topk) | |
def __repr__(self): | |
format_string = self.__class__.__name__ + '(' | |
format_string += 'topk={}, '.format(self.topk) | |
format_string += 'stride={}'.format(self.stride) | |
format_string += ')' | |
return format_string | |