PyTorch
ssl-aasist
custom_code
File size: 4,648 Bytes
9043f3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import logging
import os
import tempfile
import unittest
from io import StringIO
from unittest.mock import patch

from fairseq import checkpoint_utils
from tests.utils import (
    create_dummy_data,
    preprocess_translation_data,
    train_translation_model,
)
import torch


class TestCheckpointUtils(unittest.TestCase):
    def setUp(self):
        logging.disable(logging.CRITICAL)

    def tearDown(self):
        logging.disable(logging.NOTSET)

    @contextlib.contextmanager
    def _train_transformer(self, seed, extra_args=None):
        if extra_args is None:
            extra_args = []
        with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir:
            create_dummy_data(data_dir)
            preprocess_translation_data(data_dir)
            train_translation_model(
                data_dir,
                "transformer_iwslt_de_en",
                [
                    "--encoder-layers",
                    "3",
                    "--decoder-layers",
                    "3",
                    "--encoder-embed-dim",
                    "8",
                    "--decoder-embed-dim",
                    "8",
                    "--seed",
                    str(seed),
                ]
                + extra_args,
            )
            yield os.path.join(data_dir, "checkpoint_last.pt")

    def test_load_model_ensemble_and_task(self):
        # with contextlib.redirect_stdout(StringIO()):
        with self._train_transformer(seed=123) as model1:
            with self._train_transformer(seed=456) as model2:
                ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
                    filenames=[model1, model2]
                )
                self.assertEqual(len(ensemble), 2)

                # after Transformer has been migrated to Hydra, this will probably
                # become cfg.common.seed
                self.assertEqual(ensemble[0].args.seed, 123)
                self.assertEqual(ensemble[1].args.seed, 456)

                # the task from the first model should be returned
                self.assertTrue("seed123" in task.cfg.data)

                # last cfg is saved
                self.assertEqual(cfg.common.seed, 456)

    def test_prune_state_dict(self):
        with contextlib.redirect_stdout(StringIO()):
            extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"]
            with self._train_transformer(seed=1, extra_args=extra_args) as model:
                ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
                    filenames=[model],
                    arg_overrides={
                        "encoder_layers_to_keep": "0,2",
                        "decoder_layers_to_keep": "1",
                    },
                )
                self.assertEqual(len(ensemble), 1)
                self.assertEqual(len(ensemble[0].encoder.layers), 2)
                self.assertEqual(len(ensemble[0].decoder.layers), 1)

    def test_torch_persistent_save_async(self):
        state_dict = {}
        filename = "async_checkpoint.pt"

        with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena:
            with patch(
                f"{checkpoint_utils.__name__}._torch_persistent_save"
            ) as mock_save:
                checkpoint_utils.torch_persistent_save(
                    state_dict, filename, async_write=True
                )
                mock_opena.assert_called_with(filename, "wb")
                mock_save.assert_called()

    def test_load_ema_from_checkpoint(self):
        dummy_state = {"a": torch.tensor([1]), "b": torch.tensor([0.1])}
        with patch(f"{checkpoint_utils.__name__}.PathManager.open") as mock_open, patch(
            f"{checkpoint_utils.__name__}.torch.load"
        ) as mock_load:

            mock_load.return_value = {"extra_state": {"ema": dummy_state}}
            filename = "ema_checkpoint.pt"
            state = checkpoint_utils.load_ema_from_checkpoint(filename)

            mock_open.assert_called_with(filename, "rb")
            mock_load.assert_called()

            self.assertIn("a", state["model"])
            self.assertIn("b", state["model"])
            self.assertTrue(torch.allclose(dummy_state["a"], state["model"]["a"]))
            self.assertTrue(torch.allclose(dummy_state["b"], state["model"]["b"]))


if __name__ == "__main__":
    unittest.main()