|
import torch |
|
|
|
|
|
def extend_instance(obj, mixin): |
|
"""Apply mixins to a class instance after creation""" |
|
base_cls = obj.__class__ |
|
base_cls_name = obj.__class__.__name__ |
|
obj.__class__ = type( |
|
base_cls_name, (mixin, base_cls), {} |
|
) |
|
|
|
|
|
def getattr_recursive(obj, att): |
|
""" |
|
Return nested attribute of obj |
|
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c |
|
""" |
|
if att == "": |
|
return obj |
|
i = att.find(".") |
|
if i < 0: |
|
return getattr(obj, att) |
|
else: |
|
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) |
|
|
|
|
|
def setattr_recursive(obj, att, val): |
|
""" |
|
Set nested attribute of obj |
|
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val |
|
""" |
|
if "." in att: |
|
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) |
|
setattr(obj, att.split(".")[-1], val) |
|
|
|
|
|
def apply_with_stopping_condition( |
|
module, apply_fn, apply_condition=None, stopping_condition=None, **other_args |
|
): |
|
if stopping_condition(module): |
|
return |
|
if apply_condition(module): |
|
apply_fn(module, **other_args) |
|
for child in module.children(): |
|
apply_with_stopping_condition( |
|
child, |
|
apply_fn, |
|
apply_condition=apply_condition, |
|
stopping_condition=stopping_condition, |
|
**other_args |
|
) |
|
|
|
|
|
def num_params(module, filter_to_trainable=False): |
|
"""Returns the number of parameters in the module, or optionally only the trainable parameters""" |
|
if filter_to_trainable: |
|
return sum(p.numel() for p in module.parameters() if p.requires_grad) |
|
else: |
|
return sum(p.numel() for p in module.parameters()) |
|
|
|
|
|
def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"): |
|
""" |
|
Stack a list of tensors with padding on one side |
|
Args: |
|
list_of_tensors (list[torch.Tensor]): List of tensors to stack |
|
padding_value (int, optional): Value to pad with. Defaults to 0. |
|
padding_side (str, optional): Side to pad on. Defaults to "right". |
|
Returns: |
|
torch.Tensor: Stacked tensors |
|
""" |
|
max_tokens = max(tensor.size(0) for tensor in list_of_tensors) |
|
padded_tensors = [] |
|
for tensor in list_of_tensors: |
|
num_tokens = tensor.size(0) |
|
if len(tensor.size()) == 1: |
|
padding = torch.full( |
|
(max_tokens - num_tokens,), |
|
padding_value, |
|
dtype=tensor.dtype, |
|
device=tensor.device, |
|
) |
|
else: |
|
padding = torch.full( |
|
(max_tokens - num_tokens, tensor.size(1)), |
|
padding_value, |
|
dtype=tensor.dtype, |
|
device=tensor.device, |
|
) |
|
padded_tensor = ( |
|
torch.cat((tensor, padding), dim=0) |
|
if padding_side == "right" |
|
else torch.cat((padding, tensor), dim=0) |
|
) |
|
padded_tensors.append(padded_tensor) |
|
return torch.stack(padded_tensors) |
|
|
|
|
|
def stack_with_padding_2D_attention(list_of_tensors): |
|
max_size = max(tensor.size(1) for tensor in list_of_tensors) |
|
|
|
padded_tensors = [] |
|
for tensor in list_of_tensors: |
|
a = tensor.shape[-1] |
|
padding = (0, max_size - a, 0, max_size - a) |
|
padded_tensor = torch.nn.functional.pad(tensor, padding) |
|
padded_tensors.append(padded_tensor) |
|
return torch.stack(padded_tensors) |