Upload 11 files
Browse files- CHANGELOG.md +41 -0
- LICENSE +201 -0
- README.md +88 -0
- noxfile.py +29 -0
- pyproject.toml +7 -0
- setup.py +55 -0
- src/pytorch_fid/__init__.py +1 -0
- src/pytorch_fid/__main__.py +3 -0
- src/pytorch_fid/fid_score.py +355 -0
- src/pytorch_fid/inception.py +344 -0
- tests/test_fid_score.py +102 -0
CHANGELOG.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changelog
|
2 |
+
|
3 |
+
## [0.3.0] - 2023-01-05
|
4 |
+
|
5 |
+
### Added
|
6 |
+
|
7 |
+
* Add argument `--save-stats` allowing to compute dataset statistics and save them as an `.npz` file ([#80](https://github.com/mseitzer/pytorch-fid/pull/80)). The `.npz` file can be used in subsequent FID computations instead of recomputing the dataset statistics. This option can be used in the following way: `python -m pytorch_fid --save-stats path/to/dataset path/to/outputfile`.
|
8 |
+
|
9 |
+
### Fixed
|
10 |
+
|
11 |
+
* Do not use `os.sched_getaffinity` to get number of available CPUs on Windows, as it is not available there ([232b3b14](https://github.com/mseitzer/pytorch-fid/commit/232b3b1468800102fcceaf6f2bb8977811fc991a), [#84](https://github.com/mseitzer/pytorch-fid/issues/84)).
|
12 |
+
* Do not use Inception model argument `pretrained`, as it was deprecated in torchvision 0.13 ([#88](https://github.com/mseitzer/pytorch-fid/pull/88)).
|
13 |
+
|
14 |
+
## [0.2.1] - 2021-10-10
|
15 |
+
|
16 |
+
### Added
|
17 |
+
|
18 |
+
* Add argument `--num-workers` to select number of dataloader processes ([#66](https://github.com/mseitzer/pytorch-fid/pull/66)). Defaults to 8 or the number of available CPUs if less than 8 CPUs are available.
|
19 |
+
|
20 |
+
### Fixed
|
21 |
+
|
22 |
+
* Fixed package setup to work under Windows ([#55](https://github.com/mseitzer/pytorch-fid/pull/55), [#72](https://github.com/mseitzer/pytorch-fid/issues/72))
|
23 |
+
|
24 |
+
## [0.2.0] - 2020-11-30
|
25 |
+
|
26 |
+
### Added
|
27 |
+
|
28 |
+
* Load images using a Pytorch dataloader, which should result in a speed-up. ([#47](https://github.com/mseitzer/pytorch-fid/pull/47))
|
29 |
+
* Support more image extensions ([#53](https://github.com/mseitzer/pytorch-fid/pull/53))
|
30 |
+
* Improve tooling by setting up Nox, add linting and test support ([#52](https://github.com/mseitzer/pytorch-fid/pull/52))
|
31 |
+
* Add some unit tests
|
32 |
+
|
33 |
+
## [0.1.1] - 2020-08-16
|
34 |
+
|
35 |
+
### Fixed
|
36 |
+
|
37 |
+
* Fixed software license string in `setup.py`
|
38 |
+
|
39 |
+
## [0.1.0] - 2020-08-16
|
40 |
+
|
41 |
+
Initial release as a pypi package. Use `pip install pytorch-fid` to install.
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[](https://pypi.org/project/pytorch-fid/)
|
2 |
+
|
3 |
+
# FID score for PyTorch
|
4 |
+
|
5 |
+
This is a port of the official implementation of [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500) to PyTorch.
|
6 |
+
See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR) for the original implementation using Tensorflow.
|
7 |
+
|
8 |
+
FID is a measure of similarity between two datasets of images.
|
9 |
+
It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks.
|
10 |
+
FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network.
|
11 |
+
|
12 |
+
Further insights and an independent evaluation of the FID score can be found in [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337).
|
13 |
+
|
14 |
+
The weights and the model are exactly the same as in [the official Tensorflow implementation](https://github.com/bioinf-jku/TTUR), and were tested to give very similar results (e.g. `.08` absolute error and `0.0009` relative error on LSUN, using ProGAN generated images). However, due to differences in the image interpolation implementation and library backends, FID results still differ slightly from the original implementation. So if you report FID scores in your paper, and you want them to be *exactly comparable* to FID scores reported in other papers, you should consider using [the official Tensorflow implementation](https://github.com/bioinf-jku/TTUR).
|
15 |
+
|
16 |
+
## Installation
|
17 |
+
|
18 |
+
Install from [pip](https://pypi.org/project/pytorch-fid/):
|
19 |
+
|
20 |
+
```
|
21 |
+
pip install pytorch-fid
|
22 |
+
```
|
23 |
+
|
24 |
+
Requirements:
|
25 |
+
- python3
|
26 |
+
- pytorch
|
27 |
+
- torchvision
|
28 |
+
- pillow
|
29 |
+
- numpy
|
30 |
+
- scipy
|
31 |
+
|
32 |
+
## Usage
|
33 |
+
|
34 |
+
To compute the FID score between two datasets, where images of each dataset are contained in an individual folder:
|
35 |
+
```
|
36 |
+
python -m pytorch_fid path/to/dataset1 path/to/dataset2
|
37 |
+
```
|
38 |
+
|
39 |
+
To run the evaluation on GPU, use the flag `--device cuda:N`, where `N` is the index of the GPU to use.
|
40 |
+
|
41 |
+
### Using different layers for feature maps
|
42 |
+
|
43 |
+
In difference to the official implementation, you can choose to use a different feature layer of the Inception network instead of the default `pool3` layer.
|
44 |
+
As the lower layer features still have spatial extent, the features are first global average pooled to a vector before estimating mean and covariance.
|
45 |
+
|
46 |
+
This might be useful if the datasets you want to compare have less than the otherwise required 2048 images.
|
47 |
+
Note that this changes the magnitude of the FID score and you can not compare them against scores calculated on another dimensionality.
|
48 |
+
The resulting scores might also no longer correlate with visual quality.
|
49 |
+
|
50 |
+
You can select the dimensionality of features to use with the flag `--dims N`, where N is the dimensionality of features.
|
51 |
+
The choices are:
|
52 |
+
- 64: first max pooling features
|
53 |
+
- 192: second max pooling features
|
54 |
+
- 768: pre-aux classifier features
|
55 |
+
- 2048: final average pooling features (this is the default)
|
56 |
+
|
57 |
+
## Generating a compatible `.npz` archive from a dataset
|
58 |
+
A frequent use case will be to compare multiple models against an original dataset.
|
59 |
+
To save training multiple times on the original dataset, there is also the ability to generate a compatible `.npz` archive from a dataset. This is done using any combination of the previously mentioned arguments with the addition of the `--save-stats` flag. For example:
|
60 |
+
```
|
61 |
+
python -m pytorch_fid --save-stats path/to/dataset path/to/outputfile
|
62 |
+
```
|
63 |
+
|
64 |
+
The output file may then be used in place of the path to the original dataset for further comparisons.
|
65 |
+
|
66 |
+
## Citing
|
67 |
+
|
68 |
+
If you use this repository in your research, consider citing it using the following Bibtex entry:
|
69 |
+
|
70 |
+
```
|
71 |
+
@misc{Seitzer2020FID,
|
72 |
+
author={Maximilian Seitzer},
|
73 |
+
title={{pytorch-fid: FID Score for PyTorch}},
|
74 |
+
month={August},
|
75 |
+
year={2020},
|
76 |
+
note={Version 0.3.0},
|
77 |
+
howpublished={\url{https://github.com/mseitzer/pytorch-fid}},
|
78 |
+
}
|
79 |
+
```
|
80 |
+
|
81 |
+
## License
|
82 |
+
|
83 |
+
This implementation is licensed under the Apache License 2.0.
|
84 |
+
|
85 |
+
FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see [https://arxiv.org/abs/1706.08500](https://arxiv.org/abs/1706.08500)
|
86 |
+
|
87 |
+
The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0.
|
88 |
+
See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR).
|
noxfile.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nox
|
2 |
+
|
3 |
+
LOCATIONS = ("src/", "tests/", "noxfile.py", "setup.py")
|
4 |
+
|
5 |
+
|
6 |
+
@nox.session
|
7 |
+
def lint(session):
|
8 |
+
session.install("flake8")
|
9 |
+
session.install("flake8-bugbear")
|
10 |
+
session.install("flake8-isort")
|
11 |
+
session.install("black==24.3.0")
|
12 |
+
|
13 |
+
args = session.posargs or LOCATIONS
|
14 |
+
session.run("flake8", *args)
|
15 |
+
session.run("black", "--check", "--diff", *args)
|
16 |
+
|
17 |
+
|
18 |
+
@nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12"])
|
19 |
+
def tests(session):
|
20 |
+
session.install(
|
21 |
+
"torch==2.2.1",
|
22 |
+
"torchvision",
|
23 |
+
"--index-url",
|
24 |
+
"https://download.pytorch.org/whl/cpu",
|
25 |
+
)
|
26 |
+
session.install(".")
|
27 |
+
session.install("pytest")
|
28 |
+
session.install("pytest-mock")
|
29 |
+
session.run("pytest", *session.posargs)
|
pyproject.toml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
target-version = ["py311"]
|
3 |
+
|
4 |
+
[tool.isort]
|
5 |
+
profile = "black"
|
6 |
+
line_length = 88
|
7 |
+
multi_line_output = 3
|
setup.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import setuptools
|
4 |
+
|
5 |
+
|
6 |
+
def read(rel_path):
|
7 |
+
base_path = os.path.abspath(os.path.dirname(__file__))
|
8 |
+
with open(os.path.join(base_path, rel_path), "r") as f:
|
9 |
+
return f.read()
|
10 |
+
|
11 |
+
|
12 |
+
def get_version(rel_path):
|
13 |
+
for line in read(rel_path).splitlines():
|
14 |
+
if line.startswith("__version__"):
|
15 |
+
# __version__ = "0.9"
|
16 |
+
delim = '"' if '"' in line else "'"
|
17 |
+
return line.split(delim)[1]
|
18 |
+
|
19 |
+
raise RuntimeError("Unable to find version string.")
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
setuptools.setup(
|
24 |
+
name="pytorch-fid",
|
25 |
+
version=get_version(os.path.join("src", "pytorch_fid", "__init__.py")),
|
26 |
+
author="Max Seitzer",
|
27 |
+
description=(
|
28 |
+
"Package for calculating Frechet Inception Distance (FID)" " using PyTorch"
|
29 |
+
),
|
30 |
+
long_description=read("README.md"),
|
31 |
+
long_description_content_type="text/markdown",
|
32 |
+
url="https://github.com/mseitzer/pytorch-fid",
|
33 |
+
package_dir={"": "src"},
|
34 |
+
packages=setuptools.find_packages(where="src"),
|
35 |
+
classifiers=[
|
36 |
+
"Programming Language :: Python :: 3",
|
37 |
+
"License :: OSI Approved :: Apache Software License",
|
38 |
+
],
|
39 |
+
python_requires=">=3.5",
|
40 |
+
entry_points={
|
41 |
+
"console_scripts": [
|
42 |
+
"pytorch-fid = pytorch_fid.fid_score:main",
|
43 |
+
],
|
44 |
+
},
|
45 |
+
install_requires=[
|
46 |
+
"numpy",
|
47 |
+
"pillow",
|
48 |
+
"scipy",
|
49 |
+
"torch>=1.0.1",
|
50 |
+
"torchvision>=0.2.2",
|
51 |
+
],
|
52 |
+
extras_require={
|
53 |
+
"dev": ["flake8", "flake8-bugbear", "flake8-isort", "black==24.3.0", "nox"]
|
54 |
+
},
|
55 |
+
)
|
src/pytorch_fid/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.3.0"
|
src/pytorch_fid/__main__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_fid.fid_score
|
2 |
+
|
3 |
+
pytorch_fid.fid_score.main()
|
src/pytorch_fid/fid_score.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
2 |
+
|
3 |
+
The FID metric calculates the distance between two distributions of images.
|
4 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
5 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
6 |
+
|
7 |
+
When run as a stand-alone program, it compares the distribution of
|
8 |
+
images that are stored as PNG/JPEG at a specified location with a
|
9 |
+
distribution given by summary statistics (in pickle format).
|
10 |
+
|
11 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
12 |
+
the pool_3 layer of the inception net for generated samples and real world
|
13 |
+
samples respectively.
|
14 |
+
|
15 |
+
See --help to see further details.
|
16 |
+
|
17 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
18 |
+
of Tensorflow
|
19 |
+
|
20 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
21 |
+
|
22 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
23 |
+
you may not use this file except in compliance with the License.
|
24 |
+
You may obtain a copy of the License at
|
25 |
+
|
26 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
27 |
+
|
28 |
+
Unless required by applicable law or agreed to in writing, software
|
29 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
30 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
31 |
+
See the License for the specific language governing permissions and
|
32 |
+
limitations under the License.
|
33 |
+
"""
|
34 |
+
|
35 |
+
import os
|
36 |
+
import pathlib
|
37 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
38 |
+
|
39 |
+
import numpy as np
|
40 |
+
import torch
|
41 |
+
import torchvision.transforms as TF
|
42 |
+
from PIL import Image
|
43 |
+
from scipy import linalg
|
44 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
45 |
+
|
46 |
+
try:
|
47 |
+
from tqdm import tqdm
|
48 |
+
except ImportError:
|
49 |
+
# If tqdm is not available, provide a mock version of it
|
50 |
+
def tqdm(x):
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
from pytorch_fid.inception import InceptionV3
|
55 |
+
|
56 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
57 |
+
parser.add_argument("--batch-size", type=int, default=50, help="Batch size to use")
|
58 |
+
parser.add_argument(
|
59 |
+
"--num-workers",
|
60 |
+
type=int,
|
61 |
+
help=(
|
62 |
+
"Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`"
|
63 |
+
),
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--device", type=str, default=None, help="Device to use. Like cuda, cuda:0 or cpu"
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--dims",
|
70 |
+
type=int,
|
71 |
+
default=2048,
|
72 |
+
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
|
73 |
+
help=(
|
74 |
+
"Dimensionality of Inception features to use. "
|
75 |
+
"By default, uses pool3 features"
|
76 |
+
),
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--save-stats",
|
80 |
+
action="store_true",
|
81 |
+
help=(
|
82 |
+
"Generate an npz archive from a directory of "
|
83 |
+
"samples. The first path is used as input and the "
|
84 |
+
"second as output."
|
85 |
+
),
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"path",
|
89 |
+
type=str,
|
90 |
+
nargs=2,
|
91 |
+
help=("Paths to the generated images or " "to .npz statistic files"),
|
92 |
+
)
|
93 |
+
|
94 |
+
IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
|
95 |
+
|
96 |
+
|
97 |
+
class ImagePathDataset(torch.utils.data.Dataset):
|
98 |
+
def __init__(self, files, transforms=None):
|
99 |
+
self.files = files
|
100 |
+
self.transforms = transforms
|
101 |
+
|
102 |
+
def __len__(self):
|
103 |
+
return len(self.files)
|
104 |
+
|
105 |
+
def __getitem__(self, i):
|
106 |
+
path = self.files[i]
|
107 |
+
img = Image.open(path).convert("RGB")
|
108 |
+
if self.transforms is not None:
|
109 |
+
img = self.transforms(img)
|
110 |
+
return img
|
111 |
+
|
112 |
+
|
113 |
+
def get_activations(
|
114 |
+
files, model, batch_size=50, dims=2048, device="cpu", num_workers=1
|
115 |
+
):
|
116 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
117 |
+
|
118 |
+
Params:
|
119 |
+
-- files : List of image files paths
|
120 |
+
-- model : Instance of inception model
|
121 |
+
-- batch_size : Batch size of images for the model to process at once.
|
122 |
+
Make sure that the number of samples is a multiple of
|
123 |
+
the batch size, otherwise some samples are ignored. This
|
124 |
+
behavior is retained to match the original FID score
|
125 |
+
implementation.
|
126 |
+
-- dims : Dimensionality of features returned by Inception
|
127 |
+
-- device : Device to run calculations
|
128 |
+
-- num_workers : Number of parallel dataloader workers
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
132 |
+
activations of the given tensor when feeding inception with the
|
133 |
+
query tensor.
|
134 |
+
"""
|
135 |
+
model.eval()
|
136 |
+
|
137 |
+
if batch_size > len(files):
|
138 |
+
print(
|
139 |
+
(
|
140 |
+
"Warning: batch size is bigger than the data size. "
|
141 |
+
"Setting batch size to data size"
|
142 |
+
)
|
143 |
+
)
|
144 |
+
batch_size = len(files)
|
145 |
+
|
146 |
+
dataset = ImagePathDataset(files, transforms=TF.ToTensor())
|
147 |
+
dataloader = torch.utils.data.DataLoader(
|
148 |
+
dataset,
|
149 |
+
batch_size=batch_size,
|
150 |
+
shuffle=False,
|
151 |
+
drop_last=False,
|
152 |
+
num_workers=num_workers,
|
153 |
+
)
|
154 |
+
|
155 |
+
pred_arr = np.empty((len(files), dims))
|
156 |
+
|
157 |
+
start_idx = 0
|
158 |
+
|
159 |
+
for batch in tqdm(dataloader):
|
160 |
+
batch = batch.to(device)
|
161 |
+
|
162 |
+
with torch.no_grad():
|
163 |
+
pred = model(batch)[0]
|
164 |
+
|
165 |
+
# If model output is not scalar, apply global spatial average pooling.
|
166 |
+
# This happens if you choose a dimensionality not equal 2048.
|
167 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
168 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
169 |
+
|
170 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
171 |
+
|
172 |
+
pred_arr[start_idx : start_idx + pred.shape[0]] = pred
|
173 |
+
|
174 |
+
start_idx = start_idx + pred.shape[0]
|
175 |
+
|
176 |
+
return pred_arr
|
177 |
+
|
178 |
+
|
179 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
180 |
+
"""Numpy implementation of the Frechet Distance.
|
181 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
182 |
+
and X_2 ~ N(mu_2, C_2) is
|
183 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
184 |
+
|
185 |
+
Stable version by Dougal J. Sutherland.
|
186 |
+
|
187 |
+
Params:
|
188 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
189 |
+
inception net (like returned by the function 'get_predictions')
|
190 |
+
for generated samples.
|
191 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
192 |
+
representative data set.
|
193 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
194 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
195 |
+
representative data set.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
-- : The Frechet Distance.
|
199 |
+
"""
|
200 |
+
|
201 |
+
mu1 = np.atleast_1d(mu1)
|
202 |
+
mu2 = np.atleast_1d(mu2)
|
203 |
+
|
204 |
+
sigma1 = np.atleast_2d(sigma1)
|
205 |
+
sigma2 = np.atleast_2d(sigma2)
|
206 |
+
|
207 |
+
assert (
|
208 |
+
mu1.shape == mu2.shape
|
209 |
+
), "Training and test mean vectors have different lengths"
|
210 |
+
assert (
|
211 |
+
sigma1.shape == sigma2.shape
|
212 |
+
), "Training and test covariances have different dimensions"
|
213 |
+
|
214 |
+
diff = mu1 - mu2
|
215 |
+
|
216 |
+
# Product might be almost singular
|
217 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
218 |
+
if not np.isfinite(covmean).all():
|
219 |
+
msg = (
|
220 |
+
"fid calculation produces singular product; "
|
221 |
+
"adding %s to diagonal of cov estimates"
|
222 |
+
) % eps
|
223 |
+
print(msg)
|
224 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
225 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
226 |
+
|
227 |
+
# Numerical error might give slight imaginary component
|
228 |
+
if np.iscomplexobj(covmean):
|
229 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
230 |
+
m = np.max(np.abs(covmean.imag))
|
231 |
+
raise ValueError("Imaginary component {}".format(m))
|
232 |
+
covmean = covmean.real
|
233 |
+
|
234 |
+
tr_covmean = np.trace(covmean)
|
235 |
+
|
236 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
237 |
+
|
238 |
+
|
239 |
+
def calculate_activation_statistics(
|
240 |
+
files, model, batch_size=50, dims=2048, device="cpu", num_workers=1
|
241 |
+
):
|
242 |
+
"""Calculation of the statistics used by the FID.
|
243 |
+
Params:
|
244 |
+
-- files : List of image files paths
|
245 |
+
-- model : Instance of inception model
|
246 |
+
-- batch_size : The images numpy array is split into batches with
|
247 |
+
batch size batch_size. A reasonable batch size
|
248 |
+
depends on the hardware.
|
249 |
+
-- dims : Dimensionality of features returned by Inception
|
250 |
+
-- device : Device to run calculations
|
251 |
+
-- num_workers : Number of parallel dataloader workers
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
255 |
+
the inception model.
|
256 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
257 |
+
the inception model.
|
258 |
+
"""
|
259 |
+
act = get_activations(files, model, batch_size, dims, device, num_workers)
|
260 |
+
mu = np.mean(act, axis=0)
|
261 |
+
sigma = np.cov(act, rowvar=False)
|
262 |
+
return mu, sigma
|
263 |
+
|
264 |
+
|
265 |
+
def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=1):
|
266 |
+
if path.endswith(".npz"):
|
267 |
+
with np.load(path) as f:
|
268 |
+
m, s = f["mu"][:], f["sigma"][:]
|
269 |
+
else:
|
270 |
+
path = pathlib.Path(path)
|
271 |
+
files = sorted(
|
272 |
+
[file for ext in IMAGE_EXTENSIONS for file in path.glob("*.{}".format(ext))]
|
273 |
+
)
|
274 |
+
m, s = calculate_activation_statistics(
|
275 |
+
files, model, batch_size, dims, device, num_workers
|
276 |
+
)
|
277 |
+
|
278 |
+
return m, s
|
279 |
+
|
280 |
+
|
281 |
+
def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
|
282 |
+
"""Calculates the FID of two paths"""
|
283 |
+
for p in paths:
|
284 |
+
if not os.path.exists(p):
|
285 |
+
raise RuntimeError("Invalid path: %s" % p)
|
286 |
+
|
287 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
288 |
+
|
289 |
+
model = InceptionV3([block_idx]).to(device)
|
290 |
+
|
291 |
+
m1, s1 = compute_statistics_of_path(
|
292 |
+
paths[0], model, batch_size, dims, device, num_workers
|
293 |
+
)
|
294 |
+
m2, s2 = compute_statistics_of_path(
|
295 |
+
paths[1], model, batch_size, dims, device, num_workers
|
296 |
+
)
|
297 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
298 |
+
|
299 |
+
return fid_value
|
300 |
+
|
301 |
+
|
302 |
+
def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
|
303 |
+
"""Saves FID statistics of one path"""
|
304 |
+
if not os.path.exists(paths[0]):
|
305 |
+
raise RuntimeError("Invalid path: %s" % paths[0])
|
306 |
+
|
307 |
+
if os.path.exists(paths[1]):
|
308 |
+
raise RuntimeError("Existing output file: %s" % paths[1])
|
309 |
+
|
310 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
311 |
+
|
312 |
+
model = InceptionV3([block_idx]).to(device)
|
313 |
+
|
314 |
+
print(f"Saving statistics for {paths[0]}")
|
315 |
+
|
316 |
+
m1, s1 = compute_statistics_of_path(
|
317 |
+
paths[0], model, batch_size, dims, device, num_workers
|
318 |
+
)
|
319 |
+
|
320 |
+
np.savez_compressed(paths[1], mu=m1, sigma=s1)
|
321 |
+
|
322 |
+
|
323 |
+
def main():
|
324 |
+
args = parser.parse_args()
|
325 |
+
|
326 |
+
if args.device is None:
|
327 |
+
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
|
328 |
+
else:
|
329 |
+
device = torch.device(args.device)
|
330 |
+
|
331 |
+
if args.num_workers is None:
|
332 |
+
try:
|
333 |
+
num_cpus = len(os.sched_getaffinity(0))
|
334 |
+
except AttributeError:
|
335 |
+
# os.sched_getaffinity is not available under Windows, use
|
336 |
+
# os.cpu_count instead (which may not return the *available* number
|
337 |
+
# of CPUs).
|
338 |
+
num_cpus = os.cpu_count()
|
339 |
+
|
340 |
+
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
|
341 |
+
else:
|
342 |
+
num_workers = args.num_workers
|
343 |
+
|
344 |
+
if args.save_stats:
|
345 |
+
save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
|
346 |
+
return
|
347 |
+
|
348 |
+
fid_value = calculate_fid_given_paths(
|
349 |
+
args.path, args.batch_size, device, args.dims, num_workers
|
350 |
+
)
|
351 |
+
print("FID: ", fid_value)
|
352 |
+
|
353 |
+
|
354 |
+
if __name__ == "__main__":
|
355 |
+
main()
|
src/pytorch_fid/inception.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3, # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
34 |
+
resize_input=True,
|
35 |
+
normalize_input=True,
|
36 |
+
requires_grad=False,
|
37 |
+
use_fid_inception=True,
|
38 |
+
):
|
39 |
+
"""Build pretrained InceptionV3
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
output_blocks : list of int
|
44 |
+
Indices of blocks to return features of. Possible values are:
|
45 |
+
- 0: corresponds to output of first max pooling
|
46 |
+
- 1: corresponds to output of second max pooling
|
47 |
+
- 2: corresponds to output which is fed to aux classifier
|
48 |
+
- 3: corresponds to output of final average pooling
|
49 |
+
resize_input : bool
|
50 |
+
If true, bilinearly resizes input to width and height 299 before
|
51 |
+
feeding input to model. As the network without fully connected
|
52 |
+
layers is fully convolutional, it should be able to handle inputs
|
53 |
+
of arbitrary size, so resizing might not be strictly needed
|
54 |
+
normalize_input : bool
|
55 |
+
If true, scales the input from range (0, 1) to the range the
|
56 |
+
pretrained Inception network expects, namely (-1, 1)
|
57 |
+
requires_grad : bool
|
58 |
+
If true, parameters of the model require gradients. Possibly useful
|
59 |
+
for finetuning the network
|
60 |
+
use_fid_inception : bool
|
61 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
62 |
+
FID implementation. If false, uses the pretrained Inception model
|
63 |
+
available in torchvision. The FID Inception model has different
|
64 |
+
weights and a slightly different structure from torchvision's
|
65 |
+
Inception model. If you want to compute FID scores, you are
|
66 |
+
strongly advised to set this parameter to true to get comparable
|
67 |
+
results.
|
68 |
+
"""
|
69 |
+
super(InceptionV3, self).__init__()
|
70 |
+
|
71 |
+
self.resize_input = resize_input
|
72 |
+
self.normalize_input = normalize_input
|
73 |
+
self.output_blocks = sorted(output_blocks)
|
74 |
+
self.last_needed_block = max(output_blocks)
|
75 |
+
|
76 |
+
assert self.last_needed_block <= 3, "Last possible output block index is 3"
|
77 |
+
|
78 |
+
self.blocks = nn.ModuleList()
|
79 |
+
|
80 |
+
if use_fid_inception:
|
81 |
+
inception = fid_inception_v3()
|
82 |
+
else:
|
83 |
+
inception = _inception_v3(weights="DEFAULT")
|
84 |
+
|
85 |
+
# Block 0: input to maxpool1
|
86 |
+
block0 = [
|
87 |
+
inception.Conv2d_1a_3x3,
|
88 |
+
inception.Conv2d_2a_3x3,
|
89 |
+
inception.Conv2d_2b_3x3,
|
90 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
91 |
+
]
|
92 |
+
self.blocks.append(nn.Sequential(*block0))
|
93 |
+
|
94 |
+
# Block 1: maxpool1 to maxpool2
|
95 |
+
if self.last_needed_block >= 1:
|
96 |
+
block1 = [
|
97 |
+
inception.Conv2d_3b_1x1,
|
98 |
+
inception.Conv2d_4a_3x3,
|
99 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
100 |
+
]
|
101 |
+
self.blocks.append(nn.Sequential(*block1))
|
102 |
+
|
103 |
+
# Block 2: maxpool2 to aux classifier
|
104 |
+
if self.last_needed_block >= 2:
|
105 |
+
block2 = [
|
106 |
+
inception.Mixed_5b,
|
107 |
+
inception.Mixed_5c,
|
108 |
+
inception.Mixed_5d,
|
109 |
+
inception.Mixed_6a,
|
110 |
+
inception.Mixed_6b,
|
111 |
+
inception.Mixed_6c,
|
112 |
+
inception.Mixed_6d,
|
113 |
+
inception.Mixed_6e,
|
114 |
+
]
|
115 |
+
self.blocks.append(nn.Sequential(*block2))
|
116 |
+
|
117 |
+
# Block 3: aux classifier to final avgpool
|
118 |
+
if self.last_needed_block >= 3:
|
119 |
+
block3 = [
|
120 |
+
inception.Mixed_7a,
|
121 |
+
inception.Mixed_7b,
|
122 |
+
inception.Mixed_7c,
|
123 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
124 |
+
]
|
125 |
+
self.blocks.append(nn.Sequential(*block3))
|
126 |
+
|
127 |
+
for param in self.parameters():
|
128 |
+
param.requires_grad = requires_grad
|
129 |
+
|
130 |
+
def forward(self, inp):
|
131 |
+
"""Get Inception feature maps
|
132 |
+
|
133 |
+
Parameters
|
134 |
+
----------
|
135 |
+
inp : torch.autograd.Variable
|
136 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
137 |
+
range (0, 1)
|
138 |
+
|
139 |
+
Returns
|
140 |
+
-------
|
141 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
142 |
+
block, sorted ascending by index
|
143 |
+
"""
|
144 |
+
outp = []
|
145 |
+
x = inp
|
146 |
+
|
147 |
+
if self.resize_input:
|
148 |
+
x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
|
149 |
+
|
150 |
+
if self.normalize_input:
|
151 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
152 |
+
|
153 |
+
for idx, block in enumerate(self.blocks):
|
154 |
+
x = block(x)
|
155 |
+
if idx in self.output_blocks:
|
156 |
+
outp.append(x)
|
157 |
+
|
158 |
+
if idx == self.last_needed_block:
|
159 |
+
break
|
160 |
+
|
161 |
+
return outp
|
162 |
+
|
163 |
+
|
164 |
+
def _inception_v3(*args, **kwargs):
|
165 |
+
"""Wraps `torchvision.models.inception_v3`"""
|
166 |
+
try:
|
167 |
+
version = tuple(map(int, torchvision.__version__.split(".")[:2]))
|
168 |
+
except ValueError:
|
169 |
+
# Just a caution against weird version strings
|
170 |
+
version = (0,)
|
171 |
+
|
172 |
+
# Skips default weight inititialization if supported by torchvision
|
173 |
+
# version. See https://github.com/mseitzer/pytorch-fid/issues/28.
|
174 |
+
if version >= (0, 6):
|
175 |
+
kwargs["init_weights"] = False
|
176 |
+
|
177 |
+
# Backwards compatibility: `weights` argument was handled by `pretrained`
|
178 |
+
# argument prior to version 0.13.
|
179 |
+
if version < (0, 13) and "weights" in kwargs:
|
180 |
+
if kwargs["weights"] == "DEFAULT":
|
181 |
+
kwargs["pretrained"] = True
|
182 |
+
elif kwargs["weights"] is None:
|
183 |
+
kwargs["pretrained"] = False
|
184 |
+
else:
|
185 |
+
raise ValueError(
|
186 |
+
"weights=={} not supported in torchvision {}".format(
|
187 |
+
kwargs["weights"], torchvision.__version__
|
188 |
+
)
|
189 |
+
)
|
190 |
+
del kwargs["weights"]
|
191 |
+
|
192 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
193 |
+
|
194 |
+
|
195 |
+
def fid_inception_v3():
|
196 |
+
"""Build pretrained Inception model for FID computation
|
197 |
+
|
198 |
+
The Inception model for FID computation uses a different set of weights
|
199 |
+
and has a slightly different structure than torchvision's Inception.
|
200 |
+
|
201 |
+
This method first constructs torchvision's Inception and then patches the
|
202 |
+
necessary parts that are different in the FID Inception model.
|
203 |
+
"""
|
204 |
+
inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None)
|
205 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
206 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
207 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
208 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
209 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
210 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
211 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
212 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
213 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
214 |
+
|
215 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
216 |
+
inception.load_state_dict(state_dict)
|
217 |
+
return inception
|
218 |
+
|
219 |
+
|
220 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
221 |
+
"""InceptionA block patched for FID computation"""
|
222 |
+
|
223 |
+
def __init__(self, in_channels, pool_features):
|
224 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
225 |
+
|
226 |
+
def forward(self, x):
|
227 |
+
branch1x1 = self.branch1x1(x)
|
228 |
+
|
229 |
+
branch5x5 = self.branch5x5_1(x)
|
230 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
231 |
+
|
232 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
233 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
234 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
235 |
+
|
236 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
237 |
+
# its average calculation
|
238 |
+
branch_pool = F.avg_pool2d(
|
239 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
240 |
+
)
|
241 |
+
branch_pool = self.branch_pool(branch_pool)
|
242 |
+
|
243 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
244 |
+
return torch.cat(outputs, 1)
|
245 |
+
|
246 |
+
|
247 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
248 |
+
"""InceptionC block patched for FID computation"""
|
249 |
+
|
250 |
+
def __init__(self, in_channels, channels_7x7):
|
251 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
branch1x1 = self.branch1x1(x)
|
255 |
+
|
256 |
+
branch7x7 = self.branch7x7_1(x)
|
257 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
258 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
259 |
+
|
260 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
261 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
262 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
263 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
264 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
265 |
+
|
266 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
267 |
+
# its average calculation
|
268 |
+
branch_pool = F.avg_pool2d(
|
269 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
270 |
+
)
|
271 |
+
branch_pool = self.branch_pool(branch_pool)
|
272 |
+
|
273 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
274 |
+
return torch.cat(outputs, 1)
|
275 |
+
|
276 |
+
|
277 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
278 |
+
"""First InceptionE block patched for FID computation"""
|
279 |
+
|
280 |
+
def __init__(self, in_channels):
|
281 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
282 |
+
|
283 |
+
def forward(self, x):
|
284 |
+
branch1x1 = self.branch1x1(x)
|
285 |
+
|
286 |
+
branch3x3 = self.branch3x3_1(x)
|
287 |
+
branch3x3 = [
|
288 |
+
self.branch3x3_2a(branch3x3),
|
289 |
+
self.branch3x3_2b(branch3x3),
|
290 |
+
]
|
291 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
292 |
+
|
293 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
294 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
295 |
+
branch3x3dbl = [
|
296 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
297 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
298 |
+
]
|
299 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
300 |
+
|
301 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
302 |
+
# its average calculation
|
303 |
+
branch_pool = F.avg_pool2d(
|
304 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
305 |
+
)
|
306 |
+
branch_pool = self.branch_pool(branch_pool)
|
307 |
+
|
308 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
309 |
+
return torch.cat(outputs, 1)
|
310 |
+
|
311 |
+
|
312 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
313 |
+
"""Second InceptionE block patched for FID computation"""
|
314 |
+
|
315 |
+
def __init__(self, in_channels):
|
316 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
branch1x1 = self.branch1x1(x)
|
320 |
+
|
321 |
+
branch3x3 = self.branch3x3_1(x)
|
322 |
+
branch3x3 = [
|
323 |
+
self.branch3x3_2a(branch3x3),
|
324 |
+
self.branch3x3_2b(branch3x3),
|
325 |
+
]
|
326 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
327 |
+
|
328 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
329 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
330 |
+
branch3x3dbl = [
|
331 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
332 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
333 |
+
]
|
334 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
335 |
+
|
336 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
337 |
+
# pooling. This is likely an error in this specific Inception
|
338 |
+
# implementation, as other Inception models use average pooling here
|
339 |
+
# (which matches the description in the paper).
|
340 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
341 |
+
branch_pool = self.branch_pool(branch_pool)
|
342 |
+
|
343 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
344 |
+
return torch.cat(outputs, 1)
|
tests/test_fid_score.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pytest
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from pytorch_fid import fid_score, inception
|
7 |
+
|
8 |
+
|
9 |
+
@pytest.fixture
|
10 |
+
def device():
|
11 |
+
return torch.device("cpu")
|
12 |
+
|
13 |
+
|
14 |
+
def test_calculate_fid_given_statistics(mocker, tmp_path, device):
|
15 |
+
dim = 2048
|
16 |
+
m1, m2 = np.zeros((dim,)), np.ones((dim,))
|
17 |
+
sigma = np.eye(dim)
|
18 |
+
|
19 |
+
def dummy_statistics(path, model, batch_size, dims, device, num_workers):
|
20 |
+
if path.endswith("1"):
|
21 |
+
return m1, sigma
|
22 |
+
elif path.endswith("2"):
|
23 |
+
return m2, sigma
|
24 |
+
else:
|
25 |
+
raise ValueError
|
26 |
+
|
27 |
+
mocker.patch(
|
28 |
+
"pytorch_fid.fid_score.compute_statistics_of_path", side_effect=dummy_statistics
|
29 |
+
)
|
30 |
+
|
31 |
+
dir_names = ["1", "2"]
|
32 |
+
paths = []
|
33 |
+
for name in dir_names:
|
34 |
+
path = tmp_path / name
|
35 |
+
path.mkdir()
|
36 |
+
paths.append(str(path))
|
37 |
+
|
38 |
+
fid_value = fid_score.calculate_fid_given_paths(
|
39 |
+
paths, batch_size=dim, device=device, dims=dim, num_workers=0
|
40 |
+
)
|
41 |
+
|
42 |
+
# Given equal covariance, FID is just the squared norm of difference
|
43 |
+
assert fid_value == np.sum((m1 - m2) ** 2)
|
44 |
+
|
45 |
+
|
46 |
+
def test_compute_statistics_of_path(mocker, tmp_path, device):
|
47 |
+
model = mocker.MagicMock(inception.InceptionV3)()
|
48 |
+
model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)]
|
49 |
+
|
50 |
+
size = (4, 4, 3)
|
51 |
+
arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)]
|
52 |
+
images = [(arr * 255).astype(np.uint8) for arr in arrays]
|
53 |
+
|
54 |
+
paths = []
|
55 |
+
for idx, image in enumerate(images):
|
56 |
+
paths.append(str(tmp_path / "{}.png".format(idx)))
|
57 |
+
Image.fromarray(image, mode="RGB").save(paths[-1])
|
58 |
+
|
59 |
+
stats = fid_score.compute_statistics_of_path(
|
60 |
+
str(tmp_path),
|
61 |
+
model,
|
62 |
+
batch_size=len(images),
|
63 |
+
dims=3,
|
64 |
+
device=device,
|
65 |
+
num_workers=0,
|
66 |
+
)
|
67 |
+
|
68 |
+
assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3)
|
69 |
+
assert np.allclose(stats[1], np.ones((3, 3)) * 0.25)
|
70 |
+
|
71 |
+
|
72 |
+
def test_compute_statistics_of_path_from_file(mocker, tmp_path, device):
|
73 |
+
model = mocker.MagicMock(inception.InceptionV3)()
|
74 |
+
|
75 |
+
mu = np.random.randn(5)
|
76 |
+
sigma = np.random.randn(5, 5)
|
77 |
+
|
78 |
+
path = tmp_path / "stats.npz"
|
79 |
+
with path.open("wb") as f:
|
80 |
+
np.savez(f, mu=mu, sigma=sigma)
|
81 |
+
|
82 |
+
stats = fid_score.compute_statistics_of_path(
|
83 |
+
str(path), model, batch_size=1, dims=5, device=device, num_workers=0
|
84 |
+
)
|
85 |
+
|
86 |
+
assert np.allclose(stats[0], mu)
|
87 |
+
assert np.allclose(stats[1], sigma)
|
88 |
+
|
89 |
+
|
90 |
+
def test_image_types(tmp_path):
|
91 |
+
in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255
|
92 |
+
in_image = Image.fromarray(in_arr, mode="RGB")
|
93 |
+
|
94 |
+
paths = []
|
95 |
+
for ext in fid_score.IMAGE_EXTENSIONS:
|
96 |
+
paths.append(str(tmp_path / "img.{}".format(ext)))
|
97 |
+
in_image.save(paths[-1])
|
98 |
+
|
99 |
+
dataset = fid_score.ImagePathDataset(paths)
|
100 |
+
|
101 |
+
for img in dataset:
|
102 |
+
assert np.allclose(np.array(img), in_arr)
|