|
import torch |
|
import numpy as np |
|
import unittest |
|
from fairseq.modules import ( |
|
ESPNETMultiHeadedAttention, |
|
RelPositionMultiHeadedAttention, |
|
RotaryPositionMultiHeadedAttention, |
|
) |
|
|
|
torch.use_deterministic_algorithms(True) |
|
|
|
|
|
class TestESPNETMultiHeadedAttention(unittest.TestCase): |
|
def setUp(self) -> None: |
|
self.T = 3 |
|
self.B = 1 |
|
self.C = 2 |
|
torch.manual_seed(0) |
|
self.sample = torch.randn(self.T, self.B, self.C) |
|
self.sample_scores = torch.randn(self.B, 1, self.T, self.T) |
|
self.MHA = ESPNETMultiHeadedAttention(self.C, 1, dropout=0) |
|
|
|
def test_forward(self): |
|
expected_scores = torch.tensor( |
|
[[[0.1713, -0.3776]], [[0.2263, -0.4486]], [[0.2243, -0.4538]]] |
|
) |
|
scores, _ = self.MHA(self.sample, self.sample, self.sample) |
|
self.assertTrue( |
|
np.allclose( |
|
expected_scores.cpu().detach().numpy(), |
|
scores.cpu().detach().numpy(), |
|
atol=1e-4, |
|
) |
|
) |
|
|
|
def test_forward_qkv(self): |
|
expected_query = torch.tensor( |
|
[[[[-1.0235, 0.0409], [0.4008, 1.3077], [0.5396, 2.0698]]]] |
|
) |
|
expected_key = torch.tensor( |
|
[[[[0.5053, -0.4965], [-0.3730, -0.9473], [-0.7019, -0.1935]]]] |
|
) |
|
expected_val = torch.tensor( |
|
[[[[-0.9940, 0.5403], [0.5924, -0.7619], [0.7504, -1.0892]]]] |
|
) |
|
sample_t = self.sample.transpose(0, 1) |
|
query, key, val = self.MHA.forward_qkv(sample_t, sample_t, sample_t) |
|
self.assertTrue( |
|
np.allclose( |
|
expected_query.cpu().detach().numpy(), |
|
query.cpu().detach().numpy(), |
|
atol=1e-4, |
|
) |
|
) |
|
self.assertTrue( |
|
np.allclose( |
|
expected_key.cpu().detach().numpy(), |
|
key.cpu().detach().numpy(), |
|
atol=1e-4, |
|
) |
|
) |
|
self.assertTrue( |
|
np.allclose( |
|
expected_val.cpu().detach().numpy(), |
|
val.cpu().detach().numpy(), |
|
atol=1e-4, |
|
) |
|
) |
|
|
|
def test_forward_attention(self): |
|
expected_scores = torch.tensor( |
|
[[[0.1627, -0.6249], [-0.2547, -0.6487], [-0.0711, -0.8545]]] |
|
) |
|
scores = self.MHA.forward_attention( |
|
self.sample.transpose(0, 1).view(self.B, 1, self.T, self.C), |
|
self.sample_scores, |
|
mask=None, |
|
) |
|
self.assertTrue( |
|
np.allclose( |
|
expected_scores.cpu().detach().numpy(), |
|
scores.cpu().detach().numpy(), |
|
atol=1e-4, |
|
) |
|
) |
|
|
|
|
|
class TestRelPositionMultiHeadedAttention(unittest.TestCase): |
|
def setUp(self) -> None: |
|
self.T = 3 |
|
self.B = 1 |
|
self.C = 2 |
|
torch.manual_seed(0) |
|
self.sample = torch.randn(self.T, self.B, self.C) |
|
self.sample_x = torch.randn(self.B, 1, self.T, self.T * 2 - 1) |
|
self.sample_pos = torch.randn(self.B, self.T * 2 - 1, self.C) |
|
self.MHA = RelPositionMultiHeadedAttention(self.C, 1, dropout=0) |
|
|
|
def test_rel_shift(self): |
|
expected_x = torch.tensor( |
|
[ |
|
[ |
|
[ |
|
[-0.7193, -0.4033, -0.5966], |
|
[-0.8567, 1.1006, -1.0712], |
|
[-0.5663, 0.3731, -0.8920], |
|
] |
|
] |
|
] |
|
) |
|
x = self.MHA.rel_shift(self.sample_x) |
|
self.assertTrue( |
|
np.allclose( |
|
expected_x.cpu().detach().numpy(), |
|
x.cpu().detach().numpy(), |
|
atol=1e-4, |
|
) |
|
) |
|
|
|
def test_forward(self): |
|
expected_scores = torch.tensor( |
|
[ |
|
[[-0.9609, -0.5020]], |
|
[[-0.9308, -0.4890]], |
|
[[-0.9473, -0.4948]], |
|
[[-0.9609, -0.5020]], |
|
[[-0.9308, -0.4890]], |
|
[[-0.9473, -0.4948]], |
|
[[-0.9609, -0.5020]], |
|
[[-0.9308, -0.4890]], |
|
[[-0.9473, -0.4948]], |
|
[[-0.9609, -0.5020]], |
|
[[-0.9308, -0.4890]], |
|
[[-0.9473, -0.4948]], |
|
[[-0.9609, -0.5020]], |
|
[[-0.9308, -0.4890]], |
|
[[-0.9473, -0.4948]], |
|
] |
|
) |
|
scores, _ = self.MHA(self.sample, self.sample, self.sample, self.sample_pos) |
|
self.assertTrue( |
|
np.allclose( |
|
expected_scores.cpu().detach().numpy(), |
|
scores.cpu().detach().numpy(), |
|
atol=1e-4, |
|
) |
|
) |
|
|
|
|
|
class TestRotaryPositionMultiHeadedAttention(unittest.TestCase): |
|
def setUp(self) -> None: |
|
self.T = 3 |
|
self.B = 1 |
|
self.C = 2 |
|
torch.manual_seed(0) |
|
self.sample = torch.randn(self.T, self.B, self.C) |
|
self.MHA = RotaryPositionMultiHeadedAttention( |
|
self.C, 1, dropout=0, precision=None |
|
) |
|
|
|
def test_forward(self): |
|
expected_scores = torch.tensor( |
|
[[[-0.3220, -0.4726]], [[-1.2813, -0.0979]], [[-0.3138, -0.4758]]] |
|
) |
|
scores, _ = self.MHA(self.sample, self.sample, self.sample) |
|
self.assertTrue( |
|
np.allclose( |
|
expected_scores.cpu().detach().numpy(), |
|
scores.cpu().detach().numpy(), |
|
atol=1e-4, |
|
) |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|