qninhdt commited on
Commit
78000ed
·
verified ·
1 Parent(s): bc79059

Upload 11 files

Browse files
CHANGELOG.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ ## [0.3.0] - 2023-01-05
4
+
5
+ ### Added
6
+
7
+ * Add argument `--save-stats` allowing to compute dataset statistics and save them as an `.npz` file ([#80](https://github.com/mseitzer/pytorch-fid/pull/80)). The `.npz` file can be used in subsequent FID computations instead of recomputing the dataset statistics. This option can be used in the following way: `python -m pytorch_fid --save-stats path/to/dataset path/to/outputfile`.
8
+
9
+ ### Fixed
10
+
11
+ * Do not use `os.sched_getaffinity` to get number of available CPUs on Windows, as it is not available there ([232b3b14](https://github.com/mseitzer/pytorch-fid/commit/232b3b1468800102fcceaf6f2bb8977811fc991a), [#84](https://github.com/mseitzer/pytorch-fid/issues/84)).
12
+ * Do not use Inception model argument `pretrained`, as it was deprecated in torchvision 0.13 ([#88](https://github.com/mseitzer/pytorch-fid/pull/88)).
13
+
14
+ ## [0.2.1] - 2021-10-10
15
+
16
+ ### Added
17
+
18
+ * Add argument `--num-workers` to select number of dataloader processes ([#66](https://github.com/mseitzer/pytorch-fid/pull/66)). Defaults to 8 or the number of available CPUs if less than 8 CPUs are available.
19
+
20
+ ### Fixed
21
+
22
+ * Fixed package setup to work under Windows ([#55](https://github.com/mseitzer/pytorch-fid/pull/55), [#72](https://github.com/mseitzer/pytorch-fid/issues/72))
23
+
24
+ ## [0.2.0] - 2020-11-30
25
+
26
+ ### Added
27
+
28
+ * Load images using a Pytorch dataloader, which should result in a speed-up. ([#47](https://github.com/mseitzer/pytorch-fid/pull/47))
29
+ * Support more image extensions ([#53](https://github.com/mseitzer/pytorch-fid/pull/53))
30
+ * Improve tooling by setting up Nox, add linting and test support ([#52](https://github.com/mseitzer/pytorch-fid/pull/52))
31
+ * Add some unit tests
32
+
33
+ ## [0.1.1] - 2020-08-16
34
+
35
+ ### Fixed
36
+
37
+ * Fixed software license string in `setup.py`
38
+
39
+ ## [0.1.0] - 2020-08-16
40
+
41
+ Initial release as a pypi package. Use `pip install pytorch-fid` to install.
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![PyPI](https://img.shields.io/pypi/v/pytorch-fid.svg)](https://pypi.org/project/pytorch-fid/)
2
+
3
+ # FID score for PyTorch
4
+
5
+ This is a port of the official implementation of [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500) to PyTorch.
6
+ See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR) for the original implementation using Tensorflow.
7
+
8
+ FID is a measure of similarity between two datasets of images.
9
+ It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks.
10
+ FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network.
11
+
12
+ Further insights and an independent evaluation of the FID score can be found in [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337).
13
+
14
+ The weights and the model are exactly the same as in [the official Tensorflow implementation](https://github.com/bioinf-jku/TTUR), and were tested to give very similar results (e.g. `.08` absolute error and `0.0009` relative error on LSUN, using ProGAN generated images). However, due to differences in the image interpolation implementation and library backends, FID results still differ slightly from the original implementation. So if you report FID scores in your paper, and you want them to be *exactly comparable* to FID scores reported in other papers, you should consider using [the official Tensorflow implementation](https://github.com/bioinf-jku/TTUR).
15
+
16
+ ## Installation
17
+
18
+ Install from [pip](https://pypi.org/project/pytorch-fid/):
19
+
20
+ ```
21
+ pip install pytorch-fid
22
+ ```
23
+
24
+ Requirements:
25
+ - python3
26
+ - pytorch
27
+ - torchvision
28
+ - pillow
29
+ - numpy
30
+ - scipy
31
+
32
+ ## Usage
33
+
34
+ To compute the FID score between two datasets, where images of each dataset are contained in an individual folder:
35
+ ```
36
+ python -m pytorch_fid path/to/dataset1 path/to/dataset2
37
+ ```
38
+
39
+ To run the evaluation on GPU, use the flag `--device cuda:N`, where `N` is the index of the GPU to use.
40
+
41
+ ### Using different layers for feature maps
42
+
43
+ In difference to the official implementation, you can choose to use a different feature layer of the Inception network instead of the default `pool3` layer.
44
+ As the lower layer features still have spatial extent, the features are first global average pooled to a vector before estimating mean and covariance.
45
+
46
+ This might be useful if the datasets you want to compare have less than the otherwise required 2048 images.
47
+ Note that this changes the magnitude of the FID score and you can not compare them against scores calculated on another dimensionality.
48
+ The resulting scores might also no longer correlate with visual quality.
49
+
50
+ You can select the dimensionality of features to use with the flag `--dims N`, where N is the dimensionality of features.
51
+ The choices are:
52
+ - 64: first max pooling features
53
+ - 192: second max pooling features
54
+ - 768: pre-aux classifier features
55
+ - 2048: final average pooling features (this is the default)
56
+
57
+ ## Generating a compatible `.npz` archive from a dataset
58
+ A frequent use case will be to compare multiple models against an original dataset.
59
+ To save training multiple times on the original dataset, there is also the ability to generate a compatible `.npz` archive from a dataset. This is done using any combination of the previously mentioned arguments with the addition of the `--save-stats` flag. For example:
60
+ ```
61
+ python -m pytorch_fid --save-stats path/to/dataset path/to/outputfile
62
+ ```
63
+
64
+ The output file may then be used in place of the path to the original dataset for further comparisons.
65
+
66
+ ## Citing
67
+
68
+ If you use this repository in your research, consider citing it using the following Bibtex entry:
69
+
70
+ ```
71
+ @misc{Seitzer2020FID,
72
+ author={Maximilian Seitzer},
73
+ title={{pytorch-fid: FID Score for PyTorch}},
74
+ month={August},
75
+ year={2020},
76
+ note={Version 0.3.0},
77
+ howpublished={\url{https://github.com/mseitzer/pytorch-fid}},
78
+ }
79
+ ```
80
+
81
+ ## License
82
+
83
+ This implementation is licensed under the Apache License 2.0.
84
+
85
+ FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see [https://arxiv.org/abs/1706.08500](https://arxiv.org/abs/1706.08500)
86
+
87
+ The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0.
88
+ See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR).
noxfile.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nox
2
+
3
+ LOCATIONS = ("src/", "tests/", "noxfile.py", "setup.py")
4
+
5
+
6
+ @nox.session
7
+ def lint(session):
8
+ session.install("flake8")
9
+ session.install("flake8-bugbear")
10
+ session.install("flake8-isort")
11
+ session.install("black==24.3.0")
12
+
13
+ args = session.posargs or LOCATIONS
14
+ session.run("flake8", *args)
15
+ session.run("black", "--check", "--diff", *args)
16
+
17
+
18
+ @nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12"])
19
+ def tests(session):
20
+ session.install(
21
+ "torch==2.2.1",
22
+ "torchvision",
23
+ "--index-url",
24
+ "https://download.pytorch.org/whl/cpu",
25
+ )
26
+ session.install(".")
27
+ session.install("pytest")
28
+ session.install("pytest-mock")
29
+ session.run("pytest", *session.posargs)
pyproject.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ target-version = ["py311"]
3
+
4
+ [tool.isort]
5
+ profile = "black"
6
+ line_length = 88
7
+ multi_line_output = 3
setup.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import setuptools
4
+
5
+
6
+ def read(rel_path):
7
+ base_path = os.path.abspath(os.path.dirname(__file__))
8
+ with open(os.path.join(base_path, rel_path), "r") as f:
9
+ return f.read()
10
+
11
+
12
+ def get_version(rel_path):
13
+ for line in read(rel_path).splitlines():
14
+ if line.startswith("__version__"):
15
+ # __version__ = "0.9"
16
+ delim = '"' if '"' in line else "'"
17
+ return line.split(delim)[1]
18
+
19
+ raise RuntimeError("Unable to find version string.")
20
+
21
+
22
+ if __name__ == "__main__":
23
+ setuptools.setup(
24
+ name="pytorch-fid",
25
+ version=get_version(os.path.join("src", "pytorch_fid", "__init__.py")),
26
+ author="Max Seitzer",
27
+ description=(
28
+ "Package for calculating Frechet Inception Distance (FID)" " using PyTorch"
29
+ ),
30
+ long_description=read("README.md"),
31
+ long_description_content_type="text/markdown",
32
+ url="https://github.com/mseitzer/pytorch-fid",
33
+ package_dir={"": "src"},
34
+ packages=setuptools.find_packages(where="src"),
35
+ classifiers=[
36
+ "Programming Language :: Python :: 3",
37
+ "License :: OSI Approved :: Apache Software License",
38
+ ],
39
+ python_requires=">=3.5",
40
+ entry_points={
41
+ "console_scripts": [
42
+ "pytorch-fid = pytorch_fid.fid_score:main",
43
+ ],
44
+ },
45
+ install_requires=[
46
+ "numpy",
47
+ "pillow",
48
+ "scipy",
49
+ "torch>=1.0.1",
50
+ "torchvision>=0.2.2",
51
+ ],
52
+ extras_require={
53
+ "dev": ["flake8", "flake8-bugbear", "flake8-isort", "black==24.3.0", "nox"]
54
+ },
55
+ )
src/pytorch_fid/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.3.0"
src/pytorch_fid/__main__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import pytorch_fid.fid_score
2
+
3
+ pytorch_fid.fid_score.main()
src/pytorch_fid/fid_score.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+
3
+ The FID metric calculates the distance between two distributions of images.
4
+ Typically, we have summary statistics (mean & covariance matrix) of one
5
+ of these distributions, while the 2nd distribution is given by a GAN.
6
+
7
+ When run as a stand-alone program, it compares the distribution of
8
+ images that are stored as PNG/JPEG at a specified location with a
9
+ distribution given by summary statistics (in pickle format).
10
+
11
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
12
+ the pool_3 layer of the inception net for generated samples and real world
13
+ samples respectively.
14
+
15
+ See --help to see further details.
16
+
17
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18
+ of Tensorflow
19
+
20
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
21
+
22
+ Licensed under the Apache License, Version 2.0 (the "License");
23
+ you may not use this file except in compliance with the License.
24
+ You may obtain a copy of the License at
25
+
26
+ http://www.apache.org/licenses/LICENSE-2.0
27
+
28
+ Unless required by applicable law or agreed to in writing, software
29
+ distributed under the License is distributed on an "AS IS" BASIS,
30
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ See the License for the specific language governing permissions and
32
+ limitations under the License.
33
+ """
34
+
35
+ import os
36
+ import pathlib
37
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
38
+
39
+ import numpy as np
40
+ import torch
41
+ import torchvision.transforms as TF
42
+ from PIL import Image
43
+ from scipy import linalg
44
+ from torch.nn.functional import adaptive_avg_pool2d
45
+
46
+ try:
47
+ from tqdm import tqdm
48
+ except ImportError:
49
+ # If tqdm is not available, provide a mock version of it
50
+ def tqdm(x):
51
+ return x
52
+
53
+
54
+ from pytorch_fid.inception import InceptionV3
55
+
56
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
57
+ parser.add_argument("--batch-size", type=int, default=50, help="Batch size to use")
58
+ parser.add_argument(
59
+ "--num-workers",
60
+ type=int,
61
+ help=(
62
+ "Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`"
63
+ ),
64
+ )
65
+ parser.add_argument(
66
+ "--device", type=str, default=None, help="Device to use. Like cuda, cuda:0 or cpu"
67
+ )
68
+ parser.add_argument(
69
+ "--dims",
70
+ type=int,
71
+ default=2048,
72
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
73
+ help=(
74
+ "Dimensionality of Inception features to use. "
75
+ "By default, uses pool3 features"
76
+ ),
77
+ )
78
+ parser.add_argument(
79
+ "--save-stats",
80
+ action="store_true",
81
+ help=(
82
+ "Generate an npz archive from a directory of "
83
+ "samples. The first path is used as input and the "
84
+ "second as output."
85
+ ),
86
+ )
87
+ parser.add_argument(
88
+ "path",
89
+ type=str,
90
+ nargs=2,
91
+ help=("Paths to the generated images or " "to .npz statistic files"),
92
+ )
93
+
94
+ IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
95
+
96
+
97
+ class ImagePathDataset(torch.utils.data.Dataset):
98
+ def __init__(self, files, transforms=None):
99
+ self.files = files
100
+ self.transforms = transforms
101
+
102
+ def __len__(self):
103
+ return len(self.files)
104
+
105
+ def __getitem__(self, i):
106
+ path = self.files[i]
107
+ img = Image.open(path).convert("RGB")
108
+ if self.transforms is not None:
109
+ img = self.transforms(img)
110
+ return img
111
+
112
+
113
+ def get_activations(
114
+ files, model, batch_size=50, dims=2048, device="cpu", num_workers=1
115
+ ):
116
+ """Calculates the activations of the pool_3 layer for all images.
117
+
118
+ Params:
119
+ -- files : List of image files paths
120
+ -- model : Instance of inception model
121
+ -- batch_size : Batch size of images for the model to process at once.
122
+ Make sure that the number of samples is a multiple of
123
+ the batch size, otherwise some samples are ignored. This
124
+ behavior is retained to match the original FID score
125
+ implementation.
126
+ -- dims : Dimensionality of features returned by Inception
127
+ -- device : Device to run calculations
128
+ -- num_workers : Number of parallel dataloader workers
129
+
130
+ Returns:
131
+ -- A numpy array of dimension (num images, dims) that contains the
132
+ activations of the given tensor when feeding inception with the
133
+ query tensor.
134
+ """
135
+ model.eval()
136
+
137
+ if batch_size > len(files):
138
+ print(
139
+ (
140
+ "Warning: batch size is bigger than the data size. "
141
+ "Setting batch size to data size"
142
+ )
143
+ )
144
+ batch_size = len(files)
145
+
146
+ dataset = ImagePathDataset(files, transforms=TF.ToTensor())
147
+ dataloader = torch.utils.data.DataLoader(
148
+ dataset,
149
+ batch_size=batch_size,
150
+ shuffle=False,
151
+ drop_last=False,
152
+ num_workers=num_workers,
153
+ )
154
+
155
+ pred_arr = np.empty((len(files), dims))
156
+
157
+ start_idx = 0
158
+
159
+ for batch in tqdm(dataloader):
160
+ batch = batch.to(device)
161
+
162
+ with torch.no_grad():
163
+ pred = model(batch)[0]
164
+
165
+ # If model output is not scalar, apply global spatial average pooling.
166
+ # This happens if you choose a dimensionality not equal 2048.
167
+ if pred.size(2) != 1 or pred.size(3) != 1:
168
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
169
+
170
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
171
+
172
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
173
+
174
+ start_idx = start_idx + pred.shape[0]
175
+
176
+ return pred_arr
177
+
178
+
179
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
180
+ """Numpy implementation of the Frechet Distance.
181
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
182
+ and X_2 ~ N(mu_2, C_2) is
183
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
184
+
185
+ Stable version by Dougal J. Sutherland.
186
+
187
+ Params:
188
+ -- mu1 : Numpy array containing the activations of a layer of the
189
+ inception net (like returned by the function 'get_predictions')
190
+ for generated samples.
191
+ -- mu2 : The sample mean over activations, precalculated on an
192
+ representative data set.
193
+ -- sigma1: The covariance matrix over activations for generated samples.
194
+ -- sigma2: The covariance matrix over activations, precalculated on an
195
+ representative data set.
196
+
197
+ Returns:
198
+ -- : The Frechet Distance.
199
+ """
200
+
201
+ mu1 = np.atleast_1d(mu1)
202
+ mu2 = np.atleast_1d(mu2)
203
+
204
+ sigma1 = np.atleast_2d(sigma1)
205
+ sigma2 = np.atleast_2d(sigma2)
206
+
207
+ assert (
208
+ mu1.shape == mu2.shape
209
+ ), "Training and test mean vectors have different lengths"
210
+ assert (
211
+ sigma1.shape == sigma2.shape
212
+ ), "Training and test covariances have different dimensions"
213
+
214
+ diff = mu1 - mu2
215
+
216
+ # Product might be almost singular
217
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
218
+ if not np.isfinite(covmean).all():
219
+ msg = (
220
+ "fid calculation produces singular product; "
221
+ "adding %s to diagonal of cov estimates"
222
+ ) % eps
223
+ print(msg)
224
+ offset = np.eye(sigma1.shape[0]) * eps
225
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
226
+
227
+ # Numerical error might give slight imaginary component
228
+ if np.iscomplexobj(covmean):
229
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
230
+ m = np.max(np.abs(covmean.imag))
231
+ raise ValueError("Imaginary component {}".format(m))
232
+ covmean = covmean.real
233
+
234
+ tr_covmean = np.trace(covmean)
235
+
236
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
237
+
238
+
239
+ def calculate_activation_statistics(
240
+ files, model, batch_size=50, dims=2048, device="cpu", num_workers=1
241
+ ):
242
+ """Calculation of the statistics used by the FID.
243
+ Params:
244
+ -- files : List of image files paths
245
+ -- model : Instance of inception model
246
+ -- batch_size : The images numpy array is split into batches with
247
+ batch size batch_size. A reasonable batch size
248
+ depends on the hardware.
249
+ -- dims : Dimensionality of features returned by Inception
250
+ -- device : Device to run calculations
251
+ -- num_workers : Number of parallel dataloader workers
252
+
253
+ Returns:
254
+ -- mu : The mean over samples of the activations of the pool_3 layer of
255
+ the inception model.
256
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
257
+ the inception model.
258
+ """
259
+ act = get_activations(files, model, batch_size, dims, device, num_workers)
260
+ mu = np.mean(act, axis=0)
261
+ sigma = np.cov(act, rowvar=False)
262
+ return mu, sigma
263
+
264
+
265
+ def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=1):
266
+ if path.endswith(".npz"):
267
+ with np.load(path) as f:
268
+ m, s = f["mu"][:], f["sigma"][:]
269
+ else:
270
+ path = pathlib.Path(path)
271
+ files = sorted(
272
+ [file for ext in IMAGE_EXTENSIONS for file in path.glob("*.{}".format(ext))]
273
+ )
274
+ m, s = calculate_activation_statistics(
275
+ files, model, batch_size, dims, device, num_workers
276
+ )
277
+
278
+ return m, s
279
+
280
+
281
+ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
282
+ """Calculates the FID of two paths"""
283
+ for p in paths:
284
+ if not os.path.exists(p):
285
+ raise RuntimeError("Invalid path: %s" % p)
286
+
287
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
288
+
289
+ model = InceptionV3([block_idx]).to(device)
290
+
291
+ m1, s1 = compute_statistics_of_path(
292
+ paths[0], model, batch_size, dims, device, num_workers
293
+ )
294
+ m2, s2 = compute_statistics_of_path(
295
+ paths[1], model, batch_size, dims, device, num_workers
296
+ )
297
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
298
+
299
+ return fid_value
300
+
301
+
302
+ def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
303
+ """Saves FID statistics of one path"""
304
+ if not os.path.exists(paths[0]):
305
+ raise RuntimeError("Invalid path: %s" % paths[0])
306
+
307
+ if os.path.exists(paths[1]):
308
+ raise RuntimeError("Existing output file: %s" % paths[1])
309
+
310
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
311
+
312
+ model = InceptionV3([block_idx]).to(device)
313
+
314
+ print(f"Saving statistics for {paths[0]}")
315
+
316
+ m1, s1 = compute_statistics_of_path(
317
+ paths[0], model, batch_size, dims, device, num_workers
318
+ )
319
+
320
+ np.savez_compressed(paths[1], mu=m1, sigma=s1)
321
+
322
+
323
+ def main():
324
+ args = parser.parse_args()
325
+
326
+ if args.device is None:
327
+ device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
328
+ else:
329
+ device = torch.device(args.device)
330
+
331
+ if args.num_workers is None:
332
+ try:
333
+ num_cpus = len(os.sched_getaffinity(0))
334
+ except AttributeError:
335
+ # os.sched_getaffinity is not available under Windows, use
336
+ # os.cpu_count instead (which may not return the *available* number
337
+ # of CPUs).
338
+ num_cpus = os.cpu_count()
339
+
340
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
341
+ else:
342
+ num_workers = args.num_workers
343
+
344
+ if args.save_stats:
345
+ save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
346
+ return
347
+
348
+ fid_value = calculate_fid_given_paths(
349
+ args.path, args.batch_size, device, args.dims, num_workers
350
+ )
351
+ print("FID: ", fid_value)
352
+
353
+
354
+ if __name__ == "__main__":
355
+ main()
src/pytorch_fid/inception.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3, # Final average pooling features
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
34
+ resize_input=True,
35
+ normalize_input=True,
36
+ requires_grad=False,
37
+ use_fid_inception=True,
38
+ ):
39
+ """Build pretrained InceptionV3
40
+
41
+ Parameters
42
+ ----------
43
+ output_blocks : list of int
44
+ Indices of blocks to return features of. Possible values are:
45
+ - 0: corresponds to output of first max pooling
46
+ - 1: corresponds to output of second max pooling
47
+ - 2: corresponds to output which is fed to aux classifier
48
+ - 3: corresponds to output of final average pooling
49
+ resize_input : bool
50
+ If true, bilinearly resizes input to width and height 299 before
51
+ feeding input to model. As the network without fully connected
52
+ layers is fully convolutional, it should be able to handle inputs
53
+ of arbitrary size, so resizing might not be strictly needed
54
+ normalize_input : bool
55
+ If true, scales the input from range (0, 1) to the range the
56
+ pretrained Inception network expects, namely (-1, 1)
57
+ requires_grad : bool
58
+ If true, parameters of the model require gradients. Possibly useful
59
+ for finetuning the network
60
+ use_fid_inception : bool
61
+ If true, uses the pretrained Inception model used in Tensorflow's
62
+ FID implementation. If false, uses the pretrained Inception model
63
+ available in torchvision. The FID Inception model has different
64
+ weights and a slightly different structure from torchvision's
65
+ Inception model. If you want to compute FID scores, you are
66
+ strongly advised to set this parameter to true to get comparable
67
+ results.
68
+ """
69
+ super(InceptionV3, self).__init__()
70
+
71
+ self.resize_input = resize_input
72
+ self.normalize_input = normalize_input
73
+ self.output_blocks = sorted(output_blocks)
74
+ self.last_needed_block = max(output_blocks)
75
+
76
+ assert self.last_needed_block <= 3, "Last possible output block index is 3"
77
+
78
+ self.blocks = nn.ModuleList()
79
+
80
+ if use_fid_inception:
81
+ inception = fid_inception_v3()
82
+ else:
83
+ inception = _inception_v3(weights="DEFAULT")
84
+
85
+ # Block 0: input to maxpool1
86
+ block0 = [
87
+ inception.Conv2d_1a_3x3,
88
+ inception.Conv2d_2a_3x3,
89
+ inception.Conv2d_2b_3x3,
90
+ nn.MaxPool2d(kernel_size=3, stride=2),
91
+ ]
92
+ self.blocks.append(nn.Sequential(*block0))
93
+
94
+ # Block 1: maxpool1 to maxpool2
95
+ if self.last_needed_block >= 1:
96
+ block1 = [
97
+ inception.Conv2d_3b_1x1,
98
+ inception.Conv2d_4a_3x3,
99
+ nn.MaxPool2d(kernel_size=3, stride=2),
100
+ ]
101
+ self.blocks.append(nn.Sequential(*block1))
102
+
103
+ # Block 2: maxpool2 to aux classifier
104
+ if self.last_needed_block >= 2:
105
+ block2 = [
106
+ inception.Mixed_5b,
107
+ inception.Mixed_5c,
108
+ inception.Mixed_5d,
109
+ inception.Mixed_6a,
110
+ inception.Mixed_6b,
111
+ inception.Mixed_6c,
112
+ inception.Mixed_6d,
113
+ inception.Mixed_6e,
114
+ ]
115
+ self.blocks.append(nn.Sequential(*block2))
116
+
117
+ # Block 3: aux classifier to final avgpool
118
+ if self.last_needed_block >= 3:
119
+ block3 = [
120
+ inception.Mixed_7a,
121
+ inception.Mixed_7b,
122
+ inception.Mixed_7c,
123
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
124
+ ]
125
+ self.blocks.append(nn.Sequential(*block3))
126
+
127
+ for param in self.parameters():
128
+ param.requires_grad = requires_grad
129
+
130
+ def forward(self, inp):
131
+ """Get Inception feature maps
132
+
133
+ Parameters
134
+ ----------
135
+ inp : torch.autograd.Variable
136
+ Input tensor of shape Bx3xHxW. Values are expected to be in
137
+ range (0, 1)
138
+
139
+ Returns
140
+ -------
141
+ List of torch.autograd.Variable, corresponding to the selected output
142
+ block, sorted ascending by index
143
+ """
144
+ outp = []
145
+ x = inp
146
+
147
+ if self.resize_input:
148
+ x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
149
+
150
+ if self.normalize_input:
151
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
152
+
153
+ for idx, block in enumerate(self.blocks):
154
+ x = block(x)
155
+ if idx in self.output_blocks:
156
+ outp.append(x)
157
+
158
+ if idx == self.last_needed_block:
159
+ break
160
+
161
+ return outp
162
+
163
+
164
+ def _inception_v3(*args, **kwargs):
165
+ """Wraps `torchvision.models.inception_v3`"""
166
+ try:
167
+ version = tuple(map(int, torchvision.__version__.split(".")[:2]))
168
+ except ValueError:
169
+ # Just a caution against weird version strings
170
+ version = (0,)
171
+
172
+ # Skips default weight inititialization if supported by torchvision
173
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
174
+ if version >= (0, 6):
175
+ kwargs["init_weights"] = False
176
+
177
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
178
+ # argument prior to version 0.13.
179
+ if version < (0, 13) and "weights" in kwargs:
180
+ if kwargs["weights"] == "DEFAULT":
181
+ kwargs["pretrained"] = True
182
+ elif kwargs["weights"] is None:
183
+ kwargs["pretrained"] = False
184
+ else:
185
+ raise ValueError(
186
+ "weights=={} not supported in torchvision {}".format(
187
+ kwargs["weights"], torchvision.__version__
188
+ )
189
+ )
190
+ del kwargs["weights"]
191
+
192
+ return torchvision.models.inception_v3(*args, **kwargs)
193
+
194
+
195
+ def fid_inception_v3():
196
+ """Build pretrained Inception model for FID computation
197
+
198
+ The Inception model for FID computation uses a different set of weights
199
+ and has a slightly different structure than torchvision's Inception.
200
+
201
+ This method first constructs torchvision's Inception and then patches the
202
+ necessary parts that are different in the FID Inception model.
203
+ """
204
+ inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None)
205
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
206
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
207
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
208
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
209
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
210
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
211
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
212
+ inception.Mixed_7b = FIDInceptionE_1(1280)
213
+ inception.Mixed_7c = FIDInceptionE_2(2048)
214
+
215
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
216
+ inception.load_state_dict(state_dict)
217
+ return inception
218
+
219
+
220
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
221
+ """InceptionA block patched for FID computation"""
222
+
223
+ def __init__(self, in_channels, pool_features):
224
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
225
+
226
+ def forward(self, x):
227
+ branch1x1 = self.branch1x1(x)
228
+
229
+ branch5x5 = self.branch5x5_1(x)
230
+ branch5x5 = self.branch5x5_2(branch5x5)
231
+
232
+ branch3x3dbl = self.branch3x3dbl_1(x)
233
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
234
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
235
+
236
+ # Patch: Tensorflow's average pool does not use the padded zero's in
237
+ # its average calculation
238
+ branch_pool = F.avg_pool2d(
239
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
240
+ )
241
+ branch_pool = self.branch_pool(branch_pool)
242
+
243
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
244
+ return torch.cat(outputs, 1)
245
+
246
+
247
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
248
+ """InceptionC block patched for FID computation"""
249
+
250
+ def __init__(self, in_channels, channels_7x7):
251
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
252
+
253
+ def forward(self, x):
254
+ branch1x1 = self.branch1x1(x)
255
+
256
+ branch7x7 = self.branch7x7_1(x)
257
+ branch7x7 = self.branch7x7_2(branch7x7)
258
+ branch7x7 = self.branch7x7_3(branch7x7)
259
+
260
+ branch7x7dbl = self.branch7x7dbl_1(x)
261
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
262
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
263
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
264
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
265
+
266
+ # Patch: Tensorflow's average pool does not use the padded zero's in
267
+ # its average calculation
268
+ branch_pool = F.avg_pool2d(
269
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
270
+ )
271
+ branch_pool = self.branch_pool(branch_pool)
272
+
273
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
274
+ return torch.cat(outputs, 1)
275
+
276
+
277
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
278
+ """First InceptionE block patched for FID computation"""
279
+
280
+ def __init__(self, in_channels):
281
+ super(FIDInceptionE_1, self).__init__(in_channels)
282
+
283
+ def forward(self, x):
284
+ branch1x1 = self.branch1x1(x)
285
+
286
+ branch3x3 = self.branch3x3_1(x)
287
+ branch3x3 = [
288
+ self.branch3x3_2a(branch3x3),
289
+ self.branch3x3_2b(branch3x3),
290
+ ]
291
+ branch3x3 = torch.cat(branch3x3, 1)
292
+
293
+ branch3x3dbl = self.branch3x3dbl_1(x)
294
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
295
+ branch3x3dbl = [
296
+ self.branch3x3dbl_3a(branch3x3dbl),
297
+ self.branch3x3dbl_3b(branch3x3dbl),
298
+ ]
299
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
300
+
301
+ # Patch: Tensorflow's average pool does not use the padded zero's in
302
+ # its average calculation
303
+ branch_pool = F.avg_pool2d(
304
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
305
+ )
306
+ branch_pool = self.branch_pool(branch_pool)
307
+
308
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
309
+ return torch.cat(outputs, 1)
310
+
311
+
312
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
313
+ """Second InceptionE block patched for FID computation"""
314
+
315
+ def __init__(self, in_channels):
316
+ super(FIDInceptionE_2, self).__init__(in_channels)
317
+
318
+ def forward(self, x):
319
+ branch1x1 = self.branch1x1(x)
320
+
321
+ branch3x3 = self.branch3x3_1(x)
322
+ branch3x3 = [
323
+ self.branch3x3_2a(branch3x3),
324
+ self.branch3x3_2b(branch3x3),
325
+ ]
326
+ branch3x3 = torch.cat(branch3x3, 1)
327
+
328
+ branch3x3dbl = self.branch3x3dbl_1(x)
329
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
330
+ branch3x3dbl = [
331
+ self.branch3x3dbl_3a(branch3x3dbl),
332
+ self.branch3x3dbl_3b(branch3x3dbl),
333
+ ]
334
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
335
+
336
+ # Patch: The FID Inception model uses max pooling instead of average
337
+ # pooling. This is likely an error in this specific Inception
338
+ # implementation, as other Inception models use average pooling here
339
+ # (which matches the description in the paper).
340
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
341
+ branch_pool = self.branch_pool(branch_pool)
342
+
343
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
344
+ return torch.cat(outputs, 1)
tests/test_fid_score.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pytest
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from pytorch_fid import fid_score, inception
7
+
8
+
9
+ @pytest.fixture
10
+ def device():
11
+ return torch.device("cpu")
12
+
13
+
14
+ def test_calculate_fid_given_statistics(mocker, tmp_path, device):
15
+ dim = 2048
16
+ m1, m2 = np.zeros((dim,)), np.ones((dim,))
17
+ sigma = np.eye(dim)
18
+
19
+ def dummy_statistics(path, model, batch_size, dims, device, num_workers):
20
+ if path.endswith("1"):
21
+ return m1, sigma
22
+ elif path.endswith("2"):
23
+ return m2, sigma
24
+ else:
25
+ raise ValueError
26
+
27
+ mocker.patch(
28
+ "pytorch_fid.fid_score.compute_statistics_of_path", side_effect=dummy_statistics
29
+ )
30
+
31
+ dir_names = ["1", "2"]
32
+ paths = []
33
+ for name in dir_names:
34
+ path = tmp_path / name
35
+ path.mkdir()
36
+ paths.append(str(path))
37
+
38
+ fid_value = fid_score.calculate_fid_given_paths(
39
+ paths, batch_size=dim, device=device, dims=dim, num_workers=0
40
+ )
41
+
42
+ # Given equal covariance, FID is just the squared norm of difference
43
+ assert fid_value == np.sum((m1 - m2) ** 2)
44
+
45
+
46
+ def test_compute_statistics_of_path(mocker, tmp_path, device):
47
+ model = mocker.MagicMock(inception.InceptionV3)()
48
+ model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)]
49
+
50
+ size = (4, 4, 3)
51
+ arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)]
52
+ images = [(arr * 255).astype(np.uint8) for arr in arrays]
53
+
54
+ paths = []
55
+ for idx, image in enumerate(images):
56
+ paths.append(str(tmp_path / "{}.png".format(idx)))
57
+ Image.fromarray(image, mode="RGB").save(paths[-1])
58
+
59
+ stats = fid_score.compute_statistics_of_path(
60
+ str(tmp_path),
61
+ model,
62
+ batch_size=len(images),
63
+ dims=3,
64
+ device=device,
65
+ num_workers=0,
66
+ )
67
+
68
+ assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3)
69
+ assert np.allclose(stats[1], np.ones((3, 3)) * 0.25)
70
+
71
+
72
+ def test_compute_statistics_of_path_from_file(mocker, tmp_path, device):
73
+ model = mocker.MagicMock(inception.InceptionV3)()
74
+
75
+ mu = np.random.randn(5)
76
+ sigma = np.random.randn(5, 5)
77
+
78
+ path = tmp_path / "stats.npz"
79
+ with path.open("wb") as f:
80
+ np.savez(f, mu=mu, sigma=sigma)
81
+
82
+ stats = fid_score.compute_statistics_of_path(
83
+ str(path), model, batch_size=1, dims=5, device=device, num_workers=0
84
+ )
85
+
86
+ assert np.allclose(stats[0], mu)
87
+ assert np.allclose(stats[1], sigma)
88
+
89
+
90
+ def test_image_types(tmp_path):
91
+ in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255
92
+ in_image = Image.fromarray(in_arr, mode="RGB")
93
+
94
+ paths = []
95
+ for ext in fid_score.IMAGE_EXTENSIONS:
96
+ paths.append(str(tmp_path / "img.{}".format(ext)))
97
+ in_image.save(paths[-1])
98
+
99
+ dataset = fid_score.ImagePathDataset(paths)
100
+
101
+ for img in dataset:
102
+ assert np.allclose(np.array(img), in_arr)