AustingDong
commited on
Commit
·
1ca9e3b
1
Parent(s):
d95dc04
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +421 -0
- .gradio/certificate.pem +31 -0
- Dockerfile +41 -0
- LICENSE-CODE +21 -0
- LICENSE-MODEL +91 -0
- Makefile +99 -0
- app.py +398 -0
- demo/Janus_colab_demo.ipynb +0 -0
- demo/app.py +224 -0
- demo/app_janusflow.py +247 -0
- demo/app_januspro.py +294 -0
- demo/app_vqa.py +333 -0
- demo/cam.py +486 -0
- demo/demo.ipynb +0 -0
- demo/demo_attn.ipynb +0 -0
- demo/fastapi_app.py +178 -0
- demo/fastapi_client.py +78 -0
- demo/model_utils.py +208 -0
- demo/modify_llama.py +11 -0
- demo/visualize_architecture.ipynb +1715 -0
- images/AreaChart.png +0 -0
- images/BarChart.png +0 -0
- images/BubbleChart.png +0 -0
- images/Choropleth_New.png +0 -0
- images/Histogram.png +0 -0
- images/LineChart.png +0 -0
- images/PieChart.png +0 -0
- images/Scatterplot.png +0 -0
- images/Stacked100.png +0 -0
- images/StackedArea.png +0 -0
- images/StackedBar.png +0 -0
- images/TreeMap.png +0 -0
- images/badge.svg +1 -0
- images/cat_dog.png +0 -0
- images/doge.png +0 -0
- images/equation.png +0 -0
- images/logo.png +0 -0
- images/logo.svg +22 -0
- images/pie_chart.png +0 -0
- images/ve.png +0 -0
- janus/__init__.py +31 -0
- janus/janusflow/__init__.py +31 -0
- janus/janusflow/models/__init__.py +28 -0
- janus/janusflow/models/clip_encoder.py +122 -0
- janus/janusflow/models/image_processing_vlm.py +208 -0
- janus/janusflow/models/modeling_vlm.py +226 -0
- janus/janusflow/models/processing_vlm.py +455 -0
- janus/janusflow/models/siglip_vit.py +691 -0
- janus/janusflow/models/uvit.py +714 -0
- janus/models/__init__.py +28 -0
.gitignore
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##### Python.gitignore #####
|
2 |
+
# Byte-compiled / optimized / DLL files
|
3 |
+
**/__pycache__/
|
4 |
+
*.pyc
|
5 |
+
*.pyo
|
6 |
+
*.pyd
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
wheelhouse/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
*.whl
|
34 |
+
|
35 |
+
# PyInstaller
|
36 |
+
# Usually these files are written by a python script from a template
|
37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
38 |
+
*.manifest
|
39 |
+
*.spec
|
40 |
+
|
41 |
+
# Installer logs
|
42 |
+
pip-log.txt
|
43 |
+
pip-delete-this-directory.txt
|
44 |
+
|
45 |
+
# Unit test / coverage reports
|
46 |
+
htmlcov/
|
47 |
+
.tox/
|
48 |
+
.nox/
|
49 |
+
.coverage
|
50 |
+
.coverage.*
|
51 |
+
.cache
|
52 |
+
nosetests.xml
|
53 |
+
coverage.xml
|
54 |
+
*.cover
|
55 |
+
*.py,cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
cover/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
docs/source/_build/
|
80 |
+
_autosummary/
|
81 |
+
|
82 |
+
# PyBuilder
|
83 |
+
.pybuilder/
|
84 |
+
target/
|
85 |
+
|
86 |
+
# Jupyter Notebook
|
87 |
+
.ipynb_checkpoints
|
88 |
+
|
89 |
+
# IPython
|
90 |
+
profile_default/
|
91 |
+
ipython_config.py
|
92 |
+
|
93 |
+
# pyenv
|
94 |
+
# For a library or package, you might want to ignore these files since the code is
|
95 |
+
# intended to run in multiple environments; otherwise, check them in:
|
96 |
+
.python-version
|
97 |
+
|
98 |
+
# pipenv
|
99 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
100 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
101 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
102 |
+
# install all needed dependencies.
|
103 |
+
#Pipfile.lock
|
104 |
+
|
105 |
+
# poetry
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
107 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
108 |
+
# commonly ignored for libraries.
|
109 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
110 |
+
#poetry.lock
|
111 |
+
|
112 |
+
# pdm
|
113 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
114 |
+
#pdm.lock
|
115 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
116 |
+
# in version control.
|
117 |
+
# https://pdm.fming.dev/#use-with-ide
|
118 |
+
.pdm.toml
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# ruff
|
150 |
+
.ruff_cache/
|
151 |
+
|
152 |
+
# mypy
|
153 |
+
.mypy_cache/
|
154 |
+
.dmypy.json
|
155 |
+
dmypy.json
|
156 |
+
|
157 |
+
# Pyre type checker
|
158 |
+
.pyre/
|
159 |
+
|
160 |
+
# pytype static type analyzer
|
161 |
+
.pytype/
|
162 |
+
|
163 |
+
# Cython debug symbols
|
164 |
+
cython_debug/
|
165 |
+
|
166 |
+
# PyCharm
|
167 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
168 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
169 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
170 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
171 |
+
.idea/
|
172 |
+
|
173 |
+
|
174 |
+
##### macOS.gitignore #####
|
175 |
+
# General
|
176 |
+
.DS_Store
|
177 |
+
.AppleDouble
|
178 |
+
.LSOverride
|
179 |
+
|
180 |
+
# Icon must end with two \r
|
181 |
+
Icon
|
182 |
+
|
183 |
+
# Thumbnails
|
184 |
+
._*
|
185 |
+
|
186 |
+
# Files that might appear in the root of a volume
|
187 |
+
.DocumentRevisions-V100
|
188 |
+
.fseventsd
|
189 |
+
.Spotlight-V100
|
190 |
+
.TemporaryItems
|
191 |
+
.Trashes
|
192 |
+
.VolumeIcon.icns
|
193 |
+
.com.apple.timemachine.donotpresent
|
194 |
+
|
195 |
+
# Directories potentially created on remote AFP share
|
196 |
+
.AppleDB
|
197 |
+
.AppleDesktop
|
198 |
+
Network Trash Folder
|
199 |
+
Temporary Items
|
200 |
+
.apdisk
|
201 |
+
|
202 |
+
|
203 |
+
##### Linux.gitignore #####
|
204 |
+
*~
|
205 |
+
|
206 |
+
# Temporary files which can be created if a process still has a handle open of a deleted file
|
207 |
+
.fuse_hidden*
|
208 |
+
|
209 |
+
# KDE directory preferences
|
210 |
+
.directory
|
211 |
+
|
212 |
+
# Linux trash folder which might appear on any partition or disk
|
213 |
+
.Trash-*
|
214 |
+
|
215 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
216 |
+
.nfs*
|
217 |
+
|
218 |
+
|
219 |
+
##### Windows.gitignore #####
|
220 |
+
# Windows thumbnail cache files
|
221 |
+
Thumbs.db
|
222 |
+
Thumbs.db:encryptable
|
223 |
+
ehthumbs.db
|
224 |
+
ehthumbs_vista.db
|
225 |
+
|
226 |
+
# Dump file
|
227 |
+
*.stackdump
|
228 |
+
|
229 |
+
# Folder config file
|
230 |
+
[Dd]esktop.ini
|
231 |
+
|
232 |
+
# Recycle Bin used on file shares
|
233 |
+
$RECYCLE.BIN/
|
234 |
+
|
235 |
+
# Windows Installer files
|
236 |
+
*.cab
|
237 |
+
*.msi
|
238 |
+
*.msix
|
239 |
+
*.msm
|
240 |
+
*.msp
|
241 |
+
|
242 |
+
# Windows shortcuts
|
243 |
+
*.lnk
|
244 |
+
|
245 |
+
|
246 |
+
##### Archives.gitignore #####
|
247 |
+
# It's better to unpack these files and commit the raw source because
|
248 |
+
# git has its own built in compression methods.
|
249 |
+
*.7z
|
250 |
+
*.jar
|
251 |
+
*.rar
|
252 |
+
*.zip
|
253 |
+
*.gz
|
254 |
+
*.gzip
|
255 |
+
*.tgz
|
256 |
+
*.bzip
|
257 |
+
*.bzip2
|
258 |
+
*.bz2
|
259 |
+
*.xz
|
260 |
+
*.lzma
|
261 |
+
*.cab
|
262 |
+
*.xar
|
263 |
+
|
264 |
+
# Packing-only formats
|
265 |
+
*.iso
|
266 |
+
*.tar
|
267 |
+
|
268 |
+
# Package management formats
|
269 |
+
*.dmg
|
270 |
+
*.xpi
|
271 |
+
*.gem
|
272 |
+
*.egg
|
273 |
+
*.deb
|
274 |
+
*.rpm
|
275 |
+
*.msi
|
276 |
+
*.msm
|
277 |
+
*.msp
|
278 |
+
*.txz
|
279 |
+
|
280 |
+
|
281 |
+
##### Xcode.gitignore #####
|
282 |
+
# Xcode
|
283 |
+
#
|
284 |
+
# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
|
285 |
+
|
286 |
+
## User settings
|
287 |
+
xcuserdata/
|
288 |
+
|
289 |
+
## Compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9)
|
290 |
+
*.xcscmblueprint
|
291 |
+
*.xccheckout
|
292 |
+
|
293 |
+
## Compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4)
|
294 |
+
build/
|
295 |
+
DerivedData/
|
296 |
+
*.moved-aside
|
297 |
+
*.pbxuser
|
298 |
+
!default.pbxuser
|
299 |
+
*.mode1v3
|
300 |
+
!default.mode1v3
|
301 |
+
*.mode2v3
|
302 |
+
!default.mode2v3
|
303 |
+
*.perspectivev3
|
304 |
+
!default.perspectivev3
|
305 |
+
|
306 |
+
## Gcc Patch
|
307 |
+
/*.gcno
|
308 |
+
|
309 |
+
|
310 |
+
##### JetBrains.gitignore #####
|
311 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
|
312 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
313 |
+
|
314 |
+
# User settings
|
315 |
+
.idea/*
|
316 |
+
|
317 |
+
# User-specific stuff
|
318 |
+
.idea/**/workspace.xml
|
319 |
+
.idea/**/tasks.xml
|
320 |
+
.idea/**/usage.statistics.xml
|
321 |
+
.idea/**/dictionaries
|
322 |
+
.idea/**/shelf
|
323 |
+
|
324 |
+
# Generated files
|
325 |
+
.idea/**/contentModel.xml
|
326 |
+
|
327 |
+
# Sensitive or high-churn files
|
328 |
+
.idea/**/dataSources/
|
329 |
+
.idea/**/dataSources.ids
|
330 |
+
.idea/**/dataSources.local.xml
|
331 |
+
.idea/**/sqlDataSources.xml
|
332 |
+
.idea/**/dynamic.xml
|
333 |
+
.idea/**/uiDesigner.xml
|
334 |
+
.idea/**/dbnavigator.xml
|
335 |
+
|
336 |
+
# Gradle
|
337 |
+
.idea/**/gradle.xml
|
338 |
+
.idea/**/libraries
|
339 |
+
|
340 |
+
# Gradle and Maven with auto-import
|
341 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
342 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
343 |
+
# auto-import.
|
344 |
+
# .idea/artifacts
|
345 |
+
# .idea/compiler.xml
|
346 |
+
# .idea/jarRepositories.xml
|
347 |
+
# .idea/modules.xml
|
348 |
+
# .idea/*.iml
|
349 |
+
# .idea/modules
|
350 |
+
# *.iml
|
351 |
+
# *.ipr
|
352 |
+
|
353 |
+
# CMake
|
354 |
+
cmake-build-*/
|
355 |
+
|
356 |
+
# Mongo Explorer plugin
|
357 |
+
.idea/**/mongoSettings.xml
|
358 |
+
|
359 |
+
# File-based project format
|
360 |
+
*.iws
|
361 |
+
|
362 |
+
# IntelliJ
|
363 |
+
out/
|
364 |
+
|
365 |
+
# mpeltonen/sbt-idea plugin
|
366 |
+
.idea_modules/
|
367 |
+
|
368 |
+
# JIRA plugin
|
369 |
+
atlassian-ide-plugin.xml
|
370 |
+
|
371 |
+
# Cursive Clojure plugin
|
372 |
+
.idea/replstate.xml
|
373 |
+
|
374 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
375 |
+
com_crashlytics_export_strings.xml
|
376 |
+
crashlytics.properties
|
377 |
+
crashlytics-build.properties
|
378 |
+
fabric.properties
|
379 |
+
|
380 |
+
# Editor-based Rest Client
|
381 |
+
.idea/httpRequests
|
382 |
+
|
383 |
+
# Android studio 3.1+ serialized cache file
|
384 |
+
.idea/caches/build_file_checksums.ser
|
385 |
+
|
386 |
+
|
387 |
+
##### VisualStudioCode.gitignore #####
|
388 |
+
.vscode/*
|
389 |
+
# !.vscode/settings.json
|
390 |
+
# !.vscode/tasks.json
|
391 |
+
# !.vscode/launch.json
|
392 |
+
!.vscode/extensions.json
|
393 |
+
*.code-workspace
|
394 |
+
|
395 |
+
# Local History for Visual Studio Code
|
396 |
+
.history/
|
397 |
+
|
398 |
+
|
399 |
+
##### Vim.gitignore #####
|
400 |
+
# Swap
|
401 |
+
.*.s[a-v][a-z]
|
402 |
+
!*.svg # comment out if you don't need vector files
|
403 |
+
.*.sw[a-p]
|
404 |
+
.s[a-rt-v][a-z]
|
405 |
+
.ss[a-gi-z]
|
406 |
+
.sw[a-p]
|
407 |
+
|
408 |
+
# Session
|
409 |
+
Session.vim
|
410 |
+
Sessionx.vim
|
411 |
+
|
412 |
+
# Temporary
|
413 |
+
.netrwhist
|
414 |
+
*~
|
415 |
+
# Auto-generated tag files
|
416 |
+
tags
|
417 |
+
# Persistent undo
|
418 |
+
[._]*.un~
|
419 |
+
.vscode
|
420 |
+
.github
|
421 |
+
generated_samples/
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
Dockerfile
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
|
3 |
+
COPY ./requirements-gradio.txt /code/requirements-gradio.txt
|
4 |
+
|
5 |
+
# Install system dependencies and create user
|
6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
7 |
+
&& useradd -m -u 1000 user \
|
8 |
+
&& rm -rf /var/lib/apt/lists/*
|
9 |
+
|
10 |
+
# Install OpenGL and other dependencies required for OpenCV
|
11 |
+
RUN apt-get update && apt-get install -y \
|
12 |
+
libgl1-mesa-glx \
|
13 |
+
libglib2.0-0 \
|
14 |
+
&& rm -rf /var/lib/apt/lists/*
|
15 |
+
|
16 |
+
# Switch to "user" before installing dependencies
|
17 |
+
USER user
|
18 |
+
ENV HOME=/home/user \
|
19 |
+
PATH=/home/user/.local/bin:$PATH \
|
20 |
+
PYTHONPATH=$HOME/app \
|
21 |
+
PYTHONUNBUFFERED=1 \
|
22 |
+
GRADIO_ALLOW_FLAGGING=never \
|
23 |
+
GRADIO_NUM_PORTS=1 \
|
24 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
25 |
+
GRADIO_THEME=huggingface \
|
26 |
+
SYSTEM=spaces
|
27 |
+
|
28 |
+
WORKDIR $HOME/app
|
29 |
+
|
30 |
+
# Copy project files as "user" before installing dependencies
|
31 |
+
COPY --chown=user . $HOME/app
|
32 |
+
COPY --chown=user ./images /home/user/app/images
|
33 |
+
|
34 |
+
# Install dependencies as "user"
|
35 |
+
RUN pip install --no-cache-dir --user -e .
|
36 |
+
RUN pip install --no-cache-dir --user opencv-python
|
37 |
+
RUN pip install --no-cache-dir --user -r /code/requirements-gradio.txt
|
38 |
+
RUN ls -l /home/user/app/images/
|
39 |
+
|
40 |
+
CMD ["python", "app.py"]
|
41 |
+
|
LICENSE-CODE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 DeepSeek
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
LICENSE-MODEL
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEEPSEEK LICENSE AGREEMENT
|
2 |
+
|
3 |
+
Version 1.0, 23 October 2023
|
4 |
+
|
5 |
+
Copyright (c) 2023 DeepSeek
|
6 |
+
|
7 |
+
Section I: PREAMBLE
|
8 |
+
|
9 |
+
Large generative models are being widely adopted and used, and have the potential to transform the way individuals conceive and benefit from AI or ML technologies.
|
10 |
+
|
11 |
+
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
|
12 |
+
|
13 |
+
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for content generation.
|
14 |
+
|
15 |
+
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this agreement aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
16 |
+
|
17 |
+
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
|
18 |
+
|
19 |
+
NOW THEREFORE, You and DeepSeek agree as follows:
|
20 |
+
|
21 |
+
1. Definitions
|
22 |
+
"License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
|
23 |
+
"Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
24 |
+
"Output" means the results of operating a Model as embodied in informational content resulting therefrom.
|
25 |
+
"Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
|
26 |
+
"Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
|
27 |
+
"Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
28 |
+
"Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
|
29 |
+
"DeepSeek" (or "we") means Beijing DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd., Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. and/or any of their affiliates.
|
30 |
+
"You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, etc.
|
31 |
+
"Third Parties" means individuals or legal entities that are not under common control with DeepSeek or You.
|
32 |
+
|
33 |
+
Section II: INTELLECTUAL PROPERTY RIGHTS
|
34 |
+
|
35 |
+
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
36 |
+
|
37 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
|
38 |
+
|
39 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by DeepSeek that are necessarily infringed by its contribution(s). If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or works shall terminate as of the date such litigation is asserted or filed.
|
40 |
+
|
41 |
+
|
42 |
+
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
43 |
+
|
44 |
+
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
|
45 |
+
a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
|
46 |
+
b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
|
47 |
+
c. You must cause any modified files to carry prominent notices stating that You changed the files;
|
48 |
+
d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
|
49 |
+
e. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. – for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
50 |
+
|
51 |
+
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
|
52 |
+
|
53 |
+
6. The Output You Generate. Except as set forth herein, DeepSeek claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
|
54 |
+
|
55 |
+
Section IV: OTHER PROVISIONS
|
56 |
+
|
57 |
+
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, DeepSeek reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License.
|
58 |
+
|
59 |
+
8. Trademarks and related. Nothing in this License permits You to make use of DeepSeek’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by DeepSeek.
|
60 |
+
|
61 |
+
9. Personal information, IP rights and related. This Model may contain personal information and works with IP rights. You commit to complying with applicable laws and regulations in the handling of personal information and the use of such works. Please note that DeepSeek's license granted to you to use the Model does not imply that you have obtained a legitimate basis for processing the related information or works. As an independent personal information processor and IP rights user, you need to ensure full compliance with relevant legal and regulatory requirements when handling personal information and works with IP rights that may be contained in the Model, and are willing to assume solely any risks and consequences that may arise from that.
|
62 |
+
|
63 |
+
10. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, DeepSeek provides the Model and the Complementary Material on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
|
64 |
+
|
65 |
+
11. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall DeepSeek be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if DeepSeek has been advised of the possibility of such damages.
|
66 |
+
|
67 |
+
12. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of DeepSeek, and only if You agree to indemnify, defend, and hold DeepSeek harmless for any liability incurred by, or claims asserted against, DeepSeek by reason of your accepting any such warranty or additional liability.
|
68 |
+
|
69 |
+
13. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
70 |
+
|
71 |
+
14. Governing Law and Jurisdiction. This agreement will be governed and construed under PRC laws without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this agreement. The courts located in the domicile of Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. shall have exclusive jurisdiction of any dispute arising out of this agreement.
|
72 |
+
|
73 |
+
END OF TERMS AND CONDITIONS
|
74 |
+
|
75 |
+
Attachment A
|
76 |
+
|
77 |
+
Use Restrictions
|
78 |
+
|
79 |
+
You agree not to use the Model or Derivatives of the Model:
|
80 |
+
|
81 |
+
- In any way that violates any applicable national or international law or regulation or infringes upon the lawful rights and interests of any third party;
|
82 |
+
- For military use in any way;
|
83 |
+
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
84 |
+
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
|
85 |
+
- To generate or disseminate inappropriate content subject to applicable regulatory requirements;
|
86 |
+
- To generate or disseminate personal identifiable information without due authorization or for unreasonable use;
|
87 |
+
- To defame, disparage or otherwise harass others;
|
88 |
+
- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
|
89 |
+
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
|
90 |
+
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
91 |
+
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories.
|
Makefile
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
print-% : ; @echo $* = $($*)
|
2 |
+
PROJECT_NAME = Janus
|
3 |
+
COPYRIGHT = "DeepSeek."
|
4 |
+
PROJECT_PATH = janus
|
5 |
+
SHELL = /bin/bash
|
6 |
+
SOURCE_FOLDERS = janus
|
7 |
+
PYTHON_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.py" -o -name "*.pyi") inference.py
|
8 |
+
COMMIT_HASH = $(shell git log -1 --format=%h)
|
9 |
+
PATH := $(HOME)/go/bin:$(PATH)
|
10 |
+
PYTHON ?= $(shell command -v python3 || command -v python)
|
11 |
+
PYTESTOPTS ?=
|
12 |
+
|
13 |
+
.PHONY: default
|
14 |
+
default: install
|
15 |
+
|
16 |
+
# Tools Installation
|
17 |
+
|
18 |
+
check_pip_install = $(PYTHON) -m pip show $(1) &>/dev/null || (cd && $(PYTHON) -m pip install $(1) --upgrade)
|
19 |
+
check_pip_install_extra = $(PYTHON) -m pip show $(1) &>/dev/null || (cd && $(PYTHON) -m pip install $(2) --upgrade)
|
20 |
+
|
21 |
+
pylint-install:
|
22 |
+
$(call check_pip_install_extra,pylint,pylint[spelling])
|
23 |
+
$(call check_pip_install,pyenchant)
|
24 |
+
|
25 |
+
flake8-install:
|
26 |
+
$(call check_pip_install,flake8)
|
27 |
+
$(call check_pip_install,flake8-bugbear)
|
28 |
+
$(call check_pip_install,flake8-comprehensions)
|
29 |
+
$(call check_pip_install,flake8-docstrings)
|
30 |
+
$(call check_pip_install,flake8-pyi)
|
31 |
+
$(call check_pip_install,flake8-simplify)
|
32 |
+
|
33 |
+
py-format-install:
|
34 |
+
$(call check_pip_install,isort)
|
35 |
+
$(call check_pip_install_extra,black,black[jupyter])
|
36 |
+
|
37 |
+
ruff-install:
|
38 |
+
$(call check_pip_install,ruff)
|
39 |
+
|
40 |
+
mypy-install:
|
41 |
+
$(call check_pip_install,mypy)
|
42 |
+
|
43 |
+
pre-commit-install:
|
44 |
+
$(call check_pip_install,pre-commit)
|
45 |
+
$(PYTHON) -m pre_commit install --install-hooks
|
46 |
+
|
47 |
+
go-install:
|
48 |
+
# requires go >= 1.16
|
49 |
+
command -v go || (sudo apt-get install -y golang && sudo ln -sf /usr/lib/go/bin/go /usr/bin/go)
|
50 |
+
|
51 |
+
addlicense-install: go-install
|
52 |
+
command -v addlicense || go install github.com/google/addlicense@latest
|
53 |
+
|
54 |
+
addlicense: addlicense-install
|
55 |
+
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l mit -y 2023-$(shell date +"%Y") -check $(SOURCE_FOLDERS)
|
56 |
+
|
57 |
+
# Python linters
|
58 |
+
|
59 |
+
pylint: pylint-install
|
60 |
+
$(PYTHON) -m pylint $(PROJECT_PATH)
|
61 |
+
|
62 |
+
flake8: flake8-install
|
63 |
+
$(PYTHON) -m flake8 --count --show-source --statistics
|
64 |
+
|
65 |
+
py-format: py-format-install
|
66 |
+
$(PYTHON) -m isort --project $(PROJECT_PATH) --check $(PYTHON_FILES) && \
|
67 |
+
$(PYTHON) -m black --check $(PYTHON_FILES)
|
68 |
+
|
69 |
+
black-format: py-format-install
|
70 |
+
$(PYTHON) -m black --check $(PYTHON_FILES)
|
71 |
+
|
72 |
+
ruff: ruff-install
|
73 |
+
$(PYTHON) -m ruff check .
|
74 |
+
|
75 |
+
ruff-fix: ruff-install
|
76 |
+
$(PYTHON) -m ruff check . --fix --exit-non-zero-on-fix
|
77 |
+
|
78 |
+
mypy: mypy-install
|
79 |
+
$(PYTHON) -m mypy $(PROJECT_PATH) --install-types --non-interactive
|
80 |
+
|
81 |
+
pre-commit: pre-commit-install
|
82 |
+
$(PYTHON) -m pre_commit run --all-files
|
83 |
+
|
84 |
+
# Utility functions
|
85 |
+
|
86 |
+
lint: ruff flake8 py-format mypy pylint addlicense
|
87 |
+
|
88 |
+
format: py-format-install ruff-install addlicense-install
|
89 |
+
$(PYTHON) -m isort --project $(PROJECT_PATH) $(PYTHON_FILES)
|
90 |
+
$(PYTHON) -m black $(PYTHON_FILES)
|
91 |
+
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l mit -y 2023-$(shell date +"%Y") $(SOURCE_FOLDERS) inference.py
|
92 |
+
|
93 |
+
clean-py:
|
94 |
+
find . -type f -name '*.py[co]' -delete
|
95 |
+
find . -depth -type d -name "__pycache__" -exec rm -r "{}" +
|
96 |
+
find . -depth -type d -name ".ruff_cache" -exec rm -r "{}" +
|
97 |
+
find . -depth -type d -name ".mypy_cache" -exec rm -r "{}" +
|
98 |
+
|
99 |
+
clean: clean-py
|
app.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
4 |
+
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
+
from janus.utils.io import load_pil_images
|
6 |
+
from demo.cam import generate_gradcam, AttentionGuidedCAMJanus, AttentionGuidedCAMClip, AttentionGuidedCAMLLaVA
|
7 |
+
from demo.model_utils import Clip_Utils, Janus_Utils, LLaVA_Utils, add_title_to_image
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import gc
|
12 |
+
import spaces
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
def set_seed(model_seed = 42):
|
16 |
+
torch.manual_seed(model_seed)
|
17 |
+
np.random.seed(model_seed)
|
18 |
+
torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
|
19 |
+
|
20 |
+
set_seed()
|
21 |
+
clip_utils = Clip_Utils()
|
22 |
+
clip_utils.init_Clip()
|
23 |
+
model_utils, vl_gpt, tokenizer = None, None, None
|
24 |
+
model_name = "Clip"
|
25 |
+
|
26 |
+
|
27 |
+
def clean():
|
28 |
+
global model_utils, vl_gpt, tokenizer, clip_utils
|
29 |
+
# Move models to CPU first (prevents CUDA references)
|
30 |
+
if 'vl_gpt' in globals() and vl_gpt is not None:
|
31 |
+
vl_gpt.to("cpu")
|
32 |
+
if 'clip_utils' in globals() and clip_utils is not None:
|
33 |
+
del clip_utils
|
34 |
+
|
35 |
+
# Delete all references
|
36 |
+
del model_utils, vl_gpt, tokenizer
|
37 |
+
model_utils, vl_gpt, tokenizer, clip_utils = None, None, None, None
|
38 |
+
gc.collect()
|
39 |
+
|
40 |
+
# Empty CUDA cache
|
41 |
+
if torch.cuda.is_available():
|
42 |
+
torch.cuda.empty_cache()
|
43 |
+
torch.cuda.ipc_collect() # Frees inter-process CUDA memory
|
44 |
+
|
45 |
+
# Empty MacOS Metal backend (if using Apple Silicon)
|
46 |
+
if torch.backends.mps.is_available():
|
47 |
+
torch.mps.empty_cache()
|
48 |
+
|
49 |
+
# Multimodal Understanding function
|
50 |
+
@spaces.GPU(duration=120)
|
51 |
+
def multimodal_understanding(model_type,
|
52 |
+
saliency_map_method,
|
53 |
+
visual_pooling_method,
|
54 |
+
image, question, seed, top_p, temperature, target_token_idx,
|
55 |
+
visualization_layer_min, visualization_layer_max, focus, response_type):
|
56 |
+
# Clear CUDA cache before generating
|
57 |
+
gc.collect()
|
58 |
+
if torch.cuda.is_available():
|
59 |
+
torch.cuda.empty_cache()
|
60 |
+
torch.cuda.ipc_collect()
|
61 |
+
|
62 |
+
# set seed
|
63 |
+
torch.manual_seed(seed)
|
64 |
+
np.random.seed(seed)
|
65 |
+
torch.cuda.manual_seed(seed) if torch.cuda.is_available() else None
|
66 |
+
|
67 |
+
input_text_decoded = ""
|
68 |
+
answer = ""
|
69 |
+
if model_name == "Clip":
|
70 |
+
|
71 |
+
inputs = clip_utils.prepare_inputs([question], image)
|
72 |
+
|
73 |
+
|
74 |
+
if saliency_map_method == "GradCAM":
|
75 |
+
# Generate Grad-CAM
|
76 |
+
all_layers = [layer.layer_norm1 for layer in clip_utils.model.vision_model.encoder.layers]
|
77 |
+
if visualization_layers_min.value != visualization_layers_max.value:
|
78 |
+
target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
|
79 |
+
else:
|
80 |
+
target_layers = [all_layers[visualization_layer_min-1]]
|
81 |
+
grad_cam = AttentionGuidedCAMClip(clip_utils.model, target_layers)
|
82 |
+
cam, outputs, grid_size = grad_cam.generate_cam(inputs, class_idx=0, visual_pooling_method=visual_pooling_method)
|
83 |
+
cam = cam.to("cpu")
|
84 |
+
cam = [generate_gradcam(cam, image, size=(224, 224))]
|
85 |
+
grad_cam.remove_hooks()
|
86 |
+
target_token_decoded = ""
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
else:
|
91 |
+
|
92 |
+
for param in vl_gpt.parameters():
|
93 |
+
param.requires_grad = True
|
94 |
+
|
95 |
+
|
96 |
+
prepare_inputs = model_utils.prepare_inputs(question, image)
|
97 |
+
|
98 |
+
if response_type == "answer + visualization":
|
99 |
+
if model_name.split('-')[0] == "Janus":
|
100 |
+
inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
|
101 |
+
outputs = model_utils.generate_outputs(inputs_embeds, prepare_inputs, temperature, top_p)
|
102 |
+
else:
|
103 |
+
outputs = model_utils.generate_outputs(prepare_inputs, temperature, top_p)
|
104 |
+
|
105 |
+
sequences = outputs.sequences.cpu().tolist()
|
106 |
+
answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
|
107 |
+
attention_raw = outputs.attentions
|
108 |
+
print("answer generated")
|
109 |
+
|
110 |
+
input_ids = prepare_inputs.input_ids[0].cpu().tolist()
|
111 |
+
input_ids_decoded = [tokenizer.decode([input_ids[i]]) for i in range(len(input_ids))]
|
112 |
+
start=620 if model_name.split('-')[0] == "Janus" else 512
|
113 |
+
|
114 |
+
if saliency_map_method == "GradCAM":
|
115 |
+
# target_layers = vl_gpt.vision_model.vision_tower.blocks
|
116 |
+
if focus == "Visual Encoder":
|
117 |
+
all_layers = [block.norm1 for block in vl_gpt.vision_model.vision_tower.blocks]
|
118 |
+
else:
|
119 |
+
all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
|
120 |
+
|
121 |
+
if visualization_layers_min.value != visualization_layers_max.value:
|
122 |
+
target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
|
123 |
+
else:
|
124 |
+
target_layers = [all_layers[visualization_layer_min-1]]
|
125 |
+
|
126 |
+
if model_name.split('-')[0] == "Janus":
|
127 |
+
gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
|
128 |
+
elif model_name.split('-')[0] == "LLaVA":
|
129 |
+
gradcam = AttentionGuidedCAMLLaVA(vl_gpt, target_layers)
|
130 |
+
cam_tensors, grid_size = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
|
131 |
+
gradcam.remove_hooks()
|
132 |
+
if focus == "Visual Encoder":
|
133 |
+
cam_grid = cam_tensors.reshape(grid_size, grid_size)
|
134 |
+
cam = [generate_gradcam(cam_grid, image)]
|
135 |
+
else:
|
136 |
+
if target_token_idx != -1:
|
137 |
+
input_text_decoded = input_ids_decoded[start + target_token_idx]
|
138 |
+
for i, cam_tensor in enumerate(cam_tensors):
|
139 |
+
if i == target_token_idx:
|
140 |
+
cam_grid = cam_tensor.reshape(grid_size, grid_size)
|
141 |
+
cam_i = generate_gradcam(cam_grid, image)
|
142 |
+
cam = [add_title_to_image(cam_i, input_text_decoded)]
|
143 |
+
break
|
144 |
+
else:
|
145 |
+
cam = []
|
146 |
+
for i, cam_tensor in enumerate(cam_tensors):
|
147 |
+
cam_grid = cam_tensor.reshape(24, 24)
|
148 |
+
cam_i = generate_gradcam(cam_grid, image)
|
149 |
+
cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
|
150 |
+
|
151 |
+
cam.append(cam_i)
|
152 |
+
|
153 |
+
return answer, cam, input_text_decoded
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
# Gradio interface
|
159 |
+
|
160 |
+
def model_slider_change(model_type):
|
161 |
+
global model_utils, vl_gpt, tokenizer, clip_utils, model_name
|
162 |
+
model_name = model_type
|
163 |
+
if model_type == "Clip":
|
164 |
+
clean()
|
165 |
+
set_seed()
|
166 |
+
clip_utils = Clip_Utils()
|
167 |
+
clip_utils.init_Clip()
|
168 |
+
res = (
|
169 |
+
gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type"),
|
170 |
+
gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
|
171 |
+
gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max"),
|
172 |
+
gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus"),
|
173 |
+
gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type")
|
174 |
+
)
|
175 |
+
return res
|
176 |
+
elif model_type.split('-')[0] == "Janus":
|
177 |
+
|
178 |
+
clean()
|
179 |
+
set_seed()
|
180 |
+
model_utils = Janus_Utils()
|
181 |
+
vl_gpt, tokenizer = model_utils.init_Janus(model_type.split('-')[-1])
|
182 |
+
|
183 |
+
res = (
|
184 |
+
gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="Visualization only", label="response_type"),
|
185 |
+
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
|
186 |
+
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max"),
|
187 |
+
gr.Dropdown(choices=["Visual Encoder", "Language Model"], value="Visual Encoder", label="focus"),
|
188 |
+
gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type")
|
189 |
+
)
|
190 |
+
return res
|
191 |
+
|
192 |
+
elif model_type.split('-')[0] == "LLaVA":
|
193 |
+
|
194 |
+
clean()
|
195 |
+
set_seed()
|
196 |
+
model_utils = LLaVA_Utils()
|
197 |
+
vl_gpt, tokenizer = model_utils.init_LLaVA()
|
198 |
+
|
199 |
+
res = (
|
200 |
+
gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="Visualization only", label="response_type"),
|
201 |
+
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
|
202 |
+
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max"),
|
203 |
+
gr.Dropdown(choices=["Language Model"], value="Language Model", label="focus"),
|
204 |
+
gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type")
|
205 |
+
)
|
206 |
+
return res
|
207 |
+
|
208 |
+
def focus_change(focus):
|
209 |
+
global model_name
|
210 |
+
if model_name == "Clip":
|
211 |
+
res = (
|
212 |
+
gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type"),
|
213 |
+
gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
|
214 |
+
gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
|
215 |
+
)
|
216 |
+
return res
|
217 |
+
|
218 |
+
if focus == "Language Model":
|
219 |
+
if response_type.value == "answer + visualization":
|
220 |
+
res = (
|
221 |
+
gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type"),
|
222 |
+
gr.Slider(minimum=1, maximum=24, value=8, step=1, label="visualization layers min"),
|
223 |
+
gr.Slider(minimum=1, maximum=24, value=8, step=1, label="visualization layers max")
|
224 |
+
)
|
225 |
+
return res
|
226 |
+
else:
|
227 |
+
res = (
|
228 |
+
gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type"),
|
229 |
+
gr.Slider(minimum=1, maximum=24, value=8, step=1, label="visualization layers min"),
|
230 |
+
gr.Slider(minimum=1, maximum=24, value=8, step=1, label="visualization layers max")
|
231 |
+
)
|
232 |
+
return res
|
233 |
+
|
234 |
+
else:
|
235 |
+
res = (
|
236 |
+
gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type"),
|
237 |
+
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
|
238 |
+
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max")
|
239 |
+
)
|
240 |
+
return res
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
with gr.Blocks() as demo:
|
247 |
+
gr.Markdown(value="# Multimodal Understanding")
|
248 |
+
with gr.Row():
|
249 |
+
with gr.Column():
|
250 |
+
image_input = gr.Image()
|
251 |
+
saliency_map_output = gr.Gallery(label="Saliency Map", height=300, columns=1)
|
252 |
+
|
253 |
+
with gr.Column():
|
254 |
+
model_selector = gr.Dropdown(choices=["Clip", "Janus-1B", "Janus-7B", "LLaVA-1.5-7B"], value="Clip", label="model")
|
255 |
+
response_type = gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type")
|
256 |
+
focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
|
257 |
+
saliency_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type")
|
258 |
+
visual_pooling_method = gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
|
259 |
+
|
260 |
+
|
261 |
+
visualization_layers_min = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min")
|
262 |
+
visualization_layers_max = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
|
263 |
+
|
264 |
+
question_input = gr.Textbox(label="Question")
|
265 |
+
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
266 |
+
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
267 |
+
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
268 |
+
target_token_idx = gr.Number(label="target_token_idx (-1 means all)", precision=0, value=-1)
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
model_selector.change(
|
273 |
+
fn=model_slider_change,
|
274 |
+
inputs=model_selector,
|
275 |
+
outputs=[
|
276 |
+
response_type,
|
277 |
+
visualization_layers_min,
|
278 |
+
visualization_layers_max,
|
279 |
+
focus,
|
280 |
+
saliency_map_method
|
281 |
+
]
|
282 |
+
)
|
283 |
+
|
284 |
+
focus.change(
|
285 |
+
fn = focus_change,
|
286 |
+
inputs = focus,
|
287 |
+
outputs=[
|
288 |
+
saliency_map_method,
|
289 |
+
visualization_layers_min,
|
290 |
+
visualization_layers_max,
|
291 |
+
]
|
292 |
+
)
|
293 |
+
|
294 |
+
# response_type.change(
|
295 |
+
# fn = response_type_change,
|
296 |
+
# inputs = response_type,
|
297 |
+
# outputs = [saliency_map_method]
|
298 |
+
# )
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
understanding_button = gr.Button("Chat")
|
303 |
+
understanding_output = gr.Textbox(label="Answer")
|
304 |
+
understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
|
305 |
+
|
306 |
+
|
307 |
+
examples_inpainting = gr.Examples(
|
308 |
+
label="Multimodal Understanding examples",
|
309 |
+
examples=[
|
310 |
+
|
311 |
+
[
|
312 |
+
"What is the approximate global smartphone market share of Samsung?",
|
313 |
+
"images/PieChart.png"
|
314 |
+
],
|
315 |
+
[
|
316 |
+
"What is the average internet speed in Japan?",
|
317 |
+
"images/BarChart.png"
|
318 |
+
],
|
319 |
+
[
|
320 |
+
"What was the average price of coffee beans in October 2019?",
|
321 |
+
"images/AreaChart.png"
|
322 |
+
],
|
323 |
+
[
|
324 |
+
"Which city's metro system has the largest number of stations?",
|
325 |
+
"images/BubbleChart.png"
|
326 |
+
],
|
327 |
+
|
328 |
+
[
|
329 |
+
"True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
|
330 |
+
"images/Choropleth_New.png"
|
331 |
+
],
|
332 |
+
|
333 |
+
[
|
334 |
+
"What distance have customers traveled in the taxi the most?",
|
335 |
+
"images/Histogram.png"
|
336 |
+
],
|
337 |
+
|
338 |
+
[
|
339 |
+
"What was the price of a barrel of oil in February 2020?",
|
340 |
+
"images/LineChart.png"
|
341 |
+
],
|
342 |
+
|
343 |
+
[
|
344 |
+
"True/False: eBay is nested in the Software category.",
|
345 |
+
"images/TreeMap.png"
|
346 |
+
],
|
347 |
+
|
348 |
+
[
|
349 |
+
"True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
|
350 |
+
"images/Scatterplot.png"
|
351 |
+
],
|
352 |
+
|
353 |
+
[
|
354 |
+
"Which country has the lowest proportion of Gold medals?",
|
355 |
+
"images/Stacked100.png"
|
356 |
+
],
|
357 |
+
|
358 |
+
[
|
359 |
+
"What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
|
360 |
+
"images/StackedArea.png"
|
361 |
+
],
|
362 |
+
|
363 |
+
[
|
364 |
+
"What is the cost of peanuts in Seoul?",
|
365 |
+
"images/StackedBar.png"
|
366 |
+
],
|
367 |
+
|
368 |
+
[
|
369 |
+
"Where is the dog? Left or Right?",
|
370 |
+
"images/cat_dog.png"
|
371 |
+
]
|
372 |
+
|
373 |
+
|
374 |
+
# [
|
375 |
+
# "explain this meme",
|
376 |
+
# "images/doge.png",
|
377 |
+
# ],
|
378 |
+
# [
|
379 |
+
# "Convert the formula into latex code.",
|
380 |
+
# "images/equation.png",
|
381 |
+
# ],
|
382 |
+
|
383 |
+
],
|
384 |
+
inputs=[question_input, image_input],
|
385 |
+
)
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
understanding_button.click(
|
391 |
+
multimodal_understanding,
|
392 |
+
inputs=[model_selector, saliency_map_method, visual_pooling_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
|
393 |
+
visualization_layers_min, visualization_layers_max, focus, response_type],
|
394 |
+
outputs=[understanding_output, saliency_map_output, understanding_target_token_decoded_output]
|
395 |
+
)
|
396 |
+
|
397 |
+
demo.launch(share=True)
|
398 |
+
# demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path")
|
demo/Janus_colab_demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
demo/app.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
4 |
+
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
# Load model and processor
|
11 |
+
model_path = "deepseek-ai/Janus-1.3B"
|
12 |
+
config = AutoConfig.from_pretrained(model_path)
|
13 |
+
language_config = config.language_config
|
14 |
+
language_config._attn_implementation = 'eager'
|
15 |
+
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
16 |
+
language_config=language_config,
|
17 |
+
trust_remote_code=True)
|
18 |
+
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
19 |
+
|
20 |
+
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
21 |
+
tokenizer = vl_chat_processor.tokenizer
|
22 |
+
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
23 |
+
# Multimodal Understanding function
|
24 |
+
@torch.inference_mode()
|
25 |
+
# Multimodal Understanding function
|
26 |
+
def multimodal_understanding(image, question, seed, top_p, temperature):
|
27 |
+
# Clear CUDA cache before generating
|
28 |
+
torch.cuda.empty_cache()
|
29 |
+
|
30 |
+
# set seed
|
31 |
+
torch.manual_seed(seed)
|
32 |
+
np.random.seed(seed)
|
33 |
+
torch.cuda.manual_seed(seed)
|
34 |
+
|
35 |
+
conversation = [
|
36 |
+
{
|
37 |
+
"role": "User",
|
38 |
+
"content": f"<image_placeholder>\n{question}",
|
39 |
+
"images": [image],
|
40 |
+
},
|
41 |
+
{"role": "Assistant", "content": ""},
|
42 |
+
]
|
43 |
+
|
44 |
+
pil_images = [Image.fromarray(image)]
|
45 |
+
prepare_inputs = vl_chat_processor(
|
46 |
+
conversations=conversation, images=pil_images, force_batchify=True
|
47 |
+
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
48 |
+
|
49 |
+
|
50 |
+
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
51 |
+
|
52 |
+
outputs = vl_gpt.language_model.generate(
|
53 |
+
inputs_embeds=inputs_embeds,
|
54 |
+
attention_mask=prepare_inputs.attention_mask,
|
55 |
+
pad_token_id=tokenizer.eos_token_id,
|
56 |
+
bos_token_id=tokenizer.bos_token_id,
|
57 |
+
eos_token_id=tokenizer.eos_token_id,
|
58 |
+
max_new_tokens=512,
|
59 |
+
do_sample=False if temperature == 0 else True,
|
60 |
+
use_cache=True,
|
61 |
+
temperature=temperature,
|
62 |
+
top_p=top_p,
|
63 |
+
)
|
64 |
+
|
65 |
+
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
66 |
+
return answer
|
67 |
+
|
68 |
+
|
69 |
+
def generate(input_ids,
|
70 |
+
width,
|
71 |
+
height,
|
72 |
+
temperature: float = 1,
|
73 |
+
parallel_size: int = 5,
|
74 |
+
cfg_weight: float = 5,
|
75 |
+
image_token_num_per_image: int = 576,
|
76 |
+
patch_size: int = 16):
|
77 |
+
# Clear CUDA cache before generating
|
78 |
+
torch.cuda.empty_cache()
|
79 |
+
|
80 |
+
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
81 |
+
for i in range(parallel_size * 2):
|
82 |
+
tokens[i, :] = input_ids
|
83 |
+
if i % 2 != 0:
|
84 |
+
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
85 |
+
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
86 |
+
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
|
87 |
+
|
88 |
+
pkv = None
|
89 |
+
for i in range(image_token_num_per_image):
|
90 |
+
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
|
91 |
+
use_cache=True,
|
92 |
+
past_key_values=pkv)
|
93 |
+
pkv = outputs.past_key_values
|
94 |
+
hidden_states = outputs.last_hidden_state
|
95 |
+
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
96 |
+
logit_cond = logits[0::2, :]
|
97 |
+
logit_uncond = logits[1::2, :]
|
98 |
+
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
99 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
100 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
101 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
102 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
103 |
+
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
104 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
105 |
+
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
|
106 |
+
shape=[parallel_size, 8, width // patch_size, height // patch_size])
|
107 |
+
|
108 |
+
return generated_tokens.to(dtype=torch.int), patches
|
109 |
+
|
110 |
+
def unpack(dec, width, height, parallel_size=5):
|
111 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
112 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
113 |
+
|
114 |
+
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
115 |
+
visual_img[:, :, :] = dec
|
116 |
+
|
117 |
+
return visual_img
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
@torch.inference_mode()
|
122 |
+
def generate_image(prompt,
|
123 |
+
seed=None,
|
124 |
+
guidance=5):
|
125 |
+
# Clear CUDA cache and avoid tracking gradients
|
126 |
+
torch.cuda.empty_cache()
|
127 |
+
# Set the seed for reproducible results
|
128 |
+
if seed is not None:
|
129 |
+
torch.manual_seed(seed)
|
130 |
+
torch.cuda.manual_seed(seed)
|
131 |
+
np.random.seed(seed)
|
132 |
+
width = 384
|
133 |
+
height = 384
|
134 |
+
parallel_size = 5
|
135 |
+
|
136 |
+
with torch.no_grad():
|
137 |
+
messages = [{'role': 'User', 'content': prompt},
|
138 |
+
{'role': 'Assistant', 'content': ''}]
|
139 |
+
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
140 |
+
sft_format=vl_chat_processor.sft_format,
|
141 |
+
system_prompt='')
|
142 |
+
text = text + vl_chat_processor.image_start_tag
|
143 |
+
input_ids = torch.LongTensor(tokenizer.encode(text))
|
144 |
+
output, patches = generate(input_ids,
|
145 |
+
width // 16 * 16,
|
146 |
+
height // 16 * 16,
|
147 |
+
cfg_weight=guidance,
|
148 |
+
parallel_size=parallel_size)
|
149 |
+
images = unpack(patches,
|
150 |
+
width // 16 * 16,
|
151 |
+
height // 16 * 16)
|
152 |
+
|
153 |
+
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
# Gradio interface
|
158 |
+
with gr.Blocks() as demo:
|
159 |
+
gr.Markdown(value="# Multimodal Understanding")
|
160 |
+
# with gr.Row():
|
161 |
+
with gr.Row():
|
162 |
+
image_input = gr.Image()
|
163 |
+
with gr.Column():
|
164 |
+
question_input = gr.Textbox(label="Question")
|
165 |
+
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
166 |
+
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
167 |
+
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
168 |
+
|
169 |
+
understanding_button = gr.Button("Chat")
|
170 |
+
understanding_output = gr.Textbox(label="Response")
|
171 |
+
|
172 |
+
examples_inpainting = gr.Examples(
|
173 |
+
label="Multimodal Understanding examples",
|
174 |
+
examples=[
|
175 |
+
[
|
176 |
+
"explain this meme",
|
177 |
+
"images/doge.png",
|
178 |
+
],
|
179 |
+
[
|
180 |
+
"Convert the formula into latex code.",
|
181 |
+
"images/equation.png",
|
182 |
+
],
|
183 |
+
],
|
184 |
+
inputs=[question_input, image_input],
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
gr.Markdown(value="# Text-to-Image Generation")
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
with gr.Row():
|
193 |
+
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
194 |
+
|
195 |
+
prompt_input = gr.Textbox(label="Prompt")
|
196 |
+
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
197 |
+
|
198 |
+
generation_button = gr.Button("Generate Images")
|
199 |
+
|
200 |
+
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
|
201 |
+
|
202 |
+
examples_t2i = gr.Examples(
|
203 |
+
label="Text to image generation examples. (Tips for designing prompts: Adding description like 'digital art' at the end of the prompt or writing the prompt in more detail can help produce better images!)",
|
204 |
+
examples=[
|
205 |
+
"Master shifu racoon wearing drip attire as a street gangster.",
|
206 |
+
"A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
|
207 |
+
"The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
|
208 |
+
],
|
209 |
+
inputs=prompt_input,
|
210 |
+
)
|
211 |
+
|
212 |
+
understanding_button.click(
|
213 |
+
multimodal_understanding,
|
214 |
+
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
215 |
+
outputs=understanding_output
|
216 |
+
)
|
217 |
+
|
218 |
+
generation_button.click(
|
219 |
+
fn=generate_image,
|
220 |
+
inputs=[prompt_input, seed_input, cfg_weight_input],
|
221 |
+
outputs=image_output
|
222 |
+
)
|
223 |
+
|
224 |
+
demo.launch(share=True)
|
demo/app_janusflow.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
|
4 |
+
from PIL import Image
|
5 |
+
from diffusers.models import AutoencoderKL
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
+
|
10 |
+
# Load model and processor
|
11 |
+
model_path = "deepseek-ai/JanusFlow-1.3B"
|
12 |
+
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
13 |
+
tokenizer = vl_chat_processor.tokenizer
|
14 |
+
|
15 |
+
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
|
16 |
+
vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
|
17 |
+
|
18 |
+
# remember to use bfloat16 dtype, this vae doesn't work with fp16
|
19 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
|
20 |
+
vae = vae.to(torch.bfloat16).to(cuda_device).eval()
|
21 |
+
|
22 |
+
# Multimodal Understanding function
|
23 |
+
@torch.inference_mode()
|
24 |
+
# Multimodal Understanding function
|
25 |
+
def multimodal_understanding(image, question, seed, top_p, temperature):
|
26 |
+
# Clear CUDA cache before generating
|
27 |
+
torch.cuda.empty_cache()
|
28 |
+
|
29 |
+
# set seed
|
30 |
+
torch.manual_seed(seed)
|
31 |
+
np.random.seed(seed)
|
32 |
+
torch.cuda.manual_seed(seed)
|
33 |
+
|
34 |
+
conversation = [
|
35 |
+
{
|
36 |
+
"role": "User",
|
37 |
+
"content": f"<image_placeholder>\n{question}",
|
38 |
+
"images": [image],
|
39 |
+
},
|
40 |
+
{"role": "Assistant", "content": ""},
|
41 |
+
]
|
42 |
+
|
43 |
+
pil_images = [Image.fromarray(image)]
|
44 |
+
prepare_inputs = vl_chat_processor(
|
45 |
+
conversations=conversation, images=pil_images, force_batchify=True
|
46 |
+
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
47 |
+
|
48 |
+
|
49 |
+
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
50 |
+
|
51 |
+
outputs = vl_gpt.language_model.generate(
|
52 |
+
inputs_embeds=inputs_embeds,
|
53 |
+
attention_mask=prepare_inputs.attention_mask,
|
54 |
+
pad_token_id=tokenizer.eos_token_id,
|
55 |
+
bos_token_id=tokenizer.bos_token_id,
|
56 |
+
eos_token_id=tokenizer.eos_token_id,
|
57 |
+
max_new_tokens=512,
|
58 |
+
do_sample=False if temperature == 0 else True,
|
59 |
+
use_cache=True,
|
60 |
+
temperature=temperature,
|
61 |
+
top_p=top_p,
|
62 |
+
)
|
63 |
+
|
64 |
+
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
65 |
+
|
66 |
+
return answer
|
67 |
+
|
68 |
+
|
69 |
+
@torch.inference_mode()
|
70 |
+
def generate(
|
71 |
+
input_ids,
|
72 |
+
cfg_weight: float = 2.0,
|
73 |
+
num_inference_steps: int = 30
|
74 |
+
):
|
75 |
+
# we generate 5 images at a time, *2 for CFG
|
76 |
+
tokens = torch.stack([input_ids] * 10).cuda()
|
77 |
+
tokens[5:, 1:] = vl_chat_processor.pad_id
|
78 |
+
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
79 |
+
print(inputs_embeds.shape)
|
80 |
+
|
81 |
+
# we remove the last <bog> token and replace it with t_emb later
|
82 |
+
inputs_embeds = inputs_embeds[:, :-1, :]
|
83 |
+
|
84 |
+
# generate with rectified flow ode
|
85 |
+
# step 1: encode with vision_gen_enc
|
86 |
+
z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
|
87 |
+
|
88 |
+
dt = 1.0 / num_inference_steps
|
89 |
+
dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
|
90 |
+
|
91 |
+
# step 2: run ode
|
92 |
+
attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
|
93 |
+
attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
|
94 |
+
attention_mask = attention_mask.int()
|
95 |
+
for step in range(num_inference_steps):
|
96 |
+
# prepare inputs for the llm
|
97 |
+
z_input = torch.cat([z, z], dim=0) # for cfg
|
98 |
+
t = step / num_inference_steps * 1000.
|
99 |
+
t = torch.tensor([t] * z_input.shape[0]).to(dt)
|
100 |
+
z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
|
101 |
+
z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
|
102 |
+
z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
|
103 |
+
z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
|
104 |
+
llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
|
105 |
+
|
106 |
+
# input to the llm
|
107 |
+
# we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
|
108 |
+
if step == 0:
|
109 |
+
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
110 |
+
use_cache=True,
|
111 |
+
attention_mask=attention_mask,
|
112 |
+
past_key_values=None)
|
113 |
+
past_key_values = []
|
114 |
+
for kv_cache in past_key_values:
|
115 |
+
k, v = kv_cache[0], kv_cache[1]
|
116 |
+
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
|
117 |
+
past_key_values = tuple(past_key_values)
|
118 |
+
else:
|
119 |
+
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
120 |
+
use_cache=True,
|
121 |
+
attention_mask=attention_mask,
|
122 |
+
past_key_values=past_key_values)
|
123 |
+
hidden_states = outputs.last_hidden_state
|
124 |
+
|
125 |
+
# transform hidden_states back to v
|
126 |
+
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
|
127 |
+
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
|
128 |
+
v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
|
129 |
+
v_cond, v_uncond = torch.chunk(v, 2)
|
130 |
+
v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
|
131 |
+
z = z + dt * v
|
132 |
+
|
133 |
+
# step 3: decode with vision_gen_dec and sdxl vae
|
134 |
+
decoded_image = vae.decode(z / vae.config.scaling_factor).sample
|
135 |
+
|
136 |
+
images = decoded_image.float().clip_(-1., 1.).permute(0,2,3,1).cpu().numpy()
|
137 |
+
images = ((images+1) / 2. * 255).astype(np.uint8)
|
138 |
+
|
139 |
+
return images
|
140 |
+
|
141 |
+
def unpack(dec, width, height, parallel_size=5):
|
142 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
143 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
144 |
+
|
145 |
+
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
146 |
+
visual_img[:, :, :] = dec
|
147 |
+
|
148 |
+
return visual_img
|
149 |
+
|
150 |
+
|
151 |
+
@torch.inference_mode()
|
152 |
+
def generate_image(prompt,
|
153 |
+
seed=None,
|
154 |
+
guidance=5,
|
155 |
+
num_inference_steps=30):
|
156 |
+
# Clear CUDA cache and avoid tracking gradients
|
157 |
+
torch.cuda.empty_cache()
|
158 |
+
# Set the seed for reproducible results
|
159 |
+
if seed is not None:
|
160 |
+
torch.manual_seed(seed)
|
161 |
+
torch.cuda.manual_seed(seed)
|
162 |
+
np.random.seed(seed)
|
163 |
+
|
164 |
+
with torch.no_grad():
|
165 |
+
messages = [{'role': 'User', 'content': prompt},
|
166 |
+
{'role': 'Assistant', 'content': ''}]
|
167 |
+
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
168 |
+
sft_format=vl_chat_processor.sft_format,
|
169 |
+
system_prompt='')
|
170 |
+
text = text + vl_chat_processor.image_start_tag
|
171 |
+
input_ids = torch.LongTensor(tokenizer.encode(text))
|
172 |
+
images = generate(input_ids,
|
173 |
+
cfg_weight=guidance,
|
174 |
+
num_inference_steps=num_inference_steps)
|
175 |
+
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
# Gradio interface
|
180 |
+
with gr.Blocks() as demo:
|
181 |
+
gr.Markdown(value="# Multimodal Understanding")
|
182 |
+
# with gr.Row():
|
183 |
+
with gr.Row():
|
184 |
+
image_input = gr.Image()
|
185 |
+
with gr.Column():
|
186 |
+
question_input = gr.Textbox(label="Question")
|
187 |
+
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
188 |
+
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
189 |
+
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
190 |
+
|
191 |
+
understanding_button = gr.Button("Chat")
|
192 |
+
understanding_output = gr.Textbox(label="Response")
|
193 |
+
|
194 |
+
examples_inpainting = gr.Examples(
|
195 |
+
label="Multimodal Understanding examples",
|
196 |
+
examples=[
|
197 |
+
[
|
198 |
+
"explain this meme",
|
199 |
+
"./images/doge.png",
|
200 |
+
],
|
201 |
+
[
|
202 |
+
"Convert the formula into latex code.",
|
203 |
+
"./images/equation.png",
|
204 |
+
],
|
205 |
+
],
|
206 |
+
inputs=[question_input, image_input],
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
gr.Markdown(value="# Text-to-Image Generation")
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
with gr.Row():
|
215 |
+
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
|
216 |
+
step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Number of Inference Steps")
|
217 |
+
|
218 |
+
prompt_input = gr.Textbox(label="Prompt")
|
219 |
+
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
220 |
+
|
221 |
+
generation_button = gr.Button("Generate Images")
|
222 |
+
|
223 |
+
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
|
224 |
+
|
225 |
+
examples_t2i = gr.Examples(
|
226 |
+
label="Text to image generation examples.",
|
227 |
+
examples=[
|
228 |
+
"Master shifu racoon wearing drip attire as a street gangster.",
|
229 |
+
"A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
|
230 |
+
"The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
|
231 |
+
],
|
232 |
+
inputs=prompt_input,
|
233 |
+
)
|
234 |
+
|
235 |
+
understanding_button.click(
|
236 |
+
multimodal_understanding,
|
237 |
+
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
238 |
+
outputs=understanding_output
|
239 |
+
)
|
240 |
+
|
241 |
+
generation_button.click(
|
242 |
+
fn=generate_image,
|
243 |
+
inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
|
244 |
+
outputs=image_output
|
245 |
+
)
|
246 |
+
|
247 |
+
demo.launch(share=True)
|
demo/app_januspro.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
4 |
+
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
+
from janus.utils.io import load_pil_images
|
6 |
+
from demo.cam import generate_gradcam, GradCAM, AttentionGuidedCAM
|
7 |
+
from PIL import Image
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
# import spaces # Import spaces for ZeroGPU compatibility
|
14 |
+
|
15 |
+
|
16 |
+
# Load model and processor
|
17 |
+
# model_path = "deepseek-ai/Janus-Pro-7B"
|
18 |
+
model_path = "deepseek-ai/Janus-Pro-1B"
|
19 |
+
config = AutoConfig.from_pretrained(model_path)
|
20 |
+
language_config = config.language_config
|
21 |
+
language_config._attn_implementation = 'eager'
|
22 |
+
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
23 |
+
language_config=language_config,
|
24 |
+
trust_remote_code=True)
|
25 |
+
|
26 |
+
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
27 |
+
# dtype = torch.bfloat32 if torch.cuda.is_available() else torch.float32
|
28 |
+
|
29 |
+
if torch.cuda.is_available():
|
30 |
+
vl_gpt = vl_gpt.to(dtype).cuda()
|
31 |
+
else:
|
32 |
+
# vl_gpt = vl_gpt.to(torch.float16)
|
33 |
+
torch.set_default_device("mps")
|
34 |
+
vl_gpt = vl_gpt.to(dtype)
|
35 |
+
|
36 |
+
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
37 |
+
tokenizer = vl_chat_processor.tokenizer
|
38 |
+
cuda_device = 'cuda' if torch.cuda.is_available() else 'mps'
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
# @torch.inference_mode() # cancel inference, for gradcam
|
43 |
+
# @spaces.GPU(duration=120)
|
44 |
+
# Multimodal Understanding function
|
45 |
+
def multimodal_understanding(image, question, seed, top_p, temperature, target_token_idx):
|
46 |
+
# Clear CUDA cache before generating
|
47 |
+
torch.cuda.empty_cache()
|
48 |
+
|
49 |
+
|
50 |
+
for param in vl_gpt.parameters():
|
51 |
+
param.requires_grad = True
|
52 |
+
|
53 |
+
# set seed
|
54 |
+
torch.manual_seed(seed)
|
55 |
+
np.random.seed(seed)
|
56 |
+
torch.cuda.manual_seed(seed)
|
57 |
+
|
58 |
+
|
59 |
+
# Get the last transformer block of the Vision Transformer (ViT)
|
60 |
+
|
61 |
+
|
62 |
+
conversation = [
|
63 |
+
{
|
64 |
+
"role": "<|User|>",
|
65 |
+
"content": f"<image_placeholder>\n{question}",
|
66 |
+
"images": [image],
|
67 |
+
},
|
68 |
+
{"role": "<|Assistant|>", "content": ""},
|
69 |
+
]
|
70 |
+
|
71 |
+
pil_images = [Image.fromarray(image)]
|
72 |
+
prepare_inputs = vl_chat_processor(
|
73 |
+
conversations=conversation, images=pil_images, force_batchify=True
|
74 |
+
).to(cuda_device, dtype=dtype)
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
80 |
+
|
81 |
+
# print("prepared inputs", prepare_inputs)
|
82 |
+
|
83 |
+
|
84 |
+
outputs = vl_gpt.language_model.generate(
|
85 |
+
inputs_embeds=inputs_embeds,
|
86 |
+
attention_mask=prepare_inputs.attention_mask,
|
87 |
+
pad_token_id=tokenizer.eos_token_id,
|
88 |
+
bos_token_id=tokenizer.bos_token_id,
|
89 |
+
eos_token_id=tokenizer.eos_token_id,
|
90 |
+
max_new_tokens=512,
|
91 |
+
do_sample=False if temperature == 0 else True,
|
92 |
+
use_cache=True,
|
93 |
+
temperature=temperature,
|
94 |
+
top_p=top_p,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
100 |
+
print("answer generated")
|
101 |
+
|
102 |
+
|
103 |
+
target_layer = vl_gpt.vision_model.vision_tower.blocks
|
104 |
+
|
105 |
+
gradcam = AttentionGuidedCAM(vl_gpt, target_layer)
|
106 |
+
cam_tensor, output, grid_size = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx)
|
107 |
+
cam_grid = cam_tensor.reshape(grid_size, grid_size)
|
108 |
+
cam = generate_gradcam(cam_grid, image)
|
109 |
+
|
110 |
+
output_arr = output.logits.detach().to(float).to("cpu").numpy()
|
111 |
+
predicted_ids = np.argmax(output_arr, axis=-1) # [1, num_tokens]
|
112 |
+
predicted_ids = predicted_ids.squeeze(0) # [num_tokens]
|
113 |
+
target_token_decoded = tokenizer.decode(predicted_ids[target_token_idx].tolist())
|
114 |
+
|
115 |
+
return answer, [cam], target_token_decoded
|
116 |
+
|
117 |
+
|
118 |
+
def generate(input_ids,
|
119 |
+
width,
|
120 |
+
height,
|
121 |
+
temperature: float = 1,
|
122 |
+
parallel_size: int = 5,
|
123 |
+
cfg_weight: float = 5,
|
124 |
+
image_token_num_per_image: int = 576,
|
125 |
+
patch_size: int = 16):
|
126 |
+
# Clear CUDA cache before generating
|
127 |
+
torch.cuda.empty_cache()
|
128 |
+
|
129 |
+
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
130 |
+
for i in range(parallel_size * 2):
|
131 |
+
tokens[i, :] = input_ids
|
132 |
+
if i % 2 != 0:
|
133 |
+
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
134 |
+
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
135 |
+
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
|
136 |
+
|
137 |
+
pkv = None
|
138 |
+
for i in range(image_token_num_per_image):
|
139 |
+
with torch.no_grad():
|
140 |
+
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
|
141 |
+
use_cache=True,
|
142 |
+
past_key_values=pkv)
|
143 |
+
pkv = outputs.past_key_values
|
144 |
+
hidden_states = outputs.last_hidden_state
|
145 |
+
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
146 |
+
logit_cond = logits[0::2, :]
|
147 |
+
logit_uncond = logits[1::2, :]
|
148 |
+
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
149 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
150 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
151 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
152 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
153 |
+
|
154 |
+
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
155 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
|
160 |
+
shape=[parallel_size, 8, width // patch_size, height // patch_size])
|
161 |
+
|
162 |
+
return generated_tokens.to(dtype=torch.int), patches
|
163 |
+
|
164 |
+
def unpack(dec, width, height, parallel_size=5):
|
165 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
166 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
167 |
+
|
168 |
+
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
169 |
+
visual_img[:, :, :] = dec
|
170 |
+
|
171 |
+
return visual_img
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
@torch.inference_mode()
|
176 |
+
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
177 |
+
def generate_image(prompt,
|
178 |
+
seed=None,
|
179 |
+
guidance=5,
|
180 |
+
t2i_temperature=1.0):
|
181 |
+
# Clear CUDA cache and avoid tracking gradients
|
182 |
+
torch.cuda.empty_cache()
|
183 |
+
# Set the seed for reproducible results
|
184 |
+
if seed is not None:
|
185 |
+
torch.manual_seed(seed)
|
186 |
+
torch.cuda.manual_seed(seed)
|
187 |
+
np.random.seed(seed)
|
188 |
+
width = 384
|
189 |
+
height = 384
|
190 |
+
parallel_size = 5
|
191 |
+
|
192 |
+
with torch.no_grad():
|
193 |
+
messages = [{'role': '<|User|>', 'content': prompt},
|
194 |
+
{'role': '<|Assistant|>', 'content': ''}]
|
195 |
+
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
196 |
+
sft_format=vl_chat_processor.sft_format,
|
197 |
+
system_prompt='')
|
198 |
+
text = text + vl_chat_processor.image_start_tag
|
199 |
+
|
200 |
+
input_ids = torch.LongTensor(tokenizer.encode(text))
|
201 |
+
output, patches = generate(input_ids,
|
202 |
+
width // 16 * 16,
|
203 |
+
height // 16 * 16,
|
204 |
+
cfg_weight=guidance,
|
205 |
+
parallel_size=parallel_size,
|
206 |
+
temperature=t2i_temperature)
|
207 |
+
images = unpack(patches,
|
208 |
+
width // 16 * 16,
|
209 |
+
height // 16 * 16,
|
210 |
+
parallel_size=parallel_size)
|
211 |
+
|
212 |
+
return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
|
213 |
+
|
214 |
+
|
215 |
+
# Gradio interface
|
216 |
+
with gr.Blocks() as demo:
|
217 |
+
gr.Markdown(value="# Multimodal Understanding")
|
218 |
+
with gr.Row():
|
219 |
+
with gr.Column():
|
220 |
+
image_input = gr.Image()
|
221 |
+
saliency_map_output = gr.Gallery(label="Saliency Map", columns=1, rows=1, height=300)
|
222 |
+
|
223 |
+
with gr.Column():
|
224 |
+
question_input = gr.Textbox(label="Question")
|
225 |
+
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
226 |
+
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
227 |
+
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
228 |
+
target_token_idx = gr.Number(label="target_token_idx", precision=0, value=300)
|
229 |
+
|
230 |
+
understanding_button = gr.Button("Chat")
|
231 |
+
understanding_output = gr.Textbox(label="Response")
|
232 |
+
understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
|
233 |
+
|
234 |
+
|
235 |
+
examples_inpainting = gr.Examples(
|
236 |
+
label="Multimodal Understanding examples",
|
237 |
+
examples=[
|
238 |
+
[
|
239 |
+
"explain this meme",
|
240 |
+
"images/doge.png",
|
241 |
+
],
|
242 |
+
[
|
243 |
+
"Convert the formula into latex code.",
|
244 |
+
"images/equation.png",
|
245 |
+
],
|
246 |
+
],
|
247 |
+
inputs=[question_input, image_input],
|
248 |
+
)
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
gr.Markdown(value="# Text-to-Image Generation")
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
with gr.Row():
|
258 |
+
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
259 |
+
t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
|
260 |
+
|
261 |
+
prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
|
262 |
+
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
263 |
+
|
264 |
+
generation_button = gr.Button("Generate Images")
|
265 |
+
|
266 |
+
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
|
267 |
+
|
268 |
+
examples_t2i = gr.Examples(
|
269 |
+
label="Text to image generation examples.",
|
270 |
+
examples=[
|
271 |
+
"Master shifu racoon wearing drip attire as a street gangster.",
|
272 |
+
"The face of a beautiful girl",
|
273 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
274 |
+
"A glass of red wine on a reflective surface.",
|
275 |
+
"A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
|
276 |
+
"The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
|
277 |
+
],
|
278 |
+
inputs=prompt_input,
|
279 |
+
)
|
280 |
+
|
281 |
+
understanding_button.click(
|
282 |
+
multimodal_understanding,
|
283 |
+
inputs=[image_input, question_input, und_seed_input, top_p, temperature, target_token_idx],
|
284 |
+
outputs=[understanding_output, saliency_map_output, understanding_target_token_decoded_output]
|
285 |
+
)
|
286 |
+
|
287 |
+
generation_button.click(
|
288 |
+
fn=generate_image,
|
289 |
+
inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
|
290 |
+
outputs=image_output
|
291 |
+
)
|
292 |
+
|
293 |
+
demo.launch(share=True)
|
294 |
+
# demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path")
|
demo/app_vqa.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
4 |
+
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
+
from janus.utils.io import load_pil_images
|
6 |
+
from demo.cam import generate_gradcam, AttentionGuidedCAMJanus, AttentionGuidedCAMClip
|
7 |
+
from demo.model_utils import Clip_Utils, Janus_Utils, add_title_to_image
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import gc
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
model_seed = 42
|
15 |
+
torch.manual_seed(model_seed)
|
16 |
+
np.random.seed(model_seed)
|
17 |
+
torch.cuda.manual_seed(model_seed)
|
18 |
+
|
19 |
+
model_type = "Janus-1B"
|
20 |
+
janus_utils = Janus_Utils()
|
21 |
+
vl_gpt, tokenizer = janus_utils.init_Janus(model_type.split('-')[-1])
|
22 |
+
|
23 |
+
clip_utils = Clip_Utils()
|
24 |
+
clip_utils.init_Clip()
|
25 |
+
|
26 |
+
# @torch.inference_mode() # cancel inference, for gradcam
|
27 |
+
# @spaces.GPU(duration=120)
|
28 |
+
# Multimodal Understanding function
|
29 |
+
def multimodal_understanding(model_type,
|
30 |
+
saliency_map_method,
|
31 |
+
visual_pooling_method,
|
32 |
+
image, question, seed, top_p, temperature, target_token_idx,
|
33 |
+
visualization_layer_min, visualization_layer_max, focus):
|
34 |
+
# Clear CUDA cache before generating
|
35 |
+
torch.cuda.empty_cache()
|
36 |
+
|
37 |
+
# set seed
|
38 |
+
torch.manual_seed(seed)
|
39 |
+
np.random.seed(seed)
|
40 |
+
torch.cuda.manual_seed(seed)
|
41 |
+
|
42 |
+
input_text_decoded = ""
|
43 |
+
if model_type == "Clip":
|
44 |
+
|
45 |
+
inputs = clip_utils.prepare_inputs([question], image)
|
46 |
+
|
47 |
+
|
48 |
+
if saliency_map_method == "GradCAM":
|
49 |
+
# Generate Grad-CAM
|
50 |
+
all_layers = [layer.layer_norm1 for layer in clip_utils.model.vision_model.encoder.layers]
|
51 |
+
if visualization_layers_min.value != visualization_layers_max.value:
|
52 |
+
target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
|
53 |
+
else:
|
54 |
+
target_layers = [all_layers[visualization_layer_min-1]]
|
55 |
+
grad_cam = AttentionGuidedCAMClip(clip_utils.model, target_layers)
|
56 |
+
cam, outputs, grid_size = grad_cam.generate_cam(inputs, class_idx=0, visual_pooling_method=visual_pooling_method)
|
57 |
+
cam = [generate_gradcam(cam, image, size=(224, 224))]
|
58 |
+
grad_cam.remove_hooks()
|
59 |
+
target_token_decoded = ""
|
60 |
+
answer = ""
|
61 |
+
|
62 |
+
|
63 |
+
elif model_type == "Janus-1B":
|
64 |
+
|
65 |
+
for param in vl_gpt.parameters():
|
66 |
+
param.requires_grad = True
|
67 |
+
|
68 |
+
|
69 |
+
prepare_inputs = janus_utils.prepare_inputs(question, image)
|
70 |
+
inputs_embeds = janus_utils.generate_inputs_embeddings(prepare_inputs)
|
71 |
+
outputs = janus_utils.generate_outputs(inputs_embeds, prepare_inputs, temperature, top_p)
|
72 |
+
|
73 |
+
sequences = outputs.sequences.cpu().tolist()
|
74 |
+
answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
|
75 |
+
attention_raw = outputs.attentions
|
76 |
+
print("answer generated")
|
77 |
+
|
78 |
+
input_ids = prepare_inputs.input_ids[0].cpu().tolist()
|
79 |
+
input_ids_decoded = [tokenizer.decode([input_ids[i]]) for i in range(len(input_ids))]
|
80 |
+
start=620
|
81 |
+
|
82 |
+
if saliency_map_method == "GradCAM":
|
83 |
+
# target_layers = vl_gpt.vision_model.vision_tower.blocks
|
84 |
+
if focus == "Visual Encoder":
|
85 |
+
all_layers = [block.norm1 for block in vl_gpt.vision_model.vision_tower.blocks]
|
86 |
+
else:
|
87 |
+
all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
|
88 |
+
|
89 |
+
if visualization_layers_min.value != visualization_layers_max.value:
|
90 |
+
target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
|
91 |
+
else:
|
92 |
+
target_layers = [all_layers[visualization_layer_min-1]]
|
93 |
+
|
94 |
+
gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
|
95 |
+
cam_tensors, grid_size = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
|
96 |
+
if focus == "Visual Encoder":
|
97 |
+
cam_grid = cam_tensors.reshape(grid_size, grid_size)
|
98 |
+
cam = [generate_gradcam(cam_grid, image)]
|
99 |
+
else:
|
100 |
+
if target_token_idx != -1:
|
101 |
+
input_text_decoded = input_ids_decoded[start + target_token_idx]
|
102 |
+
for i, cam_tensor in enumerate(cam_tensors):
|
103 |
+
if i == target_token_idx:
|
104 |
+
cam_grid = cam_tensor.reshape(grid_size, grid_size)
|
105 |
+
cam_i = generate_gradcam(cam_grid, image)
|
106 |
+
cam = [add_title_to_image(cam_i, input_text_decoded)]
|
107 |
+
break
|
108 |
+
else:
|
109 |
+
cam = []
|
110 |
+
for i, cam_tensor in enumerate(cam_tensors):
|
111 |
+
cam_grid = cam_tensor.reshape(24, 24)
|
112 |
+
cam_i = generate_gradcam(cam_grid, image)
|
113 |
+
cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
|
114 |
+
|
115 |
+
cam.append(cam_i)
|
116 |
+
|
117 |
+
# widths, heights = zip(*(img.size for img in heatmaps))
|
118 |
+
# total_height = sum(heights)
|
119 |
+
# max_width = max(widths)
|
120 |
+
|
121 |
+
# combined_img = Image.new("RGB", (max_width, total_height))
|
122 |
+
|
123 |
+
# y_offset = 0
|
124 |
+
# for img in heatmaps:
|
125 |
+
# combined_img.paste(img, (0, y_offset)) # Stack vertically
|
126 |
+
# y_offset += img.height
|
127 |
+
# cam = combined_img
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
elif saliency_map_method == "Attention_Map":
|
133 |
+
attn_m_token = attention_raw[target_token_idx]
|
134 |
+
img_token_positions = prepare_inputs.images_seq_mask
|
135 |
+
mask = img_token_positions[0]
|
136 |
+
|
137 |
+
tg = attn_m_token[1][:, :, :, :len(mask)]
|
138 |
+
tg = tg[:, :, :, mask]
|
139 |
+
head = 0
|
140 |
+
|
141 |
+
# res = tg[0, head, 0].to(torch.float32)
|
142 |
+
res, _ = tg.max(dim=1)
|
143 |
+
# res = tg.sum(dim=1)
|
144 |
+
res = res.to(torch.float32)
|
145 |
+
grid_size = (int)(res.shape[-1] ** 0.5)
|
146 |
+
res = res.view(grid_size, grid_size)
|
147 |
+
cam = [generate_gradcam(res, image)]
|
148 |
+
|
149 |
+
|
150 |
+
# output_arr = output.logits.detach().to(float).to("cpu").numpy()
|
151 |
+
# predicted_ids = np.argmax(output_arr, axis=-1) # [1, num_tokens]
|
152 |
+
# predicted_ids = predicted_ids.squeeze(0) # [num_tokens]
|
153 |
+
# target_token_decoded = tokenizer.decode(predicted_ids[target_token_idx].tolist())
|
154 |
+
|
155 |
+
|
156 |
+
return answer, cam, input_text_decoded
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
# Gradio interface
|
162 |
+
|
163 |
+
def update_sliders(model):
|
164 |
+
if model == "Clip":
|
165 |
+
res = (
|
166 |
+
gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
|
167 |
+
gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max"),
|
168 |
+
gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
|
169 |
+
)
|
170 |
+
return res
|
171 |
+
else:
|
172 |
+
res = (
|
173 |
+
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
|
174 |
+
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max"),
|
175 |
+
gr.Dropdown(choices=["Visual Encoder", "Language Model"], value="Visual Encoder", label="focus")
|
176 |
+
)
|
177 |
+
return res
|
178 |
+
|
179 |
+
def update_visualization_layers_sliders(focus):
|
180 |
+
if focus == "Visual Encoder":
|
181 |
+
res = (
|
182 |
+
gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type"),
|
183 |
+
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
|
184 |
+
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max")
|
185 |
+
)
|
186 |
+
return res
|
187 |
+
else:
|
188 |
+
res = (
|
189 |
+
gr.Dropdown(choices=["GradCAM", "Attention_Map"], value="GradCAM", label="saliency map type"),
|
190 |
+
gr.Slider(minimum=1, maximum=24, value=9, step=1, label="visualization layers min"),
|
191 |
+
gr.Slider(minimum=1, maximum=24, value=9, step=1, label="visualization layers max")
|
192 |
+
)
|
193 |
+
return res
|
194 |
+
|
195 |
+
with gr.Blocks() as demo:
|
196 |
+
gr.Markdown(value="# Multimodal Understanding")
|
197 |
+
with gr.Row():
|
198 |
+
with gr.Column():
|
199 |
+
image_input = gr.Image()
|
200 |
+
saliency_map_output = gr.Gallery(label="Saliency Map", columns=1)
|
201 |
+
|
202 |
+
with gr.Column():
|
203 |
+
model_selector = gr.Dropdown(choices=["Clip", "Janus-1B"], value="Clip", label="model")
|
204 |
+
focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
|
205 |
+
saliency_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type")
|
206 |
+
visual_pooling_method = gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
|
207 |
+
|
208 |
+
|
209 |
+
visualization_layers_min = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min")
|
210 |
+
visualization_layers_max = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
|
211 |
+
|
212 |
+
question_input = gr.Textbox(label="Question")
|
213 |
+
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
214 |
+
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
215 |
+
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
216 |
+
target_token_idx = gr.Number(label="target_token_idx (-1 means all)", precision=0, value=-1)
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
model_selector.change(
|
221 |
+
fn=update_sliders,
|
222 |
+
inputs=model_selector,
|
223 |
+
outputs=[
|
224 |
+
visualization_layers_min,
|
225 |
+
visualization_layers_max,
|
226 |
+
focus
|
227 |
+
]
|
228 |
+
)
|
229 |
+
|
230 |
+
focus.change(
|
231 |
+
fn = update_visualization_layers_sliders,
|
232 |
+
inputs = focus,
|
233 |
+
outputs=[
|
234 |
+
saliency_map_method,
|
235 |
+
visualization_layers_min,
|
236 |
+
visualization_layers_max,
|
237 |
+
]
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
understanding_button = gr.Button("Chat")
|
243 |
+
understanding_output = gr.Textbox(label="Response")
|
244 |
+
understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
|
245 |
+
|
246 |
+
|
247 |
+
examples_inpainting = gr.Examples(
|
248 |
+
label="Multimodal Understanding examples",
|
249 |
+
examples=[
|
250 |
+
|
251 |
+
[
|
252 |
+
"What is the approximate global smartphone market share of Samsung?",
|
253 |
+
"images/PieChart.png"
|
254 |
+
],
|
255 |
+
[
|
256 |
+
"What is the average internet speed in Japan?",
|
257 |
+
"images/BarChart.png"
|
258 |
+
],
|
259 |
+
[
|
260 |
+
"What was the average price of coffee beans in October 2019?",
|
261 |
+
"images/AreaChart.png"
|
262 |
+
],
|
263 |
+
[
|
264 |
+
"Which city's metro system has the largest number of stations?",
|
265 |
+
"images/BubbleChart.png"
|
266 |
+
],
|
267 |
+
|
268 |
+
[
|
269 |
+
"True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
|
270 |
+
"images/Choropleth_New.png"
|
271 |
+
],
|
272 |
+
|
273 |
+
[
|
274 |
+
"What distance have customers traveled in the taxi the most?",
|
275 |
+
"images/Histogram.png"
|
276 |
+
],
|
277 |
+
|
278 |
+
[
|
279 |
+
"What was the price of a barrel of oil in February 2020?",
|
280 |
+
"images/LineChart.png"
|
281 |
+
],
|
282 |
+
|
283 |
+
[
|
284 |
+
"True/False: eBay is nested in the Software category.",
|
285 |
+
"images/Treemap.png"
|
286 |
+
],
|
287 |
+
|
288 |
+
[
|
289 |
+
"True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
|
290 |
+
"images/Scatterplot.png"
|
291 |
+
],
|
292 |
+
|
293 |
+
[
|
294 |
+
"Which country has the lowest proportion of Gold medals?",
|
295 |
+
"images/Stacked100.png"
|
296 |
+
],
|
297 |
+
|
298 |
+
[
|
299 |
+
"What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
|
300 |
+
"images/StackedArea.png"
|
301 |
+
],
|
302 |
+
|
303 |
+
[
|
304 |
+
"What is the cost of peanuts in Seoul?",
|
305 |
+
"images/StackedBar.png"
|
306 |
+
],
|
307 |
+
|
308 |
+
|
309 |
+
# [
|
310 |
+
# "explain this meme",
|
311 |
+
# "images/doge.png",
|
312 |
+
# ],
|
313 |
+
# [
|
314 |
+
# "Convert the formula into latex code.",
|
315 |
+
# "images/equation.png",
|
316 |
+
# ],
|
317 |
+
|
318 |
+
],
|
319 |
+
inputs=[question_input, image_input],
|
320 |
+
)
|
321 |
+
|
322 |
+
|
323 |
+
|
324 |
+
|
325 |
+
understanding_button.click(
|
326 |
+
multimodal_understanding,
|
327 |
+
inputs=[model_selector, saliency_map_method, visual_pooling_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
|
328 |
+
visualization_layers_min, visualization_layers_max, focus],
|
329 |
+
outputs=[understanding_output, saliency_map_output, understanding_target_token_decoded_output]
|
330 |
+
)
|
331 |
+
|
332 |
+
demo.launch(share=True)
|
333 |
+
# demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path")
|
demo/cam.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import types
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
from torch import nn
|
9 |
+
import spaces
|
10 |
+
from demo.modify_llama import *
|
11 |
+
|
12 |
+
|
13 |
+
class AttentionGuidedCAM:
|
14 |
+
def __init__(self, model):
|
15 |
+
self.model = model
|
16 |
+
self.gradients = []
|
17 |
+
self.activations = []
|
18 |
+
self.hooks = []
|
19 |
+
self._register_hooks()
|
20 |
+
|
21 |
+
def _register_hooks(self):
|
22 |
+
""" Registers hooks to extract activations and gradients from ALL attention layers. """
|
23 |
+
for layer in self.target_layers:
|
24 |
+
self.hooks.append(layer.register_forward_hook(self._forward_hook))
|
25 |
+
self.hooks.append(layer.register_backward_hook(self._backward_hook))
|
26 |
+
|
27 |
+
def _forward_hook(self, module, input, output):
|
28 |
+
""" Stores attention maps (before softmax) """
|
29 |
+
self.activations.append(output)
|
30 |
+
|
31 |
+
def _backward_hook(self, module, grad_in, grad_out):
|
32 |
+
""" Stores gradients """
|
33 |
+
self.gradients.append(grad_out[0])
|
34 |
+
|
35 |
+
|
36 |
+
def remove_hooks(self):
|
37 |
+
""" Remove hooks after usage. """
|
38 |
+
for hook in self.hooks:
|
39 |
+
hook.remove()
|
40 |
+
|
41 |
+
@spaces.GPU(duration=120)
|
42 |
+
def generate_cam(self, input_tensor, class_idx=None):
|
43 |
+
raise NotImplementedError
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
class AttentionGuidedCAMClip(AttentionGuidedCAM):
|
49 |
+
def __init__(self, model, target_layers):
|
50 |
+
self.target_layers = target_layers
|
51 |
+
super().__init__(model)
|
52 |
+
|
53 |
+
@spaces.GPU(duration=120)
|
54 |
+
def generate_cam(self, input_tensor, class_idx=None, visual_pooling_method="CLS"):
|
55 |
+
""" Generates Grad-CAM heatmap for ViT. """
|
56 |
+
|
57 |
+
# Forward pass
|
58 |
+
output_full = self.model(**input_tensor)
|
59 |
+
|
60 |
+
if class_idx is None:
|
61 |
+
class_idx = torch.argmax(output_full.logits, dim=1).item()
|
62 |
+
|
63 |
+
if visual_pooling_method == "CLS":
|
64 |
+
output = output_full.image_embeds
|
65 |
+
elif visual_pooling_method == "avg":
|
66 |
+
output = self.model.visual_projection(output_full.vision_model_output.last_hidden_state).mean(dim=1)
|
67 |
+
else:
|
68 |
+
# project -> pooling
|
69 |
+
output, _ = self.model.visual_projection(output_full.vision_model_output.last_hidden_state).max(dim=1)
|
70 |
+
|
71 |
+
# pooling -> project
|
72 |
+
# output_mx, _ = output_full.vision_model_output.last_hidden_state.max(dim=1)
|
73 |
+
# output = self.model.visual_projection(output_mx)
|
74 |
+
|
75 |
+
output.backward(output_full.text_embeds[class_idx:class_idx+1], retain_graph=True)
|
76 |
+
|
77 |
+
# Aggregate activations and gradients from ALL layers
|
78 |
+
print(self.activations, self.gradients)
|
79 |
+
self.model.zero_grad()
|
80 |
+
cam_sum = None
|
81 |
+
for act, grad in zip(self.activations, self.gradients):
|
82 |
+
|
83 |
+
# act = torch.sigmoid(act[0])
|
84 |
+
act = F.relu(act[0])
|
85 |
+
|
86 |
+
grad_weights = grad.mean(dim=-1, keepdim=True)
|
87 |
+
|
88 |
+
|
89 |
+
print("act shape", act.shape)
|
90 |
+
print("grad_weights shape", grad_weights.shape)
|
91 |
+
|
92 |
+
# cam = (act * grad_weights).sum(dim=-1) # Weighted activation map
|
93 |
+
cam, _ = (act * grad_weights).max(dim=-1)
|
94 |
+
# cam, _ = grad_weights.max(dim=-1)
|
95 |
+
# cam = self.normalize(cam)
|
96 |
+
print(cam.shape)
|
97 |
+
|
98 |
+
# Sum across all layers
|
99 |
+
if cam_sum is None:
|
100 |
+
cam_sum = cam
|
101 |
+
else:
|
102 |
+
cam_sum += cam
|
103 |
+
|
104 |
+
|
105 |
+
# Normalize
|
106 |
+
cam_sum = F.relu(cam_sum)
|
107 |
+
|
108 |
+
# thresholding
|
109 |
+
cam_sum = cam_sum.to(torch.float32)
|
110 |
+
percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
|
111 |
+
cam_sum[cam_sum < percentile] = 0
|
112 |
+
|
113 |
+
# Reshape
|
114 |
+
print("cam_sum shape: ", cam_sum.shape)
|
115 |
+
cam_sum = cam_sum[0, 1:]
|
116 |
+
|
117 |
+
num_patches = cam_sum.shape[-1] # Last dimension of CAM output
|
118 |
+
grid_size = int(num_patches ** 0.5)
|
119 |
+
print(f"Detected grid size: {grid_size}x{grid_size}")
|
120 |
+
|
121 |
+
cam_sum = cam_sum.view(grid_size, grid_size).detach()
|
122 |
+
cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
|
123 |
+
|
124 |
+
return cam_sum, output_full, grid_size
|
125 |
+
|
126 |
+
|
127 |
+
class AttentionGuidedCAMJanus(AttentionGuidedCAM):
|
128 |
+
def __init__(self, model, target_layers):
|
129 |
+
self.target_layers = target_layers
|
130 |
+
super().__init__(model)
|
131 |
+
self._modify_layers()
|
132 |
+
self._register_hooks_activations()
|
133 |
+
|
134 |
+
def _modify_layers(self):
|
135 |
+
for layer in self.target_layers:
|
136 |
+
setattr(layer, "attn_gradients", None)
|
137 |
+
setattr(layer, "attention_map", None)
|
138 |
+
|
139 |
+
layer.save_attn_gradients = types.MethodType(save_attn_gradients, layer)
|
140 |
+
layer.get_attn_gradients = types.MethodType(get_attn_gradients, layer)
|
141 |
+
layer.save_attn_map = types.MethodType(save_attn_map, layer)
|
142 |
+
layer.get_attn_map = types.MethodType(get_attn_map, layer)
|
143 |
+
|
144 |
+
def _forward_activate_hooks(self, module, input, output):
|
145 |
+
attn_output, attn_weights = output # Unpack outputs
|
146 |
+
module.save_attn_map(attn_weights)
|
147 |
+
attn_weights.register_hook(module.save_attn_gradients)
|
148 |
+
|
149 |
+
def _register_hooks_activations(self):
|
150 |
+
for layer in self.target_layers:
|
151 |
+
if hasattr(layer, "q_proj"): # is an attention layer
|
152 |
+
self.hooks.append(layer.register_forward_hook(self._forward_activate_hooks))
|
153 |
+
|
154 |
+
@spaces.GPU(duration=120)
|
155 |
+
def generate_cam(self, input_tensor, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
|
156 |
+
""" Generates Grad-CAM heatmap for ViT. """
|
157 |
+
|
158 |
+
|
159 |
+
# Forward pass
|
160 |
+
image_embeddings, inputs_embeddings, outputs = self.model(input_tensor, tokenizer, temperature, top_p)
|
161 |
+
|
162 |
+
|
163 |
+
input_ids = input_tensor.input_ids
|
164 |
+
|
165 |
+
if focus == "Visual Encoder":
|
166 |
+
# Pooling
|
167 |
+
if visual_pooling_method == "CLS":
|
168 |
+
image_embeddings_pooled = image_embeddings[:, 0, :]
|
169 |
+
elif visual_pooling_method == "avg":
|
170 |
+
image_embeddings_pooled = image_embeddings[:, 1:, :].mean(dim=1) # end of image: 618
|
171 |
+
elif visual_pooling_method == "max":
|
172 |
+
image_embeddings_pooled, _ = image_embeddings[:, 1:, :].max(dim=1)
|
173 |
+
|
174 |
+
print("image_embeddings_shape: ", image_embeddings_pooled.shape)
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
inputs_embeddings_pooled = inputs_embeddings[:, 620: -4].mean(dim=1)
|
179 |
+
self.model.zero_grad()
|
180 |
+
image_embeddings_pooled.backward(inputs_embeddings_pooled, retain_graph=True)
|
181 |
+
|
182 |
+
cam_sum = None
|
183 |
+
for act, grad in zip(self.activations, self.gradients):
|
184 |
+
# act = torch.sigmoid(act)
|
185 |
+
act = F.relu(act[0])
|
186 |
+
|
187 |
+
|
188 |
+
# Compute mean of gradients
|
189 |
+
grad_weights = grad.mean(dim=-1, keepdim=True)
|
190 |
+
|
191 |
+
print("act shape", act.shape)
|
192 |
+
print("grad_weights shape", grad_weights.shape)
|
193 |
+
|
194 |
+
cam, _ = (act * grad_weights).max(dim=-1)
|
195 |
+
print(cam.shape)
|
196 |
+
|
197 |
+
# Sum across all layers
|
198 |
+
if cam_sum is None:
|
199 |
+
cam_sum = cam
|
200 |
+
else:
|
201 |
+
cam_sum += cam
|
202 |
+
|
203 |
+
# Normalize
|
204 |
+
cam_sum = F.relu(cam_sum)
|
205 |
+
|
206 |
+
|
207 |
+
# thresholding
|
208 |
+
cam_sum = cam_sum.to(torch.float32)
|
209 |
+
percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
|
210 |
+
cam_sum[cam_sum < percentile] = 0
|
211 |
+
|
212 |
+
# Reshape
|
213 |
+
# if visual_pooling_method == "CLS":
|
214 |
+
cam_sum = cam_sum[0, 1:]
|
215 |
+
print("cam_sum shape: ", cam_sum.shape)
|
216 |
+
num_patches = cam_sum.shape[-1] # Last dimension of CAM output
|
217 |
+
grid_size = int(num_patches ** 0.5)
|
218 |
+
print(f"Detected grid size: {grid_size}x{grid_size}")
|
219 |
+
|
220 |
+
cam_sum = cam_sum.view(grid_size, grid_size)
|
221 |
+
cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
|
222 |
+
cam_sum = cam_sum.detach().to("cpu")
|
223 |
+
|
224 |
+
return cam_sum, grid_size
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
elif focus == "Language Model":
|
232 |
+
loss = self.target_layers[-1].attention_map.sum()
|
233 |
+
self.model.zero_grad()
|
234 |
+
loss.backward()
|
235 |
+
|
236 |
+
self.activations = [layer.get_attn_map() for layer in self.target_layers]
|
237 |
+
self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
|
238 |
+
|
239 |
+
cam_sum = None
|
240 |
+
for act, grad in zip(self.activations, self.gradients):
|
241 |
+
# act = torch.sigmoid(act)
|
242 |
+
print("act:", act)
|
243 |
+
print(len(act))
|
244 |
+
print("act_shape:", act.shape)
|
245 |
+
# print("act1_shape:", act[1].shape)
|
246 |
+
|
247 |
+
act = F.relu(act.mean(dim=1))
|
248 |
+
|
249 |
+
|
250 |
+
# Compute mean of gradients
|
251 |
+
print("grad:", grad)
|
252 |
+
print(len(grad))
|
253 |
+
print("grad_shape:", grad.shape)
|
254 |
+
grad_weights = grad.mean(dim=1)
|
255 |
+
|
256 |
+
print("act:", act)
|
257 |
+
print("act shape", act.shape)
|
258 |
+
print("grad_weights shape", grad_weights.shape)
|
259 |
+
|
260 |
+
# cam, _ = (act * grad_weights).max(dim=-1)
|
261 |
+
# cam = act * grad_weights
|
262 |
+
cam = act * grad_weights
|
263 |
+
print(cam.shape)
|
264 |
+
|
265 |
+
# Sum across all layers
|
266 |
+
if cam_sum is None:
|
267 |
+
cam_sum = cam
|
268 |
+
else:
|
269 |
+
cam_sum += cam
|
270 |
+
|
271 |
+
# Normalize
|
272 |
+
cam_sum = F.relu(cam_sum)
|
273 |
+
# cam_sum = cam_sum - cam_sum.min()
|
274 |
+
# cam_sum = cam_sum / cam_sum.max()
|
275 |
+
|
276 |
+
# thresholding
|
277 |
+
cam_sum = cam_sum.to(torch.float32)
|
278 |
+
percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
|
279 |
+
cam_sum[cam_sum < percentile] = 0
|
280 |
+
|
281 |
+
# Reshape
|
282 |
+
# if visual_pooling_method == "CLS":
|
283 |
+
# cam_sum = cam_sum[0, 1:]
|
284 |
+
|
285 |
+
# cam_sum shape: [1, seq_len, seq_len]
|
286 |
+
cam_sum_lst = []
|
287 |
+
cam_sum_raw = cam_sum
|
288 |
+
for i in range(620, cam_sum_raw.shape[1]):
|
289 |
+
cam_sum = cam_sum_raw[:, i, :] # shape: [1: seq_len]
|
290 |
+
cam_sum = cam_sum[input_tensor.images_seq_mask].unsqueeze(0) # shape: [1, 576]
|
291 |
+
print("cam_sum shape: ", cam_sum.shape)
|
292 |
+
num_patches = cam_sum.shape[-1] # Last dimension of CAM output
|
293 |
+
grid_size = int(num_patches ** 0.5)
|
294 |
+
print(f"Detected grid size: {grid_size}x{grid_size}")
|
295 |
+
|
296 |
+
# Fix the reshaping step dynamically
|
297 |
+
|
298 |
+
cam_sum = cam_sum.view(grid_size, grid_size)
|
299 |
+
cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
|
300 |
+
cam_sum = cam_sum.detach().to("cpu")
|
301 |
+
cam_sum_lst.append(cam_sum)
|
302 |
+
|
303 |
+
|
304 |
+
return cam_sum_lst, grid_size
|
305 |
+
|
306 |
+
# Aggregate activations and gradients from ALL layers
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
|
318 |
+
def __init__(self, model, target_layers):
|
319 |
+
self.target_layers = target_layers
|
320 |
+
super().__init__(model)
|
321 |
+
self._modify_layers()
|
322 |
+
self._register_hooks_activations()
|
323 |
+
|
324 |
+
def _modify_layers(self):
|
325 |
+
for layer in self.target_layers:
|
326 |
+
setattr(layer, "attn_gradients", None)
|
327 |
+
setattr(layer, "attention_map", None)
|
328 |
+
|
329 |
+
layer.save_attn_gradients = types.MethodType(save_attn_gradients, layer)
|
330 |
+
layer.get_attn_gradients = types.MethodType(get_attn_gradients, layer)
|
331 |
+
layer.save_attn_map = types.MethodType(save_attn_map, layer)
|
332 |
+
layer.get_attn_map = types.MethodType(get_attn_map, layer)
|
333 |
+
|
334 |
+
def _forward_activate_hooks(self, module, input, output):
|
335 |
+
attn_output, attn_weights = output # Unpack outputs
|
336 |
+
attn_weights.requires_grad_()
|
337 |
+
module.save_attn_map(attn_weights)
|
338 |
+
attn_weights.register_hook(module.save_attn_gradients)
|
339 |
+
|
340 |
+
def _register_hooks_activations(self):
|
341 |
+
for layer in self.target_layers:
|
342 |
+
if hasattr(layer, "q_proj"): # is an attention layer
|
343 |
+
self.hooks.append(layer.register_forward_hook(self._forward_activate_hooks))
|
344 |
+
|
345 |
+
@spaces.GPU(duration=120)
|
346 |
+
def generate_cam(self, input_tensor, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
|
347 |
+
""" Generates Grad-CAM heatmap for ViT. """
|
348 |
+
|
349 |
+
|
350 |
+
# Forward pass
|
351 |
+
outputs_raw = self.model(**input_tensor)
|
352 |
+
|
353 |
+
if focus == "Language Model":
|
354 |
+
loss = self.target_layers[-1].attention_map.sum()
|
355 |
+
self.model.zero_grad()
|
356 |
+
loss.backward()
|
357 |
+
|
358 |
+
self.activations = [layer.get_attn_map() for layer in self.target_layers]
|
359 |
+
self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
|
360 |
+
|
361 |
+
cam_sum = None
|
362 |
+
for act, grad in zip(self.activations, self.gradients):
|
363 |
+
# act = torch.sigmoid(act)
|
364 |
+
print("act:", act)
|
365 |
+
print(len(act))
|
366 |
+
print("act_shape:", act.shape)
|
367 |
+
# print("act1_shape:", act[1].shape)
|
368 |
+
|
369 |
+
act = F.relu(act.mean(dim=1))
|
370 |
+
|
371 |
+
|
372 |
+
# Compute mean of gradients
|
373 |
+
print("grad:", grad)
|
374 |
+
print(len(grad))
|
375 |
+
print("grad_shape:", grad.shape)
|
376 |
+
grad_weights = grad.mean(dim=1)
|
377 |
+
|
378 |
+
print("act:", act)
|
379 |
+
print("act shape", act.shape)
|
380 |
+
print("grad_weights shape", grad_weights.shape)
|
381 |
+
|
382 |
+
# cam, _ = (act * grad_weights).max(dim=-1)
|
383 |
+
# cam = act * grad_weights
|
384 |
+
cam = act * grad_weights
|
385 |
+
print(cam.shape)
|
386 |
+
|
387 |
+
# Sum across all layers
|
388 |
+
if cam_sum is None:
|
389 |
+
cam_sum = cam
|
390 |
+
else:
|
391 |
+
cam_sum += cam
|
392 |
+
|
393 |
+
# Normalize
|
394 |
+
cam_sum = F.relu(cam_sum)
|
395 |
+
# cam_sum = cam_sum - cam_sum.min()
|
396 |
+
# cam_sum = cam_sum / cam_sum.max()
|
397 |
+
|
398 |
+
# thresholding
|
399 |
+
cam_sum = cam_sum.to(torch.float32)
|
400 |
+
percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
|
401 |
+
cam_sum[cam_sum < percentile] = 0
|
402 |
+
|
403 |
+
# Reshape
|
404 |
+
# if visual_pooling_method == "CLS":
|
405 |
+
# cam_sum = cam_sum[0, 1:]
|
406 |
+
|
407 |
+
# cam_sum shape: [1, seq_len, seq_len]
|
408 |
+
cam_sum_lst = []
|
409 |
+
cam_sum_raw = cam_sum
|
410 |
+
grid_size = 32
|
411 |
+
for i in range(512, cam_sum_raw.shape[1]):
|
412 |
+
cam_sum = cam_sum_raw[:, i, :] # shape: [1: seq_len]
|
413 |
+
cam_sum = cam_sum[input_tensor.images_seq_mask].unsqueeze(0) # shape: [1, 576]
|
414 |
+
print("cam_sum shape: ", cam_sum.shape)
|
415 |
+
num_patches = cam_sum.shape[-1] # Last dimension of CAM output
|
416 |
+
grid_size = int(num_patches ** 0.5)
|
417 |
+
print(f"Detected grid size: {grid_size}x{grid_size}")
|
418 |
+
|
419 |
+
# Fix the reshaping step dynamically
|
420 |
+
|
421 |
+
cam_sum = cam_sum.view(grid_size, grid_size)
|
422 |
+
cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
|
423 |
+
cam_sum = cam_sum.detach().to("cpu")
|
424 |
+
cam_sum_lst.append(cam_sum)
|
425 |
+
|
426 |
+
|
427 |
+
return cam_sum_lst, grid_size
|
428 |
+
|
429 |
+
|
430 |
+
|
431 |
+
|
432 |
+
def generate_gradcam(
|
433 |
+
cam,
|
434 |
+
image,
|
435 |
+
size = (384, 384),
|
436 |
+
alpha=0.5,
|
437 |
+
colormap=cv2.COLORMAP_JET,
|
438 |
+
aggregation='mean',
|
439 |
+
normalize=True
|
440 |
+
):
|
441 |
+
"""
|
442 |
+
Generates a Grad-CAM heatmap overlay on top of the input image.
|
443 |
+
|
444 |
+
Parameters:
|
445 |
+
attributions (torch.Tensor): A tensor of shape (C, H, W) representing the
|
446 |
+
intermediate activations or gradients at the target layer.
|
447 |
+
image (PIL.Image): The original image.
|
448 |
+
alpha (float): The blending factor for the heatmap overlay (default 0.5).
|
449 |
+
colormap (int): OpenCV colormap to apply (default cv2.COLORMAP_JET).
|
450 |
+
aggregation (str): How to aggregate across channels; either 'mean' or 'sum'.
|
451 |
+
|
452 |
+
Returns:
|
453 |
+
PIL.Image: The image overlaid with the Grad-CAM heatmap.
|
454 |
+
"""
|
455 |
+
print("Generating Grad-CAM with shape:", cam.shape)
|
456 |
+
|
457 |
+
if normalize:
|
458 |
+
cam_min, cam_max = cam.min(), cam.max()
|
459 |
+
cam = cam - cam_min
|
460 |
+
cam = cam / (cam_max - cam_min)
|
461 |
+
# Convert tensor to numpy array
|
462 |
+
cam = torch.nn.functional.interpolate(cam.unsqueeze(0).unsqueeze(0), size=size, mode='bilinear').squeeze()
|
463 |
+
cam_np = cam.squeeze().detach().cpu().numpy()
|
464 |
+
|
465 |
+
# Apply Gaussian blur for smoother heatmaps
|
466 |
+
cam_np = cv2.GaussianBlur(cam_np, (5,5), sigmaX=0.8)
|
467 |
+
|
468 |
+
# Resize the cam to match the image size
|
469 |
+
width, height = size
|
470 |
+
cam_resized = cv2.resize(cam_np, (width, height))
|
471 |
+
|
472 |
+
# Convert the normalized map to a heatmap (0-255 uint8)
|
473 |
+
heatmap = np.uint8(255 * cam_resized)
|
474 |
+
heatmap = cv2.applyColorMap(heatmap, colormap)
|
475 |
+
# OpenCV produces heatmaps in BGR, so convert to RGB for consistency
|
476 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
477 |
+
|
478 |
+
# Convert original image to a numpy array
|
479 |
+
image_np = np.array(image)
|
480 |
+
image_np = cv2.resize(image_np, (width, height))
|
481 |
+
|
482 |
+
# Blend the heatmap with the original image
|
483 |
+
overlay = cv2.addWeighted(image_np, 1 - alpha, heatmap, alpha, 0)
|
484 |
+
|
485 |
+
return Image.fromarray(overlay)
|
486 |
+
|
demo/demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
demo/demo_attn.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
demo/fastapi_app.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
2 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
3 |
+
import torch
|
4 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
5 |
+
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import io
|
9 |
+
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
# Load model and processor
|
13 |
+
model_path = "deepseek-ai/Janus-1.3B"
|
14 |
+
config = AutoConfig.from_pretrained(model_path)
|
15 |
+
language_config = config.language_config
|
16 |
+
language_config._attn_implementation = 'eager'
|
17 |
+
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
18 |
+
language_config=language_config,
|
19 |
+
trust_remote_code=True)
|
20 |
+
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
21 |
+
|
22 |
+
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
23 |
+
tokenizer = vl_chat_processor.tokenizer
|
24 |
+
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
25 |
+
|
26 |
+
|
27 |
+
@torch.inference_mode()
|
28 |
+
def multimodal_understanding(image_data, question, seed, top_p, temperature):
|
29 |
+
torch.cuda.empty_cache()
|
30 |
+
torch.manual_seed(seed)
|
31 |
+
np.random.seed(seed)
|
32 |
+
torch.cuda.manual_seed(seed)
|
33 |
+
|
34 |
+
conversation = [
|
35 |
+
{
|
36 |
+
"role": "User",
|
37 |
+
"content": f"<image_placeholder>\n{question}",
|
38 |
+
"images": [image_data],
|
39 |
+
},
|
40 |
+
{"role": "Assistant", "content": ""},
|
41 |
+
]
|
42 |
+
|
43 |
+
pil_images = [Image.open(io.BytesIO(image_data))]
|
44 |
+
prepare_inputs = vl_chat_processor(
|
45 |
+
conversations=conversation, images=pil_images, force_batchify=True
|
46 |
+
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
47 |
+
|
48 |
+
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
49 |
+
outputs = vl_gpt.language_model.generate(
|
50 |
+
inputs_embeds=inputs_embeds,
|
51 |
+
attention_mask=prepare_inputs.attention_mask,
|
52 |
+
pad_token_id=tokenizer.eos_token_id,
|
53 |
+
bos_token_id=tokenizer.bos_token_id,
|
54 |
+
eos_token_id=tokenizer.eos_token_id,
|
55 |
+
max_new_tokens=512,
|
56 |
+
do_sample=False if temperature == 0 else True,
|
57 |
+
use_cache=True,
|
58 |
+
temperature=temperature,
|
59 |
+
top_p=top_p,
|
60 |
+
)
|
61 |
+
|
62 |
+
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
63 |
+
return answer
|
64 |
+
|
65 |
+
|
66 |
+
@app.post("/understand_image_and_question/")
|
67 |
+
async def understand_image_and_question(
|
68 |
+
file: UploadFile = File(...),
|
69 |
+
question: str = Form(...),
|
70 |
+
seed: int = Form(42),
|
71 |
+
top_p: float = Form(0.95),
|
72 |
+
temperature: float = Form(0.1)
|
73 |
+
):
|
74 |
+
image_data = await file.read()
|
75 |
+
response = multimodal_understanding(image_data, question, seed, top_p, temperature)
|
76 |
+
return JSONResponse({"response": response})
|
77 |
+
|
78 |
+
|
79 |
+
def generate(input_ids,
|
80 |
+
width,
|
81 |
+
height,
|
82 |
+
temperature: float = 1,
|
83 |
+
parallel_size: int = 5,
|
84 |
+
cfg_weight: float = 5,
|
85 |
+
image_token_num_per_image: int = 576,
|
86 |
+
patch_size: int = 16):
|
87 |
+
torch.cuda.empty_cache()
|
88 |
+
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
89 |
+
for i in range(parallel_size * 2):
|
90 |
+
tokens[i, :] = input_ids
|
91 |
+
if i % 2 != 0:
|
92 |
+
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
93 |
+
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
94 |
+
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
|
95 |
+
|
96 |
+
pkv = None
|
97 |
+
for i in range(image_token_num_per_image):
|
98 |
+
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
|
99 |
+
pkv = outputs.past_key_values
|
100 |
+
hidden_states = outputs.last_hidden_state
|
101 |
+
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
102 |
+
logit_cond = logits[0::2, :]
|
103 |
+
logit_uncond = logits[1::2, :]
|
104 |
+
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
105 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
106 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
107 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
108 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
109 |
+
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
110 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
111 |
+
patches = vl_gpt.gen_vision_model.decode_code(
|
112 |
+
generated_tokens.to(dtype=torch.int),
|
113 |
+
shape=[parallel_size, 8, width // patch_size, height // patch_size]
|
114 |
+
)
|
115 |
+
|
116 |
+
return generated_tokens.to(dtype=torch.int), patches
|
117 |
+
|
118 |
+
|
119 |
+
def unpack(dec, width, height, parallel_size=5):
|
120 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
121 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
122 |
+
|
123 |
+
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
124 |
+
visual_img[:, :, :] = dec
|
125 |
+
|
126 |
+
return visual_img
|
127 |
+
|
128 |
+
|
129 |
+
@torch.inference_mode()
|
130 |
+
def generate_image(prompt, seed, guidance):
|
131 |
+
torch.cuda.empty_cache()
|
132 |
+
seed = seed if seed is not None else 12345
|
133 |
+
torch.manual_seed(seed)
|
134 |
+
torch.cuda.manual_seed(seed)
|
135 |
+
np.random.seed(seed)
|
136 |
+
width = 384
|
137 |
+
height = 384
|
138 |
+
parallel_size = 5
|
139 |
+
|
140 |
+
with torch.no_grad():
|
141 |
+
messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}]
|
142 |
+
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
143 |
+
conversations=messages,
|
144 |
+
sft_format=vl_chat_processor.sft_format,
|
145 |
+
system_prompt=''
|
146 |
+
)
|
147 |
+
text = text + vl_chat_processor.image_start_tag
|
148 |
+
input_ids = torch.LongTensor(tokenizer.encode(text))
|
149 |
+
_, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size)
|
150 |
+
images = unpack(patches, width // 16 * 16, height // 16 * 16)
|
151 |
+
|
152 |
+
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
|
153 |
+
|
154 |
+
|
155 |
+
@app.post("/generate_images/")
|
156 |
+
async def generate_images(
|
157 |
+
prompt: str = Form(...),
|
158 |
+
seed: int = Form(None),
|
159 |
+
guidance: float = Form(5.0),
|
160 |
+
):
|
161 |
+
try:
|
162 |
+
images = generate_image(prompt, seed, guidance)
|
163 |
+
def image_stream():
|
164 |
+
for img in images:
|
165 |
+
buf = io.BytesIO()
|
166 |
+
img.save(buf, format='PNG')
|
167 |
+
buf.seek(0)
|
168 |
+
yield buf.read()
|
169 |
+
|
170 |
+
return StreamingResponse(image_stream(), media_type="multipart/related")
|
171 |
+
except Exception as e:
|
172 |
+
raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
import uvicorn
|
178 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
demo/fastapi_client.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from PIL import Image
|
3 |
+
import io
|
4 |
+
# Endpoint URLs
|
5 |
+
understand_image_url = "http://localhost:8000/understand_image_and_question/"
|
6 |
+
generate_images_url = "http://localhost:8000/generate_images/"
|
7 |
+
|
8 |
+
# Use your image file path here
|
9 |
+
image_path = "images/equation.png"
|
10 |
+
|
11 |
+
# Function to call the image understanding endpoint
|
12 |
+
def understand_image_and_question(image_path, question, seed=42, top_p=0.95, temperature=0.1):
|
13 |
+
files = {'file': open(image_path, 'rb')}
|
14 |
+
data = {
|
15 |
+
'question': question,
|
16 |
+
'seed': seed,
|
17 |
+
'top_p': top_p,
|
18 |
+
'temperature': temperature
|
19 |
+
}
|
20 |
+
response = requests.post(understand_image_url, files=files, data=data)
|
21 |
+
response_data = response.json()
|
22 |
+
print("Image Understanding Response:", response_data['response'])
|
23 |
+
|
24 |
+
|
25 |
+
# Function to call the text-to-image generation endpoint
|
26 |
+
def generate_images(prompt, seed=None, guidance=5.0):
|
27 |
+
data = {
|
28 |
+
'prompt': prompt,
|
29 |
+
'seed': seed,
|
30 |
+
'guidance': guidance
|
31 |
+
}
|
32 |
+
response = requests.post(generate_images_url, data=data, stream=True)
|
33 |
+
|
34 |
+
if response.ok:
|
35 |
+
img_idx = 1
|
36 |
+
|
37 |
+
# We will create a new BytesIO for each image
|
38 |
+
buffers = {}
|
39 |
+
|
40 |
+
try:
|
41 |
+
for chunk in response.iter_content(chunk_size=1024):
|
42 |
+
if chunk:
|
43 |
+
# Use a boundary detection to determine new image start
|
44 |
+
if img_idx not in buffers:
|
45 |
+
buffers[img_idx] = io.BytesIO()
|
46 |
+
|
47 |
+
buffers[img_idx].write(chunk)
|
48 |
+
|
49 |
+
# Attempt to open the image
|
50 |
+
try:
|
51 |
+
buffer = buffers[img_idx]
|
52 |
+
buffer.seek(0)
|
53 |
+
image = Image.open(buffer)
|
54 |
+
img_path = f"generated_image_{img_idx}.png"
|
55 |
+
image.save(img_path)
|
56 |
+
print(f"Saved: {img_path}")
|
57 |
+
|
58 |
+
# Prepare the next image buffer
|
59 |
+
buffer.close()
|
60 |
+
img_idx += 1
|
61 |
+
|
62 |
+
except Exception as e:
|
63 |
+
# Continue loading data into the current buffer
|
64 |
+
continue
|
65 |
+
|
66 |
+
except Exception as e:
|
67 |
+
print("Error processing image:", e)
|
68 |
+
else:
|
69 |
+
print("Failed to generate images.")
|
70 |
+
|
71 |
+
|
72 |
+
# Example usage
|
73 |
+
if __name__ == "__main__":
|
74 |
+
# Call the image understanding API
|
75 |
+
understand_image_and_question(image_path, "What is this image about?")
|
76 |
+
|
77 |
+
# Call the image generation API
|
78 |
+
generate_images("A beautiful sunset over a mountain range, digital art.")
|
demo/model_utils.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import spaces
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
from transformers import AutoConfig, AutoModelForCausalLM, LlavaForConditionalGeneration, AutoProcessor
|
6 |
+
from transformers import CLIPProcessor, CLIPModel
|
7 |
+
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
8 |
+
|
9 |
+
@spaces.GPU(duration=120)
|
10 |
+
def set_dtype_device(model, precision=16):
|
11 |
+
dtype = (torch.bfloat16 if torch.cuda.is_available() else torch.float16) if precision==16 else (torch.bfloat32 if torch.cuda.is_available() else torch.float32)
|
12 |
+
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
13 |
+
if torch.cuda.is_available():
|
14 |
+
model = model.to(dtype).cuda()
|
15 |
+
else:
|
16 |
+
torch.set_default_device("cpu")
|
17 |
+
model = model.to(dtype)
|
18 |
+
return model, dtype, cuda_device
|
19 |
+
|
20 |
+
|
21 |
+
class Model_Utils:
|
22 |
+
def __init__(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
@spaces.GPU(duration=120)
|
26 |
+
def prepare_inputs(self):
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
@spaces.GPU(duration=120)
|
30 |
+
def generate_outputs(self):
|
31 |
+
raise NotImplementedError
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
class Clip_Utils(Model_Utils):
|
36 |
+
def __init__(self):
|
37 |
+
self.edge = 224
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
def init_Clip(self):
|
41 |
+
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
42 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
43 |
+
self.processor.feature_extractor.size = {"height": self.edge, "width": self.edge}
|
44 |
+
|
45 |
+
@spaces.GPU(duration=120)
|
46 |
+
def prepare_inputs(self, question_lst, image):
|
47 |
+
image = Image.fromarray(image)
|
48 |
+
print("image_size: ", image.size)
|
49 |
+
inputs = self.processor(text=question_lst, images=image, return_tensors="pt", padding=True)
|
50 |
+
return inputs
|
51 |
+
|
52 |
+
|
53 |
+
class Janus_Utils(Model_Utils):
|
54 |
+
def __init__(self):
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
def init_Janus(self, num_params="1B"):
|
58 |
+
|
59 |
+
model_path = f"deepseek-ai/Janus-Pro-{num_params}"
|
60 |
+
config = AutoConfig.from_pretrained(model_path)
|
61 |
+
language_config = config.language_config
|
62 |
+
language_config._attn_implementation = 'eager'
|
63 |
+
self.vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
64 |
+
language_config=language_config,
|
65 |
+
trust_remote_code=True,
|
66 |
+
ignore_mismatched_sizes=True,
|
67 |
+
)
|
68 |
+
self.vl_gpt, self.dtype, self.cuda_device = set_dtype_device(self.vl_gpt)
|
69 |
+
self.vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
70 |
+
self.tokenizer = self.vl_chat_processor.tokenizer
|
71 |
+
|
72 |
+
return self.vl_gpt, self.tokenizer
|
73 |
+
|
74 |
+
@spaces.GPU(duration=120)
|
75 |
+
def prepare_inputs(self, question, image):
|
76 |
+
conversation = [
|
77 |
+
{
|
78 |
+
"role": "<|User|>",
|
79 |
+
"content": f"<image_placeholder>\n{question}",
|
80 |
+
"images": [image],
|
81 |
+
},
|
82 |
+
{"role": "<|Assistant|>", "content": ""},
|
83 |
+
]
|
84 |
+
|
85 |
+
pil_images = [Image.fromarray(image)]
|
86 |
+
prepare_inputs = self.vl_chat_processor(
|
87 |
+
conversations=conversation, images=pil_images, force_batchify=True
|
88 |
+
).to(self.cuda_device, dtype=self.dtype)
|
89 |
+
|
90 |
+
return prepare_inputs
|
91 |
+
|
92 |
+
@spaces.GPU(duration=120)
|
93 |
+
def generate_inputs_embeddings(self, prepare_inputs):
|
94 |
+
return self.vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
95 |
+
|
96 |
+
@spaces.GPU(duration=120)
|
97 |
+
def generate_outputs(self, inputs_embeds, prepare_inputs, temperature, top_p, with_attn=False):
|
98 |
+
|
99 |
+
outputs = self.vl_gpt.language_model.generate(
|
100 |
+
inputs_embeds=inputs_embeds,
|
101 |
+
attention_mask=prepare_inputs.attention_mask,
|
102 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
103 |
+
bos_token_id=self.tokenizer.bos_token_id,
|
104 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
105 |
+
max_new_tokens=512,
|
106 |
+
do_sample=False if temperature == 0 else True,
|
107 |
+
use_cache=True,
|
108 |
+
temperature=temperature,
|
109 |
+
top_p=top_p,
|
110 |
+
return_dict_in_generate=True,
|
111 |
+
output_attentions=True
|
112 |
+
)
|
113 |
+
|
114 |
+
return outputs
|
115 |
+
|
116 |
+
class LLaVA_Utils(Model_Utils):
|
117 |
+
def __init__(self):
|
118 |
+
super().__init__()
|
119 |
+
|
120 |
+
def init_LLaVA(self):
|
121 |
+
|
122 |
+
model_path = f"llava-hf/llava-1.5-7b-hf"
|
123 |
+
config = AutoConfig.from_pretrained(model_path)
|
124 |
+
|
125 |
+
self.vl_gpt = LlavaForConditionalGeneration.from_pretrained(model_path,
|
126 |
+
low_cpu_mem_usage=True,
|
127 |
+
attn_implementation = 'eager',
|
128 |
+
output_attentions=True
|
129 |
+
)
|
130 |
+
self.vl_gpt, self.dtype, self.cuda_device = set_dtype_device(self.vl_gpt)
|
131 |
+
self.processor = AutoProcessor.from_pretrained(model_path)
|
132 |
+
self.tokenizer = self.processor.tokenizer
|
133 |
+
|
134 |
+
return self.vl_gpt, self.tokenizer
|
135 |
+
|
136 |
+
@spaces.GPU(duration=120)
|
137 |
+
def prepare_inputs(self, question, image):
|
138 |
+
conversation = [
|
139 |
+
{
|
140 |
+
|
141 |
+
"role": "user",
|
142 |
+
"content": [
|
143 |
+
{"type": "text", "text": question},
|
144 |
+
{"type": "image"},
|
145 |
+
],
|
146 |
+
},
|
147 |
+
]
|
148 |
+
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
149 |
+
pil_images = [Image.fromarray(image)]
|
150 |
+
prepare_inputs = self.processor(
|
151 |
+
images=pil_images, text=prompt, return_tensors="pt"
|
152 |
+
).to(self.cuda_device, dtype=self.dtype)
|
153 |
+
|
154 |
+
return prepare_inputs
|
155 |
+
|
156 |
+
@spaces.GPU(duration=120)
|
157 |
+
def generate_inputs_embeddings(self, prepare_inputs):
|
158 |
+
return self.vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
159 |
+
|
160 |
+
@spaces.GPU(duration=120)
|
161 |
+
def generate_outputs(self, prepare_inputs, temperature, top_p):
|
162 |
+
|
163 |
+
outputs = self.vl_gpt.generate(
|
164 |
+
**prepare_inputs,
|
165 |
+
max_new_tokens=512,
|
166 |
+
do_sample=False if temperature == 0 else True,
|
167 |
+
use_cache=True,
|
168 |
+
return_dict_in_generate=True,
|
169 |
+
output_attentions=True
|
170 |
+
)
|
171 |
+
|
172 |
+
return outputs
|
173 |
+
|
174 |
+
|
175 |
+
def add_title_to_image(image, title, font_size=20):
|
176 |
+
"""Adds a title above an image using PIL and textbbox()."""
|
177 |
+
img_width, img_height = image.size
|
178 |
+
|
179 |
+
# Create a blank image for title
|
180 |
+
title_height = font_size + 10 # Some padding
|
181 |
+
title_image = Image.new("RGB", (img_width, title_height), color=(255, 255, 255)) # White background
|
182 |
+
draw = ImageDraw.Draw(title_image)
|
183 |
+
|
184 |
+
# Load font
|
185 |
+
try:
|
186 |
+
font = ImageFont.truetype("arial.ttf", font_size) # Use Arial if available
|
187 |
+
except:
|
188 |
+
font = ImageFont.load_default() # Use default if Arial not found
|
189 |
+
|
190 |
+
# Get text size (updated for PIL >= 10)
|
191 |
+
text_bbox = draw.textbbox((0, 0), title, font=font)
|
192 |
+
text_width = text_bbox[2] - text_bbox[0]
|
193 |
+
text_height = text_bbox[3] - text_bbox[1]
|
194 |
+
|
195 |
+
# Center the title
|
196 |
+
text_position = ((img_width - text_width) // 2, (title_height - text_height) // 2)
|
197 |
+
|
198 |
+
draw.text(text_position, title, fill="black", font=font)
|
199 |
+
|
200 |
+
# Concatenate title with image
|
201 |
+
combined = Image.new("RGB", (img_width, img_height + title_height))
|
202 |
+
combined.paste(title_image, (0, 0)) # Place title at the top
|
203 |
+
combined.paste(image, (0, title_height)) # Place original image below
|
204 |
+
|
205 |
+
return combined
|
206 |
+
|
207 |
+
|
208 |
+
|
demo/modify_llama.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def save_attn_gradients(self, attn_gradients):
|
2 |
+
self.attn_gradients = attn_gradients
|
3 |
+
|
4 |
+
def get_attn_gradients(self):
|
5 |
+
return self.attn_gradients
|
6 |
+
|
7 |
+
def save_attn_map(self, attention_map):
|
8 |
+
self.attention_map = attention_map
|
9 |
+
|
10 |
+
def get_attn_map(self):
|
11 |
+
return self.attention_map
|
demo/visualize_architecture.ipynb
ADDED
@@ -0,0 +1,1715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"c:\\Users\\Austi\\anaconda3\\envs\\janus_env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
+
]
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"name": "stdout",
|
18 |
+
"output_type": "stream",
|
19 |
+
"text": [
|
20 |
+
"Python version is above 3.10, patching the collections module.\n"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"name": "stderr",
|
25 |
+
"output_type": "stream",
|
26 |
+
"text": [
|
27 |
+
"c:\\Users\\Austi\\anaconda3\\envs\\janus_env\\lib\\site-packages\\transformers\\models\\auto\\image_processing_auto.py:590: FutureWarning: The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead\n",
|
28 |
+
" warnings.warn(\n"
|
29 |
+
]
|
30 |
+
}
|
31 |
+
],
|
32 |
+
"source": [
|
33 |
+
"import gradio as gr\n",
|
34 |
+
"import torch\n",
|
35 |
+
"from transformers import AutoConfig, AutoModelForCausalLM\n",
|
36 |
+
"from janus.models import MultiModalityCausalLM, VLChatProcessor\n",
|
37 |
+
"from janus.utils.io import load_pil_images\n",
|
38 |
+
"from demo.cam import generate_gradcam, AttentionGuidedCAM\n",
|
39 |
+
"from captum.attr import LayerGradCam\n",
|
40 |
+
"from PIL import Image\n",
|
41 |
+
"from einops import rearrange\n",
|
42 |
+
"\n",
|
43 |
+
"import numpy as np\n",
|
44 |
+
"import matplotlib.pyplot as plt\n",
|
45 |
+
"import os\n",
|
46 |
+
"import time\n",
|
47 |
+
"\n",
|
48 |
+
"import torch.nn.functional as F\n",
|
49 |
+
"from scipy.ndimage import filters\n",
|
50 |
+
"from torch import nn\n"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": 2,
|
56 |
+
"metadata": {},
|
57 |
+
"outputs": [
|
58 |
+
{
|
59 |
+
"name": "stdout",
|
60 |
+
"output_type": "stream",
|
61 |
+
"text": [
|
62 |
+
"Usage Class Token: True\n"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"name": "stderr",
|
67 |
+
"output_type": "stream",
|
68 |
+
"text": [
|
69 |
+
"Some weights of MultiModalityCausalLM were not initialized from the model checkpoint at deepseek-ai/Janus-Pro-1B and are newly initialized: ['vision_model.vision_tower.cls_token']\n",
|
70 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
71 |
+
"Some weights of MultiModalityCausalLM were not initialized from the model checkpoint at deepseek-ai/Janus-Pro-1B and are newly initialized because the shapes did not match:\n",
|
72 |
+
"- vision_model.vision_tower.pos_embed: found shape torch.Size([1, 576, 1024]) in the checkpoint and torch.Size([1, 577, 1024]) in the model instantiated\n",
|
73 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
74 |
+
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n",
|
75 |
+
"You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.\n",
|
76 |
+
"Some kwargs in processor config are unused and will not have any effect: num_image_tokens, sft_format, image_tag, ignore_id, add_special_token, mask_prompt. \n"
|
77 |
+
]
|
78 |
+
}
|
79 |
+
],
|
80 |
+
"source": [
|
81 |
+
"\n",
|
82 |
+
"model_path = \"deepseek-ai/Janus-Pro-1B\"\n",
|
83 |
+
"config = AutoConfig.from_pretrained(model_path)\n",
|
84 |
+
"language_config = config.language_config\n",
|
85 |
+
"language_config._attn_implementation = 'eager'\n",
|
86 |
+
"vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,\n",
|
87 |
+
" language_config=language_config,\n",
|
88 |
+
" trust_remote_code=True,\n",
|
89 |
+
" ignore_mismatched_sizes=True # Adding CLS token, will be handled manually\n",
|
90 |
+
" )\n",
|
91 |
+
"\n",
|
92 |
+
"dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16\n",
|
93 |
+
"# dtype = torch.bfloat32 if torch.cuda.is_available() else torch.float32\n",
|
94 |
+
"\n",
|
95 |
+
"if torch.cuda.is_available():\n",
|
96 |
+
" vl_gpt = vl_gpt.to(dtype).cuda()\n",
|
97 |
+
"else:\n",
|
98 |
+
" # vl_gpt = vl_gpt.to(torch.float16)\n",
|
99 |
+
" torch.set_default_device(\"mps\")\n",
|
100 |
+
" vl_gpt = vl_gpt.to(dtype)\n",
|
101 |
+
"\n",
|
102 |
+
"vl_chat_processor = VLChatProcessor.from_pretrained(model_path)\n",
|
103 |
+
"tokenizer = vl_chat_processor.tokenizer\n",
|
104 |
+
"cuda_device = 'cuda' if torch.cuda.is_available() else 'mps'"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": 3,
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [
|
112 |
+
{
|
113 |
+
"name": "stdout",
|
114 |
+
"output_type": "stream",
|
115 |
+
"text": [
|
116 |
+
"CLIPVisionTower(\n",
|
117 |
+
" (vision_tower): VisionTransformer(\n",
|
118 |
+
" (patch_embed): PatchEmbed(\n",
|
119 |
+
" (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))\n",
|
120 |
+
" (norm): Identity()\n",
|
121 |
+
" )\n",
|
122 |
+
" (pos_drop): Dropout(p=0.0, inplace=False)\n",
|
123 |
+
" (patch_drop): Identity()\n",
|
124 |
+
" (norm_pre): Identity()\n",
|
125 |
+
" (blocks): Sequential(\n",
|
126 |
+
" (0): Block(\n",
|
127 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
128 |
+
" (attn): Attention(\n",
|
129 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
130 |
+
" (q_norm): Identity()\n",
|
131 |
+
" (k_norm): Identity()\n",
|
132 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
133 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
134 |
+
" (proj_drop): Identity()\n",
|
135 |
+
" )\n",
|
136 |
+
" (ls1): Identity()\n",
|
137 |
+
" (drop_path1): Identity()\n",
|
138 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
139 |
+
" (mlp): Mlp(\n",
|
140 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
141 |
+
" (act): GELU(approximate='none')\n",
|
142 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
143 |
+
" (norm): Identity()\n",
|
144 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
145 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
146 |
+
" )\n",
|
147 |
+
" (ls2): Identity()\n",
|
148 |
+
" (drop_path2): Identity()\n",
|
149 |
+
" )\n",
|
150 |
+
" (1): Block(\n",
|
151 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
152 |
+
" (attn): Attention(\n",
|
153 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
154 |
+
" (q_norm): Identity()\n",
|
155 |
+
" (k_norm): Identity()\n",
|
156 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
157 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
158 |
+
" (proj_drop): Identity()\n",
|
159 |
+
" )\n",
|
160 |
+
" (ls1): Identity()\n",
|
161 |
+
" (drop_path1): Identity()\n",
|
162 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
163 |
+
" (mlp): Mlp(\n",
|
164 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
165 |
+
" (act): GELU(approximate='none')\n",
|
166 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
167 |
+
" (norm): Identity()\n",
|
168 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
169 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
170 |
+
" )\n",
|
171 |
+
" (ls2): Identity()\n",
|
172 |
+
" (drop_path2): Identity()\n",
|
173 |
+
" )\n",
|
174 |
+
" (2): Block(\n",
|
175 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
176 |
+
" (attn): Attention(\n",
|
177 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
178 |
+
" (q_norm): Identity()\n",
|
179 |
+
" (k_norm): Identity()\n",
|
180 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
181 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
182 |
+
" (proj_drop): Identity()\n",
|
183 |
+
" )\n",
|
184 |
+
" (ls1): Identity()\n",
|
185 |
+
" (drop_path1): Identity()\n",
|
186 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
187 |
+
" (mlp): Mlp(\n",
|
188 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
189 |
+
" (act): GELU(approximate='none')\n",
|
190 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
191 |
+
" (norm): Identity()\n",
|
192 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
193 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
194 |
+
" )\n",
|
195 |
+
" (ls2): Identity()\n",
|
196 |
+
" (drop_path2): Identity()\n",
|
197 |
+
" )\n",
|
198 |
+
" (3): Block(\n",
|
199 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
200 |
+
" (attn): Attention(\n",
|
201 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
202 |
+
" (q_norm): Identity()\n",
|
203 |
+
" (k_norm): Identity()\n",
|
204 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
205 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
206 |
+
" (proj_drop): Identity()\n",
|
207 |
+
" )\n",
|
208 |
+
" (ls1): Identity()\n",
|
209 |
+
" (drop_path1): Identity()\n",
|
210 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
211 |
+
" (mlp): Mlp(\n",
|
212 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
213 |
+
" (act): GELU(approximate='none')\n",
|
214 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
215 |
+
" (norm): Identity()\n",
|
216 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
217 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
218 |
+
" )\n",
|
219 |
+
" (ls2): Identity()\n",
|
220 |
+
" (drop_path2): Identity()\n",
|
221 |
+
" )\n",
|
222 |
+
" (4): Block(\n",
|
223 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
224 |
+
" (attn): Attention(\n",
|
225 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
226 |
+
" (q_norm): Identity()\n",
|
227 |
+
" (k_norm): Identity()\n",
|
228 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
229 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
230 |
+
" (proj_drop): Identity()\n",
|
231 |
+
" )\n",
|
232 |
+
" (ls1): Identity()\n",
|
233 |
+
" (drop_path1): Identity()\n",
|
234 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
235 |
+
" (mlp): Mlp(\n",
|
236 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
237 |
+
" (act): GELU(approximate='none')\n",
|
238 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
239 |
+
" (norm): Identity()\n",
|
240 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
241 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
242 |
+
" )\n",
|
243 |
+
" (ls2): Identity()\n",
|
244 |
+
" (drop_path2): Identity()\n",
|
245 |
+
" )\n",
|
246 |
+
" (5): Block(\n",
|
247 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
248 |
+
" (attn): Attention(\n",
|
249 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
250 |
+
" (q_norm): Identity()\n",
|
251 |
+
" (k_norm): Identity()\n",
|
252 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
253 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
254 |
+
" (proj_drop): Identity()\n",
|
255 |
+
" )\n",
|
256 |
+
" (ls1): Identity()\n",
|
257 |
+
" (drop_path1): Identity()\n",
|
258 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
259 |
+
" (mlp): Mlp(\n",
|
260 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
261 |
+
" (act): GELU(approximate='none')\n",
|
262 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
263 |
+
" (norm): Identity()\n",
|
264 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
265 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
266 |
+
" )\n",
|
267 |
+
" (ls2): Identity()\n",
|
268 |
+
" (drop_path2): Identity()\n",
|
269 |
+
" )\n",
|
270 |
+
" (6): Block(\n",
|
271 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
272 |
+
" (attn): Attention(\n",
|
273 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
274 |
+
" (q_norm): Identity()\n",
|
275 |
+
" (k_norm): Identity()\n",
|
276 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
277 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
278 |
+
" (proj_drop): Identity()\n",
|
279 |
+
" )\n",
|
280 |
+
" (ls1): Identity()\n",
|
281 |
+
" (drop_path1): Identity()\n",
|
282 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
283 |
+
" (mlp): Mlp(\n",
|
284 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
285 |
+
" (act): GELU(approximate='none')\n",
|
286 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
287 |
+
" (norm): Identity()\n",
|
288 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
289 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
290 |
+
" )\n",
|
291 |
+
" (ls2): Identity()\n",
|
292 |
+
" (drop_path2): Identity()\n",
|
293 |
+
" )\n",
|
294 |
+
" (7): Block(\n",
|
295 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
296 |
+
" (attn): Attention(\n",
|
297 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
298 |
+
" (q_norm): Identity()\n",
|
299 |
+
" (k_norm): Identity()\n",
|
300 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
301 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
302 |
+
" (proj_drop): Identity()\n",
|
303 |
+
" )\n",
|
304 |
+
" (ls1): Identity()\n",
|
305 |
+
" (drop_path1): Identity()\n",
|
306 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
307 |
+
" (mlp): Mlp(\n",
|
308 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
309 |
+
" (act): GELU(approximate='none')\n",
|
310 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
311 |
+
" (norm): Identity()\n",
|
312 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
313 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
314 |
+
" )\n",
|
315 |
+
" (ls2): Identity()\n",
|
316 |
+
" (drop_path2): Identity()\n",
|
317 |
+
" )\n",
|
318 |
+
" (8): Block(\n",
|
319 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
320 |
+
" (attn): Attention(\n",
|
321 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
322 |
+
" (q_norm): Identity()\n",
|
323 |
+
" (k_norm): Identity()\n",
|
324 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
325 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
326 |
+
" (proj_drop): Identity()\n",
|
327 |
+
" )\n",
|
328 |
+
" (ls1): Identity()\n",
|
329 |
+
" (drop_path1): Identity()\n",
|
330 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
331 |
+
" (mlp): Mlp(\n",
|
332 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
333 |
+
" (act): GELU(approximate='none')\n",
|
334 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
335 |
+
" (norm): Identity()\n",
|
336 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
337 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
338 |
+
" )\n",
|
339 |
+
" (ls2): Identity()\n",
|
340 |
+
" (drop_path2): Identity()\n",
|
341 |
+
" )\n",
|
342 |
+
" (9): Block(\n",
|
343 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
344 |
+
" (attn): Attention(\n",
|
345 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
346 |
+
" (q_norm): Identity()\n",
|
347 |
+
" (k_norm): Identity()\n",
|
348 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
349 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
350 |
+
" (proj_drop): Identity()\n",
|
351 |
+
" )\n",
|
352 |
+
" (ls1): Identity()\n",
|
353 |
+
" (drop_path1): Identity()\n",
|
354 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
355 |
+
" (mlp): Mlp(\n",
|
356 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
357 |
+
" (act): GELU(approximate='none')\n",
|
358 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
359 |
+
" (norm): Identity()\n",
|
360 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
361 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
362 |
+
" )\n",
|
363 |
+
" (ls2): Identity()\n",
|
364 |
+
" (drop_path2): Identity()\n",
|
365 |
+
" )\n",
|
366 |
+
" (10): Block(\n",
|
367 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
368 |
+
" (attn): Attention(\n",
|
369 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
370 |
+
" (q_norm): Identity()\n",
|
371 |
+
" (k_norm): Identity()\n",
|
372 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
373 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
374 |
+
" (proj_drop): Identity()\n",
|
375 |
+
" )\n",
|
376 |
+
" (ls1): Identity()\n",
|
377 |
+
" (drop_path1): Identity()\n",
|
378 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
379 |
+
" (mlp): Mlp(\n",
|
380 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
381 |
+
" (act): GELU(approximate='none')\n",
|
382 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
383 |
+
" (norm): Identity()\n",
|
384 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
385 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
386 |
+
" )\n",
|
387 |
+
" (ls2): Identity()\n",
|
388 |
+
" (drop_path2): Identity()\n",
|
389 |
+
" )\n",
|
390 |
+
" (11): Block(\n",
|
391 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
392 |
+
" (attn): Attention(\n",
|
393 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
394 |
+
" (q_norm): Identity()\n",
|
395 |
+
" (k_norm): Identity()\n",
|
396 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
397 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
398 |
+
" (proj_drop): Identity()\n",
|
399 |
+
" )\n",
|
400 |
+
" (ls1): Identity()\n",
|
401 |
+
" (drop_path1): Identity()\n",
|
402 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
403 |
+
" (mlp): Mlp(\n",
|
404 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
405 |
+
" (act): GELU(approximate='none')\n",
|
406 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
407 |
+
" (norm): Identity()\n",
|
408 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
409 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
410 |
+
" )\n",
|
411 |
+
" (ls2): Identity()\n",
|
412 |
+
" (drop_path2): Identity()\n",
|
413 |
+
" )\n",
|
414 |
+
" (12): Block(\n",
|
415 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
416 |
+
" (attn): Attention(\n",
|
417 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
418 |
+
" (q_norm): Identity()\n",
|
419 |
+
" (k_norm): Identity()\n",
|
420 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
421 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
422 |
+
" (proj_drop): Identity()\n",
|
423 |
+
" )\n",
|
424 |
+
" (ls1): Identity()\n",
|
425 |
+
" (drop_path1): Identity()\n",
|
426 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
427 |
+
" (mlp): Mlp(\n",
|
428 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
429 |
+
" (act): GELU(approximate='none')\n",
|
430 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
431 |
+
" (norm): Identity()\n",
|
432 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
433 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
434 |
+
" )\n",
|
435 |
+
" (ls2): Identity()\n",
|
436 |
+
" (drop_path2): Identity()\n",
|
437 |
+
" )\n",
|
438 |
+
" (13): Block(\n",
|
439 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
440 |
+
" (attn): Attention(\n",
|
441 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
442 |
+
" (q_norm): Identity()\n",
|
443 |
+
" (k_norm): Identity()\n",
|
444 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
445 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
446 |
+
" (proj_drop): Identity()\n",
|
447 |
+
" )\n",
|
448 |
+
" (ls1): Identity()\n",
|
449 |
+
" (drop_path1): Identity()\n",
|
450 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
451 |
+
" (mlp): Mlp(\n",
|
452 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
453 |
+
" (act): GELU(approximate='none')\n",
|
454 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
455 |
+
" (norm): Identity()\n",
|
456 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
457 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
458 |
+
" )\n",
|
459 |
+
" (ls2): Identity()\n",
|
460 |
+
" (drop_path2): Identity()\n",
|
461 |
+
" )\n",
|
462 |
+
" (14): Block(\n",
|
463 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
464 |
+
" (attn): Attention(\n",
|
465 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
466 |
+
" (q_norm): Identity()\n",
|
467 |
+
" (k_norm): Identity()\n",
|
468 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
469 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
470 |
+
" (proj_drop): Identity()\n",
|
471 |
+
" )\n",
|
472 |
+
" (ls1): Identity()\n",
|
473 |
+
" (drop_path1): Identity()\n",
|
474 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
475 |
+
" (mlp): Mlp(\n",
|
476 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
477 |
+
" (act): GELU(approximate='none')\n",
|
478 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
479 |
+
" (norm): Identity()\n",
|
480 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
481 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
482 |
+
" )\n",
|
483 |
+
" (ls2): Identity()\n",
|
484 |
+
" (drop_path2): Identity()\n",
|
485 |
+
" )\n",
|
486 |
+
" (15): Block(\n",
|
487 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
488 |
+
" (attn): Attention(\n",
|
489 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
490 |
+
" (q_norm): Identity()\n",
|
491 |
+
" (k_norm): Identity()\n",
|
492 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
493 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
494 |
+
" (proj_drop): Identity()\n",
|
495 |
+
" )\n",
|
496 |
+
" (ls1): Identity()\n",
|
497 |
+
" (drop_path1): Identity()\n",
|
498 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
499 |
+
" (mlp): Mlp(\n",
|
500 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
501 |
+
" (act): GELU(approximate='none')\n",
|
502 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
503 |
+
" (norm): Identity()\n",
|
504 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
505 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
506 |
+
" )\n",
|
507 |
+
" (ls2): Identity()\n",
|
508 |
+
" (drop_path2): Identity()\n",
|
509 |
+
" )\n",
|
510 |
+
" (16): Block(\n",
|
511 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
512 |
+
" (attn): Attention(\n",
|
513 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
514 |
+
" (q_norm): Identity()\n",
|
515 |
+
" (k_norm): Identity()\n",
|
516 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
517 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
518 |
+
" (proj_drop): Identity()\n",
|
519 |
+
" )\n",
|
520 |
+
" (ls1): Identity()\n",
|
521 |
+
" (drop_path1): Identity()\n",
|
522 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
523 |
+
" (mlp): Mlp(\n",
|
524 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
525 |
+
" (act): GELU(approximate='none')\n",
|
526 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
527 |
+
" (norm): Identity()\n",
|
528 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
529 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
530 |
+
" )\n",
|
531 |
+
" (ls2): Identity()\n",
|
532 |
+
" (drop_path2): Identity()\n",
|
533 |
+
" )\n",
|
534 |
+
" (17): Block(\n",
|
535 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
536 |
+
" (attn): Attention(\n",
|
537 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
538 |
+
" (q_norm): Identity()\n",
|
539 |
+
" (k_norm): Identity()\n",
|
540 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
541 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
542 |
+
" (proj_drop): Identity()\n",
|
543 |
+
" )\n",
|
544 |
+
" (ls1): Identity()\n",
|
545 |
+
" (drop_path1): Identity()\n",
|
546 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
547 |
+
" (mlp): Mlp(\n",
|
548 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
549 |
+
" (act): GELU(approximate='none')\n",
|
550 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
551 |
+
" (norm): Identity()\n",
|
552 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
553 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
554 |
+
" )\n",
|
555 |
+
" (ls2): Identity()\n",
|
556 |
+
" (drop_path2): Identity()\n",
|
557 |
+
" )\n",
|
558 |
+
" (18): Block(\n",
|
559 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
560 |
+
" (attn): Attention(\n",
|
561 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
562 |
+
" (q_norm): Identity()\n",
|
563 |
+
" (k_norm): Identity()\n",
|
564 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
565 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
566 |
+
" (proj_drop): Identity()\n",
|
567 |
+
" )\n",
|
568 |
+
" (ls1): Identity()\n",
|
569 |
+
" (drop_path1): Identity()\n",
|
570 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
571 |
+
" (mlp): Mlp(\n",
|
572 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
573 |
+
" (act): GELU(approximate='none')\n",
|
574 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
575 |
+
" (norm): Identity()\n",
|
576 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
577 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
578 |
+
" )\n",
|
579 |
+
" (ls2): Identity()\n",
|
580 |
+
" (drop_path2): Identity()\n",
|
581 |
+
" )\n",
|
582 |
+
" (19): Block(\n",
|
583 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
584 |
+
" (attn): Attention(\n",
|
585 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
586 |
+
" (q_norm): Identity()\n",
|
587 |
+
" (k_norm): Identity()\n",
|
588 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
589 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
590 |
+
" (proj_drop): Identity()\n",
|
591 |
+
" )\n",
|
592 |
+
" (ls1): Identity()\n",
|
593 |
+
" (drop_path1): Identity()\n",
|
594 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
595 |
+
" (mlp): Mlp(\n",
|
596 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
597 |
+
" (act): GELU(approximate='none')\n",
|
598 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
599 |
+
" (norm): Identity()\n",
|
600 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
601 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
602 |
+
" )\n",
|
603 |
+
" (ls2): Identity()\n",
|
604 |
+
" (drop_path2): Identity()\n",
|
605 |
+
" )\n",
|
606 |
+
" (20): Block(\n",
|
607 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
608 |
+
" (attn): Attention(\n",
|
609 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
610 |
+
" (q_norm): Identity()\n",
|
611 |
+
" (k_norm): Identity()\n",
|
612 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
613 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
614 |
+
" (proj_drop): Identity()\n",
|
615 |
+
" )\n",
|
616 |
+
" (ls1): Identity()\n",
|
617 |
+
" (drop_path1): Identity()\n",
|
618 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
619 |
+
" (mlp): Mlp(\n",
|
620 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
621 |
+
" (act): GELU(approximate='none')\n",
|
622 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
623 |
+
" (norm): Identity()\n",
|
624 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
625 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
626 |
+
" )\n",
|
627 |
+
" (ls2): Identity()\n",
|
628 |
+
" (drop_path2): Identity()\n",
|
629 |
+
" )\n",
|
630 |
+
" (21): Block(\n",
|
631 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
632 |
+
" (attn): Attention(\n",
|
633 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
634 |
+
" (q_norm): Identity()\n",
|
635 |
+
" (k_norm): Identity()\n",
|
636 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
637 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
638 |
+
" (proj_drop): Identity()\n",
|
639 |
+
" )\n",
|
640 |
+
" (ls1): Identity()\n",
|
641 |
+
" (drop_path1): Identity()\n",
|
642 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
643 |
+
" (mlp): Mlp(\n",
|
644 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
645 |
+
" (act): GELU(approximate='none')\n",
|
646 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
647 |
+
" (norm): Identity()\n",
|
648 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
649 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
650 |
+
" )\n",
|
651 |
+
" (ls2): Identity()\n",
|
652 |
+
" (drop_path2): Identity()\n",
|
653 |
+
" )\n",
|
654 |
+
" (22): Block(\n",
|
655 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
656 |
+
" (attn): Attention(\n",
|
657 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
658 |
+
" (q_norm): Identity()\n",
|
659 |
+
" (k_norm): Identity()\n",
|
660 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
661 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
662 |
+
" (proj_drop): Identity()\n",
|
663 |
+
" )\n",
|
664 |
+
" (ls1): Identity()\n",
|
665 |
+
" (drop_path1): Identity()\n",
|
666 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
667 |
+
" (mlp): Mlp(\n",
|
668 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
669 |
+
" (act): GELU(approximate='none')\n",
|
670 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
671 |
+
" (norm): Identity()\n",
|
672 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
673 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
674 |
+
" )\n",
|
675 |
+
" (ls2): Identity()\n",
|
676 |
+
" (drop_path2): Identity()\n",
|
677 |
+
" )\n",
|
678 |
+
" (23): Block(\n",
|
679 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
680 |
+
" (attn): Attention(\n",
|
681 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
682 |
+
" (q_norm): Identity()\n",
|
683 |
+
" (k_norm): Identity()\n",
|
684 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
685 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
686 |
+
" (proj_drop): Identity()\n",
|
687 |
+
" )\n",
|
688 |
+
" (ls1): Identity()\n",
|
689 |
+
" (drop_path1): Identity()\n",
|
690 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
691 |
+
" (mlp): Mlp(\n",
|
692 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
693 |
+
" (act): GELU(approximate='none')\n",
|
694 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
695 |
+
" (norm): Identity()\n",
|
696 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
697 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
698 |
+
" )\n",
|
699 |
+
" (ls2): Identity()\n",
|
700 |
+
" (drop_path2): Identity()\n",
|
701 |
+
" )\n",
|
702 |
+
" )\n",
|
703 |
+
" (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
704 |
+
" (attn_pool): AttentionPoolLatent(\n",
|
705 |
+
" (q): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
706 |
+
" (kv): Linear(in_features=1024, out_features=2048, bias=True)\n",
|
707 |
+
" (q_norm): Identity()\n",
|
708 |
+
" (k_norm): Identity()\n",
|
709 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
710 |
+
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
|
711 |
+
" (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
712 |
+
" (mlp): Mlp(\n",
|
713 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
714 |
+
" (act): GELU(approximate='none')\n",
|
715 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
716 |
+
" (norm): Identity()\n",
|
717 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
718 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
719 |
+
" )\n",
|
720 |
+
" )\n",
|
721 |
+
" (fc_norm): Identity()\n",
|
722 |
+
" (head_drop): Dropout(p=0.0, inplace=False)\n",
|
723 |
+
" (head): Identity()\n",
|
724 |
+
" )\n",
|
725 |
+
")\n"
|
726 |
+
]
|
727 |
+
}
|
728 |
+
],
|
729 |
+
"source": [
|
730 |
+
"print(vl_gpt.vision_model)"
|
731 |
+
]
|
732 |
+
},
|
733 |
+
{
|
734 |
+
"cell_type": "code",
|
735 |
+
"execution_count": 4,
|
736 |
+
"metadata": {},
|
737 |
+
"outputs": [
|
738 |
+
{
|
739 |
+
"name": "stdout",
|
740 |
+
"output_type": "stream",
|
741 |
+
"text": [
|
742 |
+
"LlamaForCausalLM(\n",
|
743 |
+
" (model): LlamaModel(\n",
|
744 |
+
" (embed_tokens): Embedding(102400, 2048)\n",
|
745 |
+
" (layers): ModuleList(\n",
|
746 |
+
" (0-23): 24 x LlamaDecoderLayer(\n",
|
747 |
+
" (self_attn): LlamaAttention(\n",
|
748 |
+
" (q_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
|
749 |
+
" (k_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
|
750 |
+
" (v_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
|
751 |
+
" (o_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
|
752 |
+
" )\n",
|
753 |
+
" (mlp): LlamaMLP(\n",
|
754 |
+
" (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)\n",
|
755 |
+
" (up_proj): Linear(in_features=2048, out_features=5632, bias=False)\n",
|
756 |
+
" (down_proj): Linear(in_features=5632, out_features=2048, bias=False)\n",
|
757 |
+
" (act_fn): SiLU()\n",
|
758 |
+
" )\n",
|
759 |
+
" (input_layernorm): LlamaRMSNorm((2048,), eps=1e-06)\n",
|
760 |
+
" (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-06)\n",
|
761 |
+
" )\n",
|
762 |
+
" )\n",
|
763 |
+
" (norm): LlamaRMSNorm((2048,), eps=1e-06)\n",
|
764 |
+
" (rotary_emb): LlamaRotaryEmbedding()\n",
|
765 |
+
" )\n",
|
766 |
+
" (lm_head): Linear(in_features=2048, out_features=102400, bias=False)\n",
|
767 |
+
")\n"
|
768 |
+
]
|
769 |
+
}
|
770 |
+
],
|
771 |
+
"source": [
|
772 |
+
"print(vl_gpt.language_model)"
|
773 |
+
]
|
774 |
+
},
|
775 |
+
{
|
776 |
+
"cell_type": "code",
|
777 |
+
"execution_count": 5,
|
778 |
+
"metadata": {},
|
779 |
+
"outputs": [
|
780 |
+
{
|
781 |
+
"name": "stdout",
|
782 |
+
"output_type": "stream",
|
783 |
+
"text": [
|
784 |
+
"MultiModalityCausalLM(\n",
|
785 |
+
" (vision_model): CLIPVisionTower(\n",
|
786 |
+
" (vision_tower): VisionTransformer(\n",
|
787 |
+
" (patch_embed): PatchEmbed(\n",
|
788 |
+
" (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))\n",
|
789 |
+
" (norm): Identity()\n",
|
790 |
+
" )\n",
|
791 |
+
" (pos_drop): Dropout(p=0.0, inplace=False)\n",
|
792 |
+
" (patch_drop): Identity()\n",
|
793 |
+
" (norm_pre): Identity()\n",
|
794 |
+
" (blocks): Sequential(\n",
|
795 |
+
" (0): Block(\n",
|
796 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
797 |
+
" (attn): Attention(\n",
|
798 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
799 |
+
" (q_norm): Identity()\n",
|
800 |
+
" (k_norm): Identity()\n",
|
801 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
802 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
803 |
+
" (proj_drop): Identity()\n",
|
804 |
+
" )\n",
|
805 |
+
" (ls1): Identity()\n",
|
806 |
+
" (drop_path1): Identity()\n",
|
807 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
808 |
+
" (mlp): Mlp(\n",
|
809 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
810 |
+
" (act): GELU(approximate='none')\n",
|
811 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
812 |
+
" (norm): Identity()\n",
|
813 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
814 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
815 |
+
" )\n",
|
816 |
+
" (ls2): Identity()\n",
|
817 |
+
" (drop_path2): Identity()\n",
|
818 |
+
" )\n",
|
819 |
+
" (1): Block(\n",
|
820 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
821 |
+
" (attn): Attention(\n",
|
822 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
823 |
+
" (q_norm): Identity()\n",
|
824 |
+
" (k_norm): Identity()\n",
|
825 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
826 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
827 |
+
" (proj_drop): Identity()\n",
|
828 |
+
" )\n",
|
829 |
+
" (ls1): Identity()\n",
|
830 |
+
" (drop_path1): Identity()\n",
|
831 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
832 |
+
" (mlp): Mlp(\n",
|
833 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
834 |
+
" (act): GELU(approximate='none')\n",
|
835 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
836 |
+
" (norm): Identity()\n",
|
837 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
838 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
839 |
+
" )\n",
|
840 |
+
" (ls2): Identity()\n",
|
841 |
+
" (drop_path2): Identity()\n",
|
842 |
+
" )\n",
|
843 |
+
" (2): Block(\n",
|
844 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
845 |
+
" (attn): Attention(\n",
|
846 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
847 |
+
" (q_norm): Identity()\n",
|
848 |
+
" (k_norm): Identity()\n",
|
849 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
850 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
851 |
+
" (proj_drop): Identity()\n",
|
852 |
+
" )\n",
|
853 |
+
" (ls1): Identity()\n",
|
854 |
+
" (drop_path1): Identity()\n",
|
855 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
856 |
+
" (mlp): Mlp(\n",
|
857 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
858 |
+
" (act): GELU(approximate='none')\n",
|
859 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
860 |
+
" (norm): Identity()\n",
|
861 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
862 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
863 |
+
" )\n",
|
864 |
+
" (ls2): Identity()\n",
|
865 |
+
" (drop_path2): Identity()\n",
|
866 |
+
" )\n",
|
867 |
+
" (3): Block(\n",
|
868 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
869 |
+
" (attn): Attention(\n",
|
870 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
871 |
+
" (q_norm): Identity()\n",
|
872 |
+
" (k_norm): Identity()\n",
|
873 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
874 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
875 |
+
" (proj_drop): Identity()\n",
|
876 |
+
" )\n",
|
877 |
+
" (ls1): Identity()\n",
|
878 |
+
" (drop_path1): Identity()\n",
|
879 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
880 |
+
" (mlp): Mlp(\n",
|
881 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
882 |
+
" (act): GELU(approximate='none')\n",
|
883 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
884 |
+
" (norm): Identity()\n",
|
885 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
886 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
887 |
+
" )\n",
|
888 |
+
" (ls2): Identity()\n",
|
889 |
+
" (drop_path2): Identity()\n",
|
890 |
+
" )\n",
|
891 |
+
" (4): Block(\n",
|
892 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
893 |
+
" (attn): Attention(\n",
|
894 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
895 |
+
" (q_norm): Identity()\n",
|
896 |
+
" (k_norm): Identity()\n",
|
897 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
898 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
899 |
+
" (proj_drop): Identity()\n",
|
900 |
+
" )\n",
|
901 |
+
" (ls1): Identity()\n",
|
902 |
+
" (drop_path1): Identity()\n",
|
903 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
904 |
+
" (mlp): Mlp(\n",
|
905 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
906 |
+
" (act): GELU(approximate='none')\n",
|
907 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
908 |
+
" (norm): Identity()\n",
|
909 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
910 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
911 |
+
" )\n",
|
912 |
+
" (ls2): Identity()\n",
|
913 |
+
" (drop_path2): Identity()\n",
|
914 |
+
" )\n",
|
915 |
+
" (5): Block(\n",
|
916 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
917 |
+
" (attn): Attention(\n",
|
918 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
919 |
+
" (q_norm): Identity()\n",
|
920 |
+
" (k_norm): Identity()\n",
|
921 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
922 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
923 |
+
" (proj_drop): Identity()\n",
|
924 |
+
" )\n",
|
925 |
+
" (ls1): Identity()\n",
|
926 |
+
" (drop_path1): Identity()\n",
|
927 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
928 |
+
" (mlp): Mlp(\n",
|
929 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
930 |
+
" (act): GELU(approximate='none')\n",
|
931 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
932 |
+
" (norm): Identity()\n",
|
933 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
934 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
935 |
+
" )\n",
|
936 |
+
" (ls2): Identity()\n",
|
937 |
+
" (drop_path2): Identity()\n",
|
938 |
+
" )\n",
|
939 |
+
" (6): Block(\n",
|
940 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
941 |
+
" (attn): Attention(\n",
|
942 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
943 |
+
" (q_norm): Identity()\n",
|
944 |
+
" (k_norm): Identity()\n",
|
945 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
946 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
947 |
+
" (proj_drop): Identity()\n",
|
948 |
+
" )\n",
|
949 |
+
" (ls1): Identity()\n",
|
950 |
+
" (drop_path1): Identity()\n",
|
951 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
952 |
+
" (mlp): Mlp(\n",
|
953 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
954 |
+
" (act): GELU(approximate='none')\n",
|
955 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
956 |
+
" (norm): Identity()\n",
|
957 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
958 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
959 |
+
" )\n",
|
960 |
+
" (ls2): Identity()\n",
|
961 |
+
" (drop_path2): Identity()\n",
|
962 |
+
" )\n",
|
963 |
+
" (7): Block(\n",
|
964 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
965 |
+
" (attn): Attention(\n",
|
966 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
967 |
+
" (q_norm): Identity()\n",
|
968 |
+
" (k_norm): Identity()\n",
|
969 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
970 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
971 |
+
" (proj_drop): Identity()\n",
|
972 |
+
" )\n",
|
973 |
+
" (ls1): Identity()\n",
|
974 |
+
" (drop_path1): Identity()\n",
|
975 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
976 |
+
" (mlp): Mlp(\n",
|
977 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
978 |
+
" (act): GELU(approximate='none')\n",
|
979 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
980 |
+
" (norm): Identity()\n",
|
981 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
982 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
983 |
+
" )\n",
|
984 |
+
" (ls2): Identity()\n",
|
985 |
+
" (drop_path2): Identity()\n",
|
986 |
+
" )\n",
|
987 |
+
" (8): Block(\n",
|
988 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
989 |
+
" (attn): Attention(\n",
|
990 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
991 |
+
" (q_norm): Identity()\n",
|
992 |
+
" (k_norm): Identity()\n",
|
993 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
994 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
995 |
+
" (proj_drop): Identity()\n",
|
996 |
+
" )\n",
|
997 |
+
" (ls1): Identity()\n",
|
998 |
+
" (drop_path1): Identity()\n",
|
999 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1000 |
+
" (mlp): Mlp(\n",
|
1001 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1002 |
+
" (act): GELU(approximate='none')\n",
|
1003 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1004 |
+
" (norm): Identity()\n",
|
1005 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1006 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1007 |
+
" )\n",
|
1008 |
+
" (ls2): Identity()\n",
|
1009 |
+
" (drop_path2): Identity()\n",
|
1010 |
+
" )\n",
|
1011 |
+
" (9): Block(\n",
|
1012 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1013 |
+
" (attn): Attention(\n",
|
1014 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1015 |
+
" (q_norm): Identity()\n",
|
1016 |
+
" (k_norm): Identity()\n",
|
1017 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1018 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1019 |
+
" (proj_drop): Identity()\n",
|
1020 |
+
" )\n",
|
1021 |
+
" (ls1): Identity()\n",
|
1022 |
+
" (drop_path1): Identity()\n",
|
1023 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1024 |
+
" (mlp): Mlp(\n",
|
1025 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1026 |
+
" (act): GELU(approximate='none')\n",
|
1027 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1028 |
+
" (norm): Identity()\n",
|
1029 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1030 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1031 |
+
" )\n",
|
1032 |
+
" (ls2): Identity()\n",
|
1033 |
+
" (drop_path2): Identity()\n",
|
1034 |
+
" )\n",
|
1035 |
+
" (10): Block(\n",
|
1036 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1037 |
+
" (attn): Attention(\n",
|
1038 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1039 |
+
" (q_norm): Identity()\n",
|
1040 |
+
" (k_norm): Identity()\n",
|
1041 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1042 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1043 |
+
" (proj_drop): Identity()\n",
|
1044 |
+
" )\n",
|
1045 |
+
" (ls1): Identity()\n",
|
1046 |
+
" (drop_path1): Identity()\n",
|
1047 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1048 |
+
" (mlp): Mlp(\n",
|
1049 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1050 |
+
" (act): GELU(approximate='none')\n",
|
1051 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1052 |
+
" (norm): Identity()\n",
|
1053 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1054 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1055 |
+
" )\n",
|
1056 |
+
" (ls2): Identity()\n",
|
1057 |
+
" (drop_path2): Identity()\n",
|
1058 |
+
" )\n",
|
1059 |
+
" (11): Block(\n",
|
1060 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1061 |
+
" (attn): Attention(\n",
|
1062 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1063 |
+
" (q_norm): Identity()\n",
|
1064 |
+
" (k_norm): Identity()\n",
|
1065 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1066 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1067 |
+
" (proj_drop): Identity()\n",
|
1068 |
+
" )\n",
|
1069 |
+
" (ls1): Identity()\n",
|
1070 |
+
" (drop_path1): Identity()\n",
|
1071 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1072 |
+
" (mlp): Mlp(\n",
|
1073 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1074 |
+
" (act): GELU(approximate='none')\n",
|
1075 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1076 |
+
" (norm): Identity()\n",
|
1077 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1078 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1079 |
+
" )\n",
|
1080 |
+
" (ls2): Identity()\n",
|
1081 |
+
" (drop_path2): Identity()\n",
|
1082 |
+
" )\n",
|
1083 |
+
" (12): Block(\n",
|
1084 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1085 |
+
" (attn): Attention(\n",
|
1086 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1087 |
+
" (q_norm): Identity()\n",
|
1088 |
+
" (k_norm): Identity()\n",
|
1089 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1090 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1091 |
+
" (proj_drop): Identity()\n",
|
1092 |
+
" )\n",
|
1093 |
+
" (ls1): Identity()\n",
|
1094 |
+
" (drop_path1): Identity()\n",
|
1095 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1096 |
+
" (mlp): Mlp(\n",
|
1097 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1098 |
+
" (act): GELU(approximate='none')\n",
|
1099 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1100 |
+
" (norm): Identity()\n",
|
1101 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1102 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1103 |
+
" )\n",
|
1104 |
+
" (ls2): Identity()\n",
|
1105 |
+
" (drop_path2): Identity()\n",
|
1106 |
+
" )\n",
|
1107 |
+
" (13): Block(\n",
|
1108 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1109 |
+
" (attn): Attention(\n",
|
1110 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1111 |
+
" (q_norm): Identity()\n",
|
1112 |
+
" (k_norm): Identity()\n",
|
1113 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1114 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1115 |
+
" (proj_drop): Identity()\n",
|
1116 |
+
" )\n",
|
1117 |
+
" (ls1): Identity()\n",
|
1118 |
+
" (drop_path1): Identity()\n",
|
1119 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1120 |
+
" (mlp): Mlp(\n",
|
1121 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1122 |
+
" (act): GELU(approximate='none')\n",
|
1123 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1124 |
+
" (norm): Identity()\n",
|
1125 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1126 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1127 |
+
" )\n",
|
1128 |
+
" (ls2): Identity()\n",
|
1129 |
+
" (drop_path2): Identity()\n",
|
1130 |
+
" )\n",
|
1131 |
+
" (14): Block(\n",
|
1132 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1133 |
+
" (attn): Attention(\n",
|
1134 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1135 |
+
" (q_norm): Identity()\n",
|
1136 |
+
" (k_norm): Identity()\n",
|
1137 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1138 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1139 |
+
" (proj_drop): Identity()\n",
|
1140 |
+
" )\n",
|
1141 |
+
" (ls1): Identity()\n",
|
1142 |
+
" (drop_path1): Identity()\n",
|
1143 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1144 |
+
" (mlp): Mlp(\n",
|
1145 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1146 |
+
" (act): GELU(approximate='none')\n",
|
1147 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1148 |
+
" (norm): Identity()\n",
|
1149 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1150 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1151 |
+
" )\n",
|
1152 |
+
" (ls2): Identity()\n",
|
1153 |
+
" (drop_path2): Identity()\n",
|
1154 |
+
" )\n",
|
1155 |
+
" (15): Block(\n",
|
1156 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1157 |
+
" (attn): Attention(\n",
|
1158 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1159 |
+
" (q_norm): Identity()\n",
|
1160 |
+
" (k_norm): Identity()\n",
|
1161 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1162 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1163 |
+
" (proj_drop): Identity()\n",
|
1164 |
+
" )\n",
|
1165 |
+
" (ls1): Identity()\n",
|
1166 |
+
" (drop_path1): Identity()\n",
|
1167 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1168 |
+
" (mlp): Mlp(\n",
|
1169 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1170 |
+
" (act): GELU(approximate='none')\n",
|
1171 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1172 |
+
" (norm): Identity()\n",
|
1173 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1174 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1175 |
+
" )\n",
|
1176 |
+
" (ls2): Identity()\n",
|
1177 |
+
" (drop_path2): Identity()\n",
|
1178 |
+
" )\n",
|
1179 |
+
" (16): Block(\n",
|
1180 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1181 |
+
" (attn): Attention(\n",
|
1182 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1183 |
+
" (q_norm): Identity()\n",
|
1184 |
+
" (k_norm): Identity()\n",
|
1185 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1186 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1187 |
+
" (proj_drop): Identity()\n",
|
1188 |
+
" )\n",
|
1189 |
+
" (ls1): Identity()\n",
|
1190 |
+
" (drop_path1): Identity()\n",
|
1191 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1192 |
+
" (mlp): Mlp(\n",
|
1193 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1194 |
+
" (act): GELU(approximate='none')\n",
|
1195 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1196 |
+
" (norm): Identity()\n",
|
1197 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1198 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1199 |
+
" )\n",
|
1200 |
+
" (ls2): Identity()\n",
|
1201 |
+
" (drop_path2): Identity()\n",
|
1202 |
+
" )\n",
|
1203 |
+
" (17): Block(\n",
|
1204 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1205 |
+
" (attn): Attention(\n",
|
1206 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1207 |
+
" (q_norm): Identity()\n",
|
1208 |
+
" (k_norm): Identity()\n",
|
1209 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1210 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1211 |
+
" (proj_drop): Identity()\n",
|
1212 |
+
" )\n",
|
1213 |
+
" (ls1): Identity()\n",
|
1214 |
+
" (drop_path1): Identity()\n",
|
1215 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1216 |
+
" (mlp): Mlp(\n",
|
1217 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1218 |
+
" (act): GELU(approximate='none')\n",
|
1219 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1220 |
+
" (norm): Identity()\n",
|
1221 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1222 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1223 |
+
" )\n",
|
1224 |
+
" (ls2): Identity()\n",
|
1225 |
+
" (drop_path2): Identity()\n",
|
1226 |
+
" )\n",
|
1227 |
+
" (18): Block(\n",
|
1228 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1229 |
+
" (attn): Attention(\n",
|
1230 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1231 |
+
" (q_norm): Identity()\n",
|
1232 |
+
" (k_norm): Identity()\n",
|
1233 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1234 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1235 |
+
" (proj_drop): Identity()\n",
|
1236 |
+
" )\n",
|
1237 |
+
" (ls1): Identity()\n",
|
1238 |
+
" (drop_path1): Identity()\n",
|
1239 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1240 |
+
" (mlp): Mlp(\n",
|
1241 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1242 |
+
" (act): GELU(approximate='none')\n",
|
1243 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1244 |
+
" (norm): Identity()\n",
|
1245 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1246 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1247 |
+
" )\n",
|
1248 |
+
" (ls2): Identity()\n",
|
1249 |
+
" (drop_path2): Identity()\n",
|
1250 |
+
" )\n",
|
1251 |
+
" (19): Block(\n",
|
1252 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1253 |
+
" (attn): Attention(\n",
|
1254 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1255 |
+
" (q_norm): Identity()\n",
|
1256 |
+
" (k_norm): Identity()\n",
|
1257 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1258 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1259 |
+
" (proj_drop): Identity()\n",
|
1260 |
+
" )\n",
|
1261 |
+
" (ls1): Identity()\n",
|
1262 |
+
" (drop_path1): Identity()\n",
|
1263 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1264 |
+
" (mlp): Mlp(\n",
|
1265 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1266 |
+
" (act): GELU(approximate='none')\n",
|
1267 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1268 |
+
" (norm): Identity()\n",
|
1269 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1270 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1271 |
+
" )\n",
|
1272 |
+
" (ls2): Identity()\n",
|
1273 |
+
" (drop_path2): Identity()\n",
|
1274 |
+
" )\n",
|
1275 |
+
" (20): Block(\n",
|
1276 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1277 |
+
" (attn): Attention(\n",
|
1278 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1279 |
+
" (q_norm): Identity()\n",
|
1280 |
+
" (k_norm): Identity()\n",
|
1281 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1282 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1283 |
+
" (proj_drop): Identity()\n",
|
1284 |
+
" )\n",
|
1285 |
+
" (ls1): Identity()\n",
|
1286 |
+
" (drop_path1): Identity()\n",
|
1287 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1288 |
+
" (mlp): Mlp(\n",
|
1289 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1290 |
+
" (act): GELU(approximate='none')\n",
|
1291 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1292 |
+
" (norm): Identity()\n",
|
1293 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1294 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1295 |
+
" )\n",
|
1296 |
+
" (ls2): Identity()\n",
|
1297 |
+
" (drop_path2): Identity()\n",
|
1298 |
+
" )\n",
|
1299 |
+
" (21): Block(\n",
|
1300 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1301 |
+
" (attn): Attention(\n",
|
1302 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1303 |
+
" (q_norm): Identity()\n",
|
1304 |
+
" (k_norm): Identity()\n",
|
1305 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1306 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1307 |
+
" (proj_drop): Identity()\n",
|
1308 |
+
" )\n",
|
1309 |
+
" (ls1): Identity()\n",
|
1310 |
+
" (drop_path1): Identity()\n",
|
1311 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1312 |
+
" (mlp): Mlp(\n",
|
1313 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1314 |
+
" (act): GELU(approximate='none')\n",
|
1315 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1316 |
+
" (norm): Identity()\n",
|
1317 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1318 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1319 |
+
" )\n",
|
1320 |
+
" (ls2): Identity()\n",
|
1321 |
+
" (drop_path2): Identity()\n",
|
1322 |
+
" )\n",
|
1323 |
+
" (22): Block(\n",
|
1324 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1325 |
+
" (attn): Attention(\n",
|
1326 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1327 |
+
" (q_norm): Identity()\n",
|
1328 |
+
" (k_norm): Identity()\n",
|
1329 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1330 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1331 |
+
" (proj_drop): Identity()\n",
|
1332 |
+
" )\n",
|
1333 |
+
" (ls1): Identity()\n",
|
1334 |
+
" (drop_path1): Identity()\n",
|
1335 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1336 |
+
" (mlp): Mlp(\n",
|
1337 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1338 |
+
" (act): GELU(approximate='none')\n",
|
1339 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1340 |
+
" (norm): Identity()\n",
|
1341 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1342 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1343 |
+
" )\n",
|
1344 |
+
" (ls2): Identity()\n",
|
1345 |
+
" (drop_path2): Identity()\n",
|
1346 |
+
" )\n",
|
1347 |
+
" (23): Block(\n",
|
1348 |
+
" (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1349 |
+
" (attn): Attention(\n",
|
1350 |
+
" (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
|
1351 |
+
" (q_norm): Identity()\n",
|
1352 |
+
" (k_norm): Identity()\n",
|
1353 |
+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
|
1354 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1355 |
+
" (proj_drop): Identity()\n",
|
1356 |
+
" )\n",
|
1357 |
+
" (ls1): Identity()\n",
|
1358 |
+
" (drop_path1): Identity()\n",
|
1359 |
+
" (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1360 |
+
" (mlp): Mlp(\n",
|
1361 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1362 |
+
" (act): GELU(approximate='none')\n",
|
1363 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1364 |
+
" (norm): Identity()\n",
|
1365 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1366 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1367 |
+
" )\n",
|
1368 |
+
" (ls2): Identity()\n",
|
1369 |
+
" (drop_path2): Identity()\n",
|
1370 |
+
" )\n",
|
1371 |
+
" )\n",
|
1372 |
+
" (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1373 |
+
" (attn_pool): AttentionPoolLatent(\n",
|
1374 |
+
" (q): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1375 |
+
" (kv): Linear(in_features=1024, out_features=2048, bias=True)\n",
|
1376 |
+
" (q_norm): Identity()\n",
|
1377 |
+
" (k_norm): Identity()\n",
|
1378 |
+
" (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
|
1379 |
+
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
|
1380 |
+
" (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)\n",
|
1381 |
+
" (mlp): Mlp(\n",
|
1382 |
+
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
|
1383 |
+
" (act): GELU(approximate='none')\n",
|
1384 |
+
" (drop1): Dropout(p=0.0, inplace=False)\n",
|
1385 |
+
" (norm): Identity()\n",
|
1386 |
+
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
|
1387 |
+
" (drop2): Dropout(p=0.0, inplace=False)\n",
|
1388 |
+
" )\n",
|
1389 |
+
" )\n",
|
1390 |
+
" (fc_norm): Identity()\n",
|
1391 |
+
" (head_drop): Dropout(p=0.0, inplace=False)\n",
|
1392 |
+
" (head): Identity()\n",
|
1393 |
+
" )\n",
|
1394 |
+
" )\n",
|
1395 |
+
" (aligner): MlpProjector(\n",
|
1396 |
+
" (layers): Sequential(\n",
|
1397 |
+
" (0): Linear(in_features=1024, out_features=2048, bias=True)\n",
|
1398 |
+
" (1): GELU(approximate='none')\n",
|
1399 |
+
" (2): Linear(in_features=2048, out_features=2048, bias=True)\n",
|
1400 |
+
" )\n",
|
1401 |
+
" )\n",
|
1402 |
+
" (gen_vision_model): VQModel(\n",
|
1403 |
+
" (encoder): Encoder(\n",
|
1404 |
+
" (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1405 |
+
" (conv_blocks): ModuleList(\n",
|
1406 |
+
" (0-1): 2 x Module(\n",
|
1407 |
+
" (res): ModuleList(\n",
|
1408 |
+
" (0-1): 2 x ResnetBlock(\n",
|
1409 |
+
" (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
|
1410 |
+
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1411 |
+
" (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
|
1412 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1413 |
+
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1414 |
+
" )\n",
|
1415 |
+
" )\n",
|
1416 |
+
" (attn): ModuleList()\n",
|
1417 |
+
" (downsample): Downsample(\n",
|
1418 |
+
" (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
|
1419 |
+
" )\n",
|
1420 |
+
" )\n",
|
1421 |
+
" (2): Module(\n",
|
1422 |
+
" (res): ModuleList(\n",
|
1423 |
+
" (0): ResnetBlock(\n",
|
1424 |
+
" (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
|
1425 |
+
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1426 |
+
" (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1427 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1428 |
+
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1429 |
+
" (nin_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
|
1430 |
+
" )\n",
|
1431 |
+
" (1): ResnetBlock(\n",
|
1432 |
+
" (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1433 |
+
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1434 |
+
" (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1435 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1436 |
+
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1437 |
+
" )\n",
|
1438 |
+
" )\n",
|
1439 |
+
" (attn): ModuleList()\n",
|
1440 |
+
" (downsample): Downsample(\n",
|
1441 |
+
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
|
1442 |
+
" )\n",
|
1443 |
+
" )\n",
|
1444 |
+
" (3): Module(\n",
|
1445 |
+
" (res): ModuleList(\n",
|
1446 |
+
" (0-1): 2 x ResnetBlock(\n",
|
1447 |
+
" (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1448 |
+
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1449 |
+
" (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1450 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1451 |
+
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1452 |
+
" )\n",
|
1453 |
+
" )\n",
|
1454 |
+
" (attn): ModuleList()\n",
|
1455 |
+
" (downsample): Downsample(\n",
|
1456 |
+
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
|
1457 |
+
" )\n",
|
1458 |
+
" )\n",
|
1459 |
+
" (4): Module(\n",
|
1460 |
+
" (res): ModuleList(\n",
|
1461 |
+
" (0): ResnetBlock(\n",
|
1462 |
+
" (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1463 |
+
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1464 |
+
" (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1465 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1466 |
+
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1467 |
+
" (nin_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1468 |
+
" )\n",
|
1469 |
+
" (1): ResnetBlock(\n",
|
1470 |
+
" (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1471 |
+
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1472 |
+
" (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1473 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1474 |
+
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1475 |
+
" )\n",
|
1476 |
+
" )\n",
|
1477 |
+
" (attn): ModuleList(\n",
|
1478 |
+
" (0-1): 2 x AttnBlock(\n",
|
1479 |
+
" (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1480 |
+
" (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1481 |
+
" (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1482 |
+
" (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1483 |
+
" (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1484 |
+
" )\n",
|
1485 |
+
" )\n",
|
1486 |
+
" )\n",
|
1487 |
+
" )\n",
|
1488 |
+
" (mid): ModuleList(\n",
|
1489 |
+
" (0): ResnetBlock(\n",
|
1490 |
+
" (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1491 |
+
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1492 |
+
" (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1493 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1494 |
+
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1495 |
+
" )\n",
|
1496 |
+
" (1): AttnBlock(\n",
|
1497 |
+
" (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1498 |
+
" (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1499 |
+
" (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1500 |
+
" (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1501 |
+
" (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1502 |
+
" )\n",
|
1503 |
+
" (2): ResnetBlock(\n",
|
1504 |
+
" (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1505 |
+
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1506 |
+
" (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1507 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1508 |
+
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1509 |
+
" )\n",
|
1510 |
+
" )\n",
|
1511 |
+
" (norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1512 |
+
" (conv_out): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1513 |
+
" )\n",
|
1514 |
+
" (decoder): Decoder(\n",
|
1515 |
+
" (conv_in): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1516 |
+
" (mid): ModuleList(\n",
|
1517 |
+
" (0): ResnetBlock(\n",
|
1518 |
+
" (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1519 |
+
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1520 |
+
" (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1521 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1522 |
+
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1523 |
+
" )\n",
|
1524 |
+
" (1): AttnBlock(\n",
|
1525 |
+
" (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1526 |
+
" (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1527 |
+
" (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1528 |
+
" (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1529 |
+
" (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1530 |
+
" )\n",
|
1531 |
+
" (2): ResnetBlock(\n",
|
1532 |
+
" (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1533 |
+
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1534 |
+
" (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1535 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1536 |
+
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1537 |
+
" )\n",
|
1538 |
+
" )\n",
|
1539 |
+
" (conv_blocks): ModuleList(\n",
|
1540 |
+
" (0): Module(\n",
|
1541 |
+
" (res): ModuleList(\n",
|
1542 |
+
" (0-2): 3 x ResnetBlock(\n",
|
1543 |
+
" (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1544 |
+
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1545 |
+
" (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1546 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1547 |
+
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1548 |
+
" )\n",
|
1549 |
+
" )\n",
|
1550 |
+
" (attn): ModuleList(\n",
|
1551 |
+
" (0-2): 3 x AttnBlock(\n",
|
1552 |
+
" (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1553 |
+
" (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1554 |
+
" (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1555 |
+
" (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1556 |
+
" (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
1557 |
+
" )\n",
|
1558 |
+
" )\n",
|
1559 |
+
" (upsample): Upsample(\n",
|
1560 |
+
" (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1561 |
+
" )\n",
|
1562 |
+
" )\n",
|
1563 |
+
" (1): Module(\n",
|
1564 |
+
" (res): ModuleList(\n",
|
1565 |
+
" (0): ResnetBlock(\n",
|
1566 |
+
" (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
|
1567 |
+
" (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1568 |
+
" (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1569 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1570 |
+
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1571 |
+
" (nin_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
|
1572 |
+
" )\n",
|
1573 |
+
" (1-2): 2 x ResnetBlock(\n",
|
1574 |
+
" (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1575 |
+
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1576 |
+
" (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1577 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1578 |
+
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1579 |
+
" )\n",
|
1580 |
+
" )\n",
|
1581 |
+
" (attn): ModuleList()\n",
|
1582 |
+
" (upsample): Upsample(\n",
|
1583 |
+
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1584 |
+
" )\n",
|
1585 |
+
" )\n",
|
1586 |
+
" (2): Module(\n",
|
1587 |
+
" (res): ModuleList(\n",
|
1588 |
+
" (0-2): 3 x ResnetBlock(\n",
|
1589 |
+
" (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1590 |
+
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1591 |
+
" (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1592 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1593 |
+
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1594 |
+
" )\n",
|
1595 |
+
" )\n",
|
1596 |
+
" (attn): ModuleList()\n",
|
1597 |
+
" (upsample): Upsample(\n",
|
1598 |
+
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1599 |
+
" )\n",
|
1600 |
+
" )\n",
|
1601 |
+
" (3): Module(\n",
|
1602 |
+
" (res): ModuleList(\n",
|
1603 |
+
" (0): ResnetBlock(\n",
|
1604 |
+
" (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
|
1605 |
+
" (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1606 |
+
" (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
|
1607 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1608 |
+
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1609 |
+
" (nin_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
|
1610 |
+
" )\n",
|
1611 |
+
" (1-2): 2 x ResnetBlock(\n",
|
1612 |
+
" (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
|
1613 |
+
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1614 |
+
" (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
|
1615 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1616 |
+
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1617 |
+
" )\n",
|
1618 |
+
" )\n",
|
1619 |
+
" (attn): ModuleList()\n",
|
1620 |
+
" (upsample): Upsample(\n",
|
1621 |
+
" (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1622 |
+
" )\n",
|
1623 |
+
" )\n",
|
1624 |
+
" (4): Module(\n",
|
1625 |
+
" (res): ModuleList(\n",
|
1626 |
+
" (0-2): 3 x ResnetBlock(\n",
|
1627 |
+
" (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
|
1628 |
+
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1629 |
+
" (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
|
1630 |
+
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
1631 |
+
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1632 |
+
" )\n",
|
1633 |
+
" )\n",
|
1634 |
+
" (attn): ModuleList()\n",
|
1635 |
+
" )\n",
|
1636 |
+
" )\n",
|
1637 |
+
" (norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
|
1638 |
+
" (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
1639 |
+
" )\n",
|
1640 |
+
" (quantize): VectorQuantizer(\n",
|
1641 |
+
" (embedding): Embedding(16384, 8)\n",
|
1642 |
+
" )\n",
|
1643 |
+
" (quant_conv): Conv2d(256, 8, kernel_size=(1, 1), stride=(1, 1))\n",
|
1644 |
+
" (post_quant_conv): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1))\n",
|
1645 |
+
" )\n",
|
1646 |
+
" (gen_aligner): MlpProjector(\n",
|
1647 |
+
" (layers): Sequential(\n",
|
1648 |
+
" (0): Linear(in_features=8, out_features=2048, bias=True)\n",
|
1649 |
+
" (1): GELU(approximate='none')\n",
|
1650 |
+
" (2): Linear(in_features=2048, out_features=2048, bias=True)\n",
|
1651 |
+
" )\n",
|
1652 |
+
" )\n",
|
1653 |
+
" (gen_head): vision_head(\n",
|
1654 |
+
" (output_mlp_projector): Linear(in_features=2048, out_features=2048, bias=True)\n",
|
1655 |
+
" (vision_activation): GELU(approximate='none')\n",
|
1656 |
+
" (vision_head): Linear(in_features=2048, out_features=16384, bias=True)\n",
|
1657 |
+
" )\n",
|
1658 |
+
" (gen_embed): Embedding(16384, 8)\n",
|
1659 |
+
" (language_model): LlamaForCausalLM(\n",
|
1660 |
+
" (model): LlamaModel(\n",
|
1661 |
+
" (embed_tokens): Embedding(102400, 2048)\n",
|
1662 |
+
" (layers): ModuleList(\n",
|
1663 |
+
" (0-23): 24 x LlamaDecoderLayer(\n",
|
1664 |
+
" (self_attn): LlamaAttention(\n",
|
1665 |
+
" (q_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
|
1666 |
+
" (k_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
|
1667 |
+
" (v_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
|
1668 |
+
" (o_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
|
1669 |
+
" )\n",
|
1670 |
+
" (mlp): LlamaMLP(\n",
|
1671 |
+
" (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)\n",
|
1672 |
+
" (up_proj): Linear(in_features=2048, out_features=5632, bias=False)\n",
|
1673 |
+
" (down_proj): Linear(in_features=5632, out_features=2048, bias=False)\n",
|
1674 |
+
" (act_fn): SiLU()\n",
|
1675 |
+
" )\n",
|
1676 |
+
" (input_layernorm): LlamaRMSNorm((2048,), eps=1e-06)\n",
|
1677 |
+
" (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-06)\n",
|
1678 |
+
" )\n",
|
1679 |
+
" )\n",
|
1680 |
+
" (norm): LlamaRMSNorm((2048,), eps=1e-06)\n",
|
1681 |
+
" (rotary_emb): LlamaRotaryEmbedding()\n",
|
1682 |
+
" )\n",
|
1683 |
+
" (lm_head): Linear(in_features=2048, out_features=102400, bias=False)\n",
|
1684 |
+
" )\n",
|
1685 |
+
")\n"
|
1686 |
+
]
|
1687 |
+
}
|
1688 |
+
],
|
1689 |
+
"source": [
|
1690 |
+
"print(vl_gpt)"
|
1691 |
+
]
|
1692 |
+
}
|
1693 |
+
],
|
1694 |
+
"metadata": {
|
1695 |
+
"kernelspec": {
|
1696 |
+
"display_name": "janus_env",
|
1697 |
+
"language": "python",
|
1698 |
+
"name": "python3"
|
1699 |
+
},
|
1700 |
+
"language_info": {
|
1701 |
+
"codemirror_mode": {
|
1702 |
+
"name": "ipython",
|
1703 |
+
"version": 3
|
1704 |
+
},
|
1705 |
+
"file_extension": ".py",
|
1706 |
+
"mimetype": "text/x-python",
|
1707 |
+
"name": "python",
|
1708 |
+
"nbconvert_exporter": "python",
|
1709 |
+
"pygments_lexer": "ipython3",
|
1710 |
+
"version": "3.10.16"
|
1711 |
+
}
|
1712 |
+
},
|
1713 |
+
"nbformat": 4,
|
1714 |
+
"nbformat_minor": 2
|
1715 |
+
}
|
images/AreaChart.png
ADDED
![]() |
images/BarChart.png
ADDED
![]() |
images/BubbleChart.png
ADDED
![]() |
images/Choropleth_New.png
ADDED
![]() |
images/Histogram.png
ADDED
![]() |
images/LineChart.png
ADDED
![]() |
images/PieChart.png
ADDED
![]() |
images/Scatterplot.png
ADDED
![]() |
images/Stacked100.png
ADDED
![]() |
images/StackedArea.png
ADDED
![]() |
images/StackedBar.png
ADDED
![]() |
images/TreeMap.png
ADDED
![]() |
images/badge.svg
ADDED
|
images/cat_dog.png
ADDED
![]() |
images/doge.png
ADDED
![]() |
images/equation.png
ADDED
![]() |
images/logo.png
ADDED
![]() |
images/logo.svg
ADDED
|
images/pie_chart.png
ADDED
![]() |
images/ve.png
ADDED
![]() |
janus/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
|
21 |
+
# check if python version is above 3.10
|
22 |
+
import sys
|
23 |
+
|
24 |
+
if sys.version_info >= (3, 10):
|
25 |
+
print("Python version is above 3.10, patching the collections module.")
|
26 |
+
# Monkey patch collections
|
27 |
+
import collections
|
28 |
+
import collections.abc
|
29 |
+
|
30 |
+
for type_name in collections.abc.__all__:
|
31 |
+
setattr(collections, type_name, getattr(collections.abc, type_name))
|
janus/janusflow/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
|
21 |
+
# check if python version is above 3.10
|
22 |
+
import sys
|
23 |
+
|
24 |
+
if sys.version_info >= (3, 10):
|
25 |
+
print("Python version is above 3.10, patching the collections module.")
|
26 |
+
# Monkey patch collections
|
27 |
+
import collections
|
28 |
+
import collections.abc
|
29 |
+
|
30 |
+
for type_name in collections.abc.__all__:
|
31 |
+
setattr(collections, type_name, getattr(collections.abc, type_name))
|
janus/janusflow/models/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from .image_processing_vlm import VLMImageProcessor
|
21 |
+
from .modeling_vlm import MultiModalityCausalLM
|
22 |
+
from .processing_vlm import VLChatProcessor
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
"VLMImageProcessor",
|
26 |
+
"VLChatProcessor",
|
27 |
+
"MultiModalityCausalLM",
|
28 |
+
]
|
janus/janusflow/models/clip_encoder.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from typing import Dict, List, Literal, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torchvision.transforms
|
25 |
+
from einops import rearrange
|
26 |
+
|
27 |
+
from janus.janusflow.models.siglip_vit import create_siglip_vit
|
28 |
+
|
29 |
+
|
30 |
+
class CLIPVisionTower(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
model_name: str = "siglip_large_patch16_384",
|
34 |
+
image_size: Union[Tuple[int, int], int] = 336,
|
35 |
+
select_feature: str = "patch",
|
36 |
+
select_layer: int = -2,
|
37 |
+
select_layers: list = None,
|
38 |
+
ckpt_path: str = "",
|
39 |
+
pixel_mean: Optional[List[float]] = None,
|
40 |
+
pixel_std: Optional[List[float]] = None,
|
41 |
+
**kwargs,
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.model_name = model_name
|
46 |
+
self.select_feature = select_feature
|
47 |
+
self.select_layer = select_layer
|
48 |
+
self.select_layers = select_layers
|
49 |
+
|
50 |
+
vision_tower_params = {
|
51 |
+
"model_name": model_name,
|
52 |
+
"image_size": image_size,
|
53 |
+
"ckpt_path": ckpt_path,
|
54 |
+
"select_layer": select_layer,
|
55 |
+
}
|
56 |
+
vision_tower_params.update(kwargs)
|
57 |
+
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
|
58 |
+
vision_tower_params
|
59 |
+
)
|
60 |
+
|
61 |
+
if pixel_mean is not None and pixel_std is not None:
|
62 |
+
image_norm = torchvision.transforms.Normalize(
|
63 |
+
mean=pixel_mean, std=pixel_std
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
image_norm = None
|
67 |
+
|
68 |
+
self.image_norm = image_norm
|
69 |
+
|
70 |
+
def build_vision_tower(self, vision_tower_params):
|
71 |
+
if self.model_name.startswith("siglip"):
|
72 |
+
self.select_feature = "same"
|
73 |
+
vision_tower = create_siglip_vit(**vision_tower_params)
|
74 |
+
forward_kwargs = dict()
|
75 |
+
|
76 |
+
elif self.model_name.startswith("sam"):
|
77 |
+
vision_tower = create_sam_vit(**vision_tower_params)
|
78 |
+
forward_kwargs = dict()
|
79 |
+
|
80 |
+
else: # huggingface
|
81 |
+
from transformers import CLIPVisionModel
|
82 |
+
|
83 |
+
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
|
84 |
+
forward_kwargs = dict(output_hidden_states=True)
|
85 |
+
|
86 |
+
return vision_tower, forward_kwargs
|
87 |
+
|
88 |
+
def feature_select(self, image_forward_outs):
|
89 |
+
if isinstance(image_forward_outs, torch.Tensor):
|
90 |
+
# the output has been the self.select_layer"s features
|
91 |
+
image_features = image_forward_outs
|
92 |
+
else:
|
93 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
94 |
+
|
95 |
+
if self.select_feature == "patch":
|
96 |
+
# if the output has cls_token
|
97 |
+
image_features = image_features[:, 1:]
|
98 |
+
elif self.select_feature == "cls_patch":
|
99 |
+
image_features = image_features
|
100 |
+
elif self.select_feature == "same":
|
101 |
+
image_features = image_features
|
102 |
+
|
103 |
+
else:
|
104 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
105 |
+
return image_features
|
106 |
+
|
107 |
+
def forward(self, images):
|
108 |
+
"""
|
109 |
+
|
110 |
+
Args:
|
111 |
+
images (torch.Tensor): [b, 3, H, W]
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
image_features (torch.Tensor): [b, n_patch, d]
|
115 |
+
"""
|
116 |
+
|
117 |
+
if self.image_norm is not None:
|
118 |
+
images = self.image_norm(images)
|
119 |
+
|
120 |
+
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
|
121 |
+
image_features = self.feature_select(image_forward_outs)
|
122 |
+
return image_features
|
janus/janusflow/models/image_processing_vlm.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from typing import List, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import torchvision
|
25 |
+
import torchvision.transforms.functional
|
26 |
+
from PIL import Image
|
27 |
+
from transformers import AutoImageProcessor, PretrainedConfig
|
28 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
29 |
+
from transformers.image_utils import to_numpy_array
|
30 |
+
from transformers.utils import logging
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
|
35 |
+
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
36 |
+
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
|
37 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
38 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
39 |
+
|
40 |
+
|
41 |
+
def expand2square(pil_img, background_color):
|
42 |
+
width, height = pil_img.size
|
43 |
+
if width == height:
|
44 |
+
return pil_img
|
45 |
+
elif width > height:
|
46 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
47 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
48 |
+
return result
|
49 |
+
else:
|
50 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
51 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
52 |
+
return result
|
53 |
+
|
54 |
+
|
55 |
+
class VLMImageProcessorConfig(PretrainedConfig):
|
56 |
+
model_type = "deepseek_vlm"
|
57 |
+
image_size: int
|
58 |
+
min_size: int
|
59 |
+
image_mean: Union[Tuple[float, float, float], List[float]]
|
60 |
+
image_std: Union[Tuple[float, float, float], List[float]]
|
61 |
+
rescale_factor: float
|
62 |
+
do_normalize: bool
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
image_size: int,
|
67 |
+
min_size: int = 14,
|
68 |
+
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
69 |
+
0.48145466,
|
70 |
+
0.4578275,
|
71 |
+
0.40821073,
|
72 |
+
),
|
73 |
+
image_std: Union[Tuple[float, float, float], List[float]] = (
|
74 |
+
0.26862954,
|
75 |
+
0.26130258,
|
76 |
+
0.27577711,
|
77 |
+
),
|
78 |
+
rescale_factor: float = 1.0 / 255.0,
|
79 |
+
do_normalize: bool = True,
|
80 |
+
**kwargs,
|
81 |
+
):
|
82 |
+
self.image_size = image_size
|
83 |
+
self.min_size = min_size
|
84 |
+
self.image_mean = image_mean
|
85 |
+
self.image_std = image_std
|
86 |
+
self.rescale_factor = rescale_factor
|
87 |
+
self.do_normalize = do_normalize
|
88 |
+
|
89 |
+
super().__init__(**kwargs)
|
90 |
+
|
91 |
+
|
92 |
+
class VLMImageProcessor(BaseImageProcessor):
|
93 |
+
model_input_names = ["pixel_values"]
|
94 |
+
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
image_size: int,
|
98 |
+
min_size: int = 14,
|
99 |
+
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
100 |
+
0.48145466,
|
101 |
+
0.4578275,
|
102 |
+
0.40821073,
|
103 |
+
),
|
104 |
+
image_std: Union[Tuple[float, float, float], List[float]] = (
|
105 |
+
0.26862954,
|
106 |
+
0.26130258,
|
107 |
+
0.27577711,
|
108 |
+
),
|
109 |
+
rescale_factor: float = 1.0 / 255.0,
|
110 |
+
do_normalize: bool = True,
|
111 |
+
**kwargs,
|
112 |
+
):
|
113 |
+
super().__init__(**kwargs)
|
114 |
+
|
115 |
+
self.image_size = image_size
|
116 |
+
self.rescale_factor = rescale_factor
|
117 |
+
self.image_mean = image_mean
|
118 |
+
self.image_std = image_std
|
119 |
+
self.min_size = min_size
|
120 |
+
self.do_normalize = do_normalize
|
121 |
+
|
122 |
+
if image_mean is None:
|
123 |
+
self.background_color = (127, 127, 127)
|
124 |
+
else:
|
125 |
+
self.background_color = tuple([int(x * 255) for x in image_mean])
|
126 |
+
|
127 |
+
def resize(self, pil_img: Image) -> np.ndarray:
|
128 |
+
"""
|
129 |
+
|
130 |
+
Args:
|
131 |
+
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
x (np.ndarray): [3, self.image_size, self.image_size]
|
135 |
+
"""
|
136 |
+
|
137 |
+
width, height = pil_img.size
|
138 |
+
max_size = max(width, height)
|
139 |
+
|
140 |
+
size = [
|
141 |
+
max(int(height / max_size * self.image_size), self.min_size),
|
142 |
+
max(int(width / max_size * self.image_size), self.min_size),
|
143 |
+
]
|
144 |
+
|
145 |
+
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
|
146 |
+
print(f"orig size = {pil_img.size}, new size = {size}")
|
147 |
+
raise ValueError("Invalid size!")
|
148 |
+
|
149 |
+
pil_img = torchvision.transforms.functional.resize(
|
150 |
+
pil_img,
|
151 |
+
size,
|
152 |
+
interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
|
153 |
+
antialias=True,
|
154 |
+
)
|
155 |
+
|
156 |
+
pil_img = expand2square(pil_img, self.background_color)
|
157 |
+
x = to_numpy_array(pil_img)
|
158 |
+
|
159 |
+
# [H, W, 3] -> [3, H, W]
|
160 |
+
x = np.transpose(x, (2, 0, 1))
|
161 |
+
|
162 |
+
return x
|
163 |
+
|
164 |
+
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
|
165 |
+
# resize and pad to [self.image_size, self.image_size]
|
166 |
+
# then convert from [H, W, 3] to [3, H, W]
|
167 |
+
images: List[np.ndarray] = [self.resize(image) for image in images]
|
168 |
+
|
169 |
+
# resacle from [0, 255] -> [0, 1]
|
170 |
+
images = [
|
171 |
+
self.rescale(
|
172 |
+
image=image,
|
173 |
+
scale=self.rescale_factor,
|
174 |
+
input_data_format="channels_first",
|
175 |
+
)
|
176 |
+
for image in images
|
177 |
+
]
|
178 |
+
|
179 |
+
# normalize
|
180 |
+
if self.do_normalize:
|
181 |
+
images = [
|
182 |
+
self.normalize(
|
183 |
+
image=image,
|
184 |
+
mean=self.image_mean,
|
185 |
+
std=self.image_std,
|
186 |
+
input_data_format="channels_first",
|
187 |
+
)
|
188 |
+
for image in images
|
189 |
+
]
|
190 |
+
|
191 |
+
data = {"pixel_values": images}
|
192 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
193 |
+
|
194 |
+
@property
|
195 |
+
def default_shape(self):
|
196 |
+
return [3, self.image_size, self.image_size]
|
197 |
+
|
198 |
+
|
199 |
+
AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
image_processor = VLMImageProcessor(
|
204 |
+
image_size=1024,
|
205 |
+
image_mean=IMAGENET_INCEPTION_MEAN,
|
206 |
+
image_std=IMAGENET_INCEPTION_STD,
|
207 |
+
do_normalize=True,
|
208 |
+
)
|
janus/janusflow/models/modeling_vlm.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from attrdict import AttrDict
|
21 |
+
from einops import rearrange
|
22 |
+
import torch
|
23 |
+
from transformers.configuration_utils import PretrainedConfig
|
24 |
+
from transformers import (
|
25 |
+
AutoConfig,
|
26 |
+
AutoModelForCausalLM,
|
27 |
+
PreTrainedModel,
|
28 |
+
LlamaConfig,
|
29 |
+
LlamaForCausalLM,
|
30 |
+
)
|
31 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
32 |
+
from janus.janusflow.models.clip_encoder import CLIPVisionTower
|
33 |
+
from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder
|
34 |
+
import torch.nn as nn
|
35 |
+
|
36 |
+
|
37 |
+
def model_name_to_cls(cls_name):
|
38 |
+
|
39 |
+
if "CLIPVisionTower" in cls_name:
|
40 |
+
cls = CLIPVisionTower
|
41 |
+
elif "ShallowUViTEncoder" in cls_name:
|
42 |
+
cls = ShallowUViTEncoder
|
43 |
+
elif "ShallowUViTDecoder" in cls_name:
|
44 |
+
cls = ShallowUViTDecoder
|
45 |
+
else:
|
46 |
+
raise ValueError(f"class_name {cls_name} is invalid.")
|
47 |
+
|
48 |
+
return cls
|
49 |
+
|
50 |
+
|
51 |
+
class VisionUnderstandEncoderConfig(PretrainedConfig):
|
52 |
+
model_type = "vision_und_enc"
|
53 |
+
cls: str = ""
|
54 |
+
params: AttrDict = {}
|
55 |
+
|
56 |
+
def __init__(self, **kwargs):
|
57 |
+
super().__init__(**kwargs)
|
58 |
+
|
59 |
+
self.cls = kwargs.get("cls", "")
|
60 |
+
if not isinstance(self.cls, str):
|
61 |
+
self.cls = self.cls.__name__
|
62 |
+
|
63 |
+
self.params = AttrDict(kwargs.get("params", {}))
|
64 |
+
|
65 |
+
|
66 |
+
class VisionGenerationEncoderConfig(PretrainedConfig):
|
67 |
+
model_type = "vision_gen_enc"
|
68 |
+
cls: str = ""
|
69 |
+
params: AttrDict = {}
|
70 |
+
|
71 |
+
def __init__(self, **kwargs):
|
72 |
+
super().__init__(**kwargs)
|
73 |
+
|
74 |
+
self.cls = kwargs.get("cls", "")
|
75 |
+
if not isinstance(self.cls, str):
|
76 |
+
self.cls = self.cls.__name__
|
77 |
+
|
78 |
+
self.params = AttrDict(kwargs.get("params", {}))
|
79 |
+
|
80 |
+
|
81 |
+
class VisionGenerationDecoderConfig(PretrainedConfig):
|
82 |
+
model_type = "vision_gen_dec"
|
83 |
+
cls: str = ""
|
84 |
+
params: AttrDict = {}
|
85 |
+
|
86 |
+
def __init__(self, **kwargs):
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
|
89 |
+
self.cls = kwargs.get("cls", "")
|
90 |
+
if not isinstance(self.cls, str):
|
91 |
+
self.cls = self.cls.__name__
|
92 |
+
|
93 |
+
self.params = AttrDict(kwargs.get("params", {}))
|
94 |
+
|
95 |
+
|
96 |
+
class MultiModalityConfig(PretrainedConfig):
|
97 |
+
model_type = "multi_modality"
|
98 |
+
vision_und_enc_config: VisionUnderstandEncoderConfig
|
99 |
+
language_config: LlamaConfig
|
100 |
+
|
101 |
+
def __init__(self, **kwargs):
|
102 |
+
super().__init__(**kwargs)
|
103 |
+
vision_und_enc_config = kwargs.get("vision_und_enc_config", {})
|
104 |
+
self.vision_und_enc_config = VisionUnderstandEncoderConfig(
|
105 |
+
**vision_und_enc_config
|
106 |
+
)
|
107 |
+
|
108 |
+
vision_gen_enc_config = kwargs.get("vision_gen_enc_config", {})
|
109 |
+
self.vision_gen_enc_config = VisionGenerationEncoderConfig(
|
110 |
+
**vision_gen_enc_config
|
111 |
+
)
|
112 |
+
|
113 |
+
vision_gen_dec_config = kwargs.get("vision_gen_dec_config", {})
|
114 |
+
self.vision_gen_dec_config = VisionGenerationDecoderConfig(
|
115 |
+
**vision_gen_dec_config
|
116 |
+
)
|
117 |
+
|
118 |
+
language_config = kwargs.get("language_config", {})
|
119 |
+
if isinstance(language_config, LlamaConfig):
|
120 |
+
self.language_config = language_config
|
121 |
+
else:
|
122 |
+
self.language_config = LlamaConfig(**language_config)
|
123 |
+
|
124 |
+
|
125 |
+
class MultiModalityPreTrainedModel(PreTrainedModel):
|
126 |
+
config_class = MultiModalityConfig
|
127 |
+
base_model_prefix = "multi_modality"
|
128 |
+
_no_split_modules = []
|
129 |
+
_skip_keys_device_placement = "past_key_values"
|
130 |
+
|
131 |
+
|
132 |
+
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
133 |
+
|
134 |
+
def __init__(self, config: MultiModalityConfig):
|
135 |
+
super().__init__(config)
|
136 |
+
|
137 |
+
# vision understanding encoder
|
138 |
+
vision_und_enc_config = config.vision_und_enc_config
|
139 |
+
vision_und_enc_cls = model_name_to_cls(vision_und_enc_config.cls)
|
140 |
+
self.vision_und_enc_model = vision_und_enc_cls(**vision_und_enc_config.params)
|
141 |
+
|
142 |
+
# vision understanding aligner
|
143 |
+
self.vision_und_enc_aligner = nn.Linear(1024, 2048, bias=True)
|
144 |
+
|
145 |
+
# begin of understanding embedding
|
146 |
+
self.beg_of_und_embed = nn.Parameter(torch.zeros(1, 2048))
|
147 |
+
|
148 |
+
# vision generation encoder
|
149 |
+
vision_gen_enc_config = config.vision_gen_enc_config
|
150 |
+
vision_gen_enc_cls = model_name_to_cls(vision_gen_enc_config.cls)
|
151 |
+
self.vision_gen_enc_model = vision_gen_enc_cls(**vision_gen_enc_config.params)
|
152 |
+
|
153 |
+
# vision generation encoder aligner
|
154 |
+
self.vision_gen_enc_aligner = nn.Linear(768, 2048, bias=True)
|
155 |
+
|
156 |
+
# vision generation decoder
|
157 |
+
vision_gen_dec_config = config.vision_gen_dec_config
|
158 |
+
vision_gen_dec_cls = model_name_to_cls(vision_gen_dec_config.cls)
|
159 |
+
self.vision_gen_dec_model = vision_gen_dec_cls(**vision_gen_dec_config.params)
|
160 |
+
|
161 |
+
# language model
|
162 |
+
language_config = config.language_config
|
163 |
+
self.language_model = LlamaForCausalLM(language_config)
|
164 |
+
|
165 |
+
# vision generation decoder aligner
|
166 |
+
self.vision_gen_dec_aligner_norm = LlamaRMSNorm(
|
167 |
+
2048, eps=language_config.rms_norm_eps
|
168 |
+
)
|
169 |
+
self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)
|
170 |
+
|
171 |
+
def prepare_inputs_embeds(
|
172 |
+
self,
|
173 |
+
input_ids: torch.LongTensor,
|
174 |
+
pixel_values: torch.FloatTensor,
|
175 |
+
images_seq_mask: torch.LongTensor,
|
176 |
+
images_emb_mask: torch.LongTensor,
|
177 |
+
**kwargs,
|
178 |
+
):
|
179 |
+
"""
|
180 |
+
|
181 |
+
Args:
|
182 |
+
input_ids (torch.LongTensor): [b, T]
|
183 |
+
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
184 |
+
images_seq_mask (torch.BoolTensor): [b, T]
|
185 |
+
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
186 |
+
|
187 |
+
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
input_embeds (torch.Tensor): [b, T, D]
|
191 |
+
"""
|
192 |
+
|
193 |
+
bs, n = pixel_values.shape[0:2]
|
194 |
+
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
195 |
+
# [b x n, T2, D]
|
196 |
+
images_embeds = self.vision_und_enc_model(images)
|
197 |
+
images_embeds = self.vision_und_enc_aligner(images_embeds)
|
198 |
+
# print(images_embeds.shape, self.beg_of_und_embed.shape, images_seq_mask.shape, input_ids.shape)
|
199 |
+
beg_of_und_embed = self.beg_of_und_embed[0].detach().clone()
|
200 |
+
images_embeds = torch.cat(
|
201 |
+
[
|
202 |
+
beg_of_und_embed.view(1, 1, -1).repeat(images_embeds.shape[0], 1, 1),
|
203 |
+
images_embeds,
|
204 |
+
],
|
205 |
+
dim=1,
|
206 |
+
)
|
207 |
+
# [b x n, T2, D] -> [b, n x T2, D]
|
208 |
+
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
209 |
+
# [b, n, T2] -> [b, n x T2]
|
210 |
+
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
211 |
+
|
212 |
+
# [b, T, D]
|
213 |
+
input_ids[input_ids < 0] = 0 # ignore the image embeddings
|
214 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
215 |
+
|
216 |
+
# replace with the image embeddings
|
217 |
+
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
218 |
+
|
219 |
+
return inputs_embeds
|
220 |
+
|
221 |
+
|
222 |
+
AutoConfig.register("vision_und_enc", VisionUnderstandEncoderConfig)
|
223 |
+
AutoConfig.register("vision_gen_enc", VisionGenerationEncoderConfig)
|
224 |
+
AutoConfig.register("vision_gen_dec", VisionGenerationDecoderConfig)
|
225 |
+
AutoConfig.register("multi_modality", MultiModalityConfig)
|
226 |
+
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
|
janus/janusflow/models/processing_vlm.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Dict, List
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from PIL.Image import Image
|
25 |
+
from transformers import LlamaTokenizerFast
|
26 |
+
from transformers.processing_utils import ProcessorMixin
|
27 |
+
|
28 |
+
from janus.janusflow.models.image_processing_vlm import VLMImageProcessor
|
29 |
+
from janus.utils.conversation import get_conv_template
|
30 |
+
|
31 |
+
|
32 |
+
class DictOutput(object):
|
33 |
+
def keys(self):
|
34 |
+
return self.__dict__.keys()
|
35 |
+
|
36 |
+
def __getitem__(self, item):
|
37 |
+
return self.__dict__[item]
|
38 |
+
|
39 |
+
def __setitem__(self, key, value):
|
40 |
+
self.__dict__[key] = value
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class VLChatProcessorOutput(DictOutput):
|
45 |
+
sft_format: str
|
46 |
+
input_ids: torch.Tensor
|
47 |
+
pixel_values: torch.Tensor
|
48 |
+
num_und_image_tokens: torch.IntTensor
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.input_ids)
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class BatchedVLChatProcessorOutput(DictOutput):
|
56 |
+
sft_format: List[str]
|
57 |
+
input_ids: torch.Tensor
|
58 |
+
pixel_values: torch.Tensor
|
59 |
+
attention_mask: torch.Tensor
|
60 |
+
images_seq_mask: torch.BoolTensor
|
61 |
+
images_emb_mask: torch.BoolTensor
|
62 |
+
|
63 |
+
def to(self, device, dtype=torch.bfloat16):
|
64 |
+
self.input_ids = self.input_ids.to(device)
|
65 |
+
self.attention_mask = self.attention_mask.to(device)
|
66 |
+
self.images_seq_mask = self.images_seq_mask.to(device)
|
67 |
+
self.images_emb_mask = self.images_emb_mask.to(device)
|
68 |
+
self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
|
69 |
+
return self
|
70 |
+
|
71 |
+
|
72 |
+
class VLChatProcessor(ProcessorMixin):
|
73 |
+
image_processor_class = "AutoImageProcessor"
|
74 |
+
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
75 |
+
|
76 |
+
attributes = ["image_processor", "tokenizer"]
|
77 |
+
|
78 |
+
system_prompt = (
|
79 |
+
"You are a helpful language and vision assistant. "
|
80 |
+
"You are able to understand the visual content that the user provides, "
|
81 |
+
"and assist the user with a variety of tasks using natural language."
|
82 |
+
)
|
83 |
+
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
image_processor: VLMImageProcessor,
|
87 |
+
tokenizer: LlamaTokenizerFast,
|
88 |
+
image_tag: str = "<image_placeholder>",
|
89 |
+
image_start_tag: str = "<begin_of_image>",
|
90 |
+
image_end_tag: str = "<end_of_image>",
|
91 |
+
image_gen_tag: str = "<|begin▁of▁generation|>",
|
92 |
+
num_image_tokens: int = 576,
|
93 |
+
add_special_token: bool = False,
|
94 |
+
sft_format: str = "deepseek",
|
95 |
+
mask_prompt: bool = True,
|
96 |
+
ignore_id: int = -100,
|
97 |
+
**kwargs,
|
98 |
+
):
|
99 |
+
self.image_processor = image_processor
|
100 |
+
self.tokenizer = tokenizer
|
101 |
+
|
102 |
+
image_id = self.tokenizer.vocab.get(image_tag)
|
103 |
+
if image_id is None:
|
104 |
+
special_tokens = [image_tag]
|
105 |
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
106 |
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
107 |
+
print(f"Add image tag = {image_tag} to the tokenizer")
|
108 |
+
|
109 |
+
image_gen_id = self.tokenizer.vocab.get(image_gen_tag)
|
110 |
+
if image_gen_id is None:
|
111 |
+
special_tokens = [image_gen_tag]
|
112 |
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
113 |
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
114 |
+
print(f"Add generation tag = {image_gen_tag} to the tokenizer")
|
115 |
+
|
116 |
+
assert image_start_tag is not None and image_end_tag is not None
|
117 |
+
boi_id = self.tokenizer.vocab.get(image_start_tag)
|
118 |
+
eoi_id = self.tokenizer.vocab.get(image_end_tag)
|
119 |
+
if boi_id is None:
|
120 |
+
special_tokens = [image_start_tag]
|
121 |
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
122 |
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
123 |
+
print(f"Add boi tag = {image_start_tag} to the tokenizer")
|
124 |
+
if eoi_id is None:
|
125 |
+
special_tokens = [image_end_tag]
|
126 |
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
127 |
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
128 |
+
print(f"Add eoi tag = {image_end_tag} to the tokenizer")
|
129 |
+
|
130 |
+
self.image_tag = image_tag
|
131 |
+
self.image_gen_tag = image_gen_tag
|
132 |
+
self.image_start_tag = image_start_tag
|
133 |
+
self.image_end_tag = image_end_tag
|
134 |
+
|
135 |
+
self.num_image_tokens = num_image_tokens
|
136 |
+
self.add_special_token = add_special_token
|
137 |
+
self.sft_format = sft_format
|
138 |
+
self.mask_prompt = mask_prompt
|
139 |
+
self.ignore_id = ignore_id
|
140 |
+
self.tokenizer.pad_token_id = self.tokenizer.vocab.get("<|▁pad▁|>")
|
141 |
+
|
142 |
+
super().__init__(
|
143 |
+
image_processor,
|
144 |
+
tokenizer,
|
145 |
+
image_tag,
|
146 |
+
num_image_tokens,
|
147 |
+
add_special_token,
|
148 |
+
sft_format,
|
149 |
+
mask_prompt,
|
150 |
+
ignore_id,
|
151 |
+
**kwargs,
|
152 |
+
)
|
153 |
+
|
154 |
+
def new_chat_template(self):
|
155 |
+
conv = get_conv_template(self.sft_format)
|
156 |
+
conv.set_system_message(self.system_prompt)
|
157 |
+
return conv
|
158 |
+
|
159 |
+
def apply_sft_template_for_multi_turn_prompts(
|
160 |
+
self,
|
161 |
+
conversations: List[Dict[str, str]],
|
162 |
+
sft_format: str = "deepseek",
|
163 |
+
system_prompt: str = "",
|
164 |
+
):
|
165 |
+
"""
|
166 |
+
Applies the SFT template to conversation.
|
167 |
+
|
168 |
+
An example of conversation:
|
169 |
+
conversation = [
|
170 |
+
{
|
171 |
+
"role": "User",
|
172 |
+
"content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
|
173 |
+
"images": [
|
174 |
+
"./multi-images/attribute_comparison_1.png",
|
175 |
+
"./multi-images/attribute_comparison_2.png"
|
176 |
+
]
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"role": "Assistant",
|
180 |
+
"content": ""
|
181 |
+
}
|
182 |
+
]
|
183 |
+
|
184 |
+
Args:
|
185 |
+
conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
|
186 |
+
sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
|
187 |
+
system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
sft_prompt (str): The formatted text.
|
191 |
+
"""
|
192 |
+
|
193 |
+
conv = get_conv_template(sft_format)
|
194 |
+
conv.set_system_message(system_prompt)
|
195 |
+
for message in conversations:
|
196 |
+
conv.append_message(message["role"], message["content"].strip())
|
197 |
+
sft_prompt = conv.get_prompt().strip()
|
198 |
+
|
199 |
+
return sft_prompt
|
200 |
+
|
201 |
+
@property
|
202 |
+
def image_token(self):
|
203 |
+
return self.image_tag
|
204 |
+
|
205 |
+
@property
|
206 |
+
def image_id(self):
|
207 |
+
image_id = self.tokenizer.vocab.get(self.image_tag)
|
208 |
+
return image_id
|
209 |
+
|
210 |
+
@property
|
211 |
+
def image_start_id(self):
|
212 |
+
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
|
213 |
+
return image_start_id
|
214 |
+
|
215 |
+
@property
|
216 |
+
def image_end_id(self):
|
217 |
+
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
|
218 |
+
return image_end_id
|
219 |
+
|
220 |
+
@property
|
221 |
+
def image_start_token(self):
|
222 |
+
return self.image_start_tag
|
223 |
+
|
224 |
+
@property
|
225 |
+
def image_end_token(self):
|
226 |
+
return self.image_end_tag
|
227 |
+
|
228 |
+
@property
|
229 |
+
def pad_id(self):
|
230 |
+
pad_id = self.tokenizer.pad_token_id
|
231 |
+
if pad_id is None:
|
232 |
+
pad_id = self.tokenizer.eos_token_id
|
233 |
+
|
234 |
+
return pad_id
|
235 |
+
|
236 |
+
@property
|
237 |
+
def image_gen_id(self):
|
238 |
+
image_gen_id = self.tokenizer.vocab.get(self.image_gen_tag)
|
239 |
+
return image_gen_id
|
240 |
+
|
241 |
+
def add_image_token(
|
242 |
+
self,
|
243 |
+
image_indices: List[int],
|
244 |
+
input_ids: torch.LongTensor,
|
245 |
+
):
|
246 |
+
"""
|
247 |
+
|
248 |
+
Args:
|
249 |
+
image_indices (List[int]): [index_0, index_1, ..., index_j]
|
250 |
+
input_ids (torch.LongTensor): [N]
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
input_ids (torch.LongTensor): [N + image tokens]
|
254 |
+
num_image_tokens (torch.IntTensor): [n_images]
|
255 |
+
"""
|
256 |
+
|
257 |
+
input_slices = []
|
258 |
+
|
259 |
+
start = 0
|
260 |
+
for index in image_indices:
|
261 |
+
if self.add_special_token:
|
262 |
+
end = index + 1
|
263 |
+
else:
|
264 |
+
end = index
|
265 |
+
|
266 |
+
# original text tokens
|
267 |
+
input_slices.append(input_ids[start:end])
|
268 |
+
|
269 |
+
# add boi, image tokens, eoi and set the mask as False
|
270 |
+
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
|
271 |
+
input_slices.append(
|
272 |
+
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
|
273 |
+
)
|
274 |
+
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
|
275 |
+
start = index + 1
|
276 |
+
|
277 |
+
# the left part
|
278 |
+
input_slices.append(input_ids[start:])
|
279 |
+
|
280 |
+
# concat all slices
|
281 |
+
input_ids = torch.cat(input_slices, dim=0)
|
282 |
+
num_image_tokens = torch.IntTensor(
|
283 |
+
[self.num_image_tokens + 1] * len(image_indices)
|
284 |
+
)
|
285 |
+
# we add 1 to fit generation
|
286 |
+
|
287 |
+
return input_ids, num_image_tokens
|
288 |
+
|
289 |
+
def process_one(
|
290 |
+
self,
|
291 |
+
prompt: str = None,
|
292 |
+
conversations: List[Dict[str, str]] = None,
|
293 |
+
images: List[Image] = None,
|
294 |
+
**kwargs,
|
295 |
+
):
|
296 |
+
"""
|
297 |
+
|
298 |
+
Args:
|
299 |
+
prompt (str): the formatted prompt;
|
300 |
+
conversations (List[Dict]): conversations with a list of messages;
|
301 |
+
images (List[ImageType]): the list of images;
|
302 |
+
**kwargs:
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
outputs (BaseProcessorOutput): the output of the processor,
|
306 |
+
- input_ids (torch.LongTensor): [N + image tokens]
|
307 |
+
- target_ids (torch.LongTensor): [N + image tokens]
|
308 |
+
- images (torch.FloatTensor): [n_images, 3, H, W]
|
309 |
+
- image_id (int): the id of the image token
|
310 |
+
- num_image_tokens (List[int]): the number of image tokens
|
311 |
+
"""
|
312 |
+
|
313 |
+
assert (
|
314 |
+
prompt is None or conversations is None
|
315 |
+
), "prompt and conversations cannot be used at the same time."
|
316 |
+
|
317 |
+
if prompt is None:
|
318 |
+
# apply sft format
|
319 |
+
sft_format = self.apply_sft_template_for_multi_turn_prompts(
|
320 |
+
conversations=conversations,
|
321 |
+
sft_format=self.sft_format,
|
322 |
+
system_prompt=self.system_prompt,
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
sft_format = prompt
|
326 |
+
|
327 |
+
# tokenize
|
328 |
+
input_ids = self.tokenizer.encode(sft_format)
|
329 |
+
input_ids = torch.LongTensor(input_ids)
|
330 |
+
|
331 |
+
# add image tokens to the input_ids
|
332 |
+
image_token_mask: torch.BoolTensor = input_ids == self.image_id
|
333 |
+
image_indices = image_token_mask.nonzero()
|
334 |
+
|
335 |
+
input_ids, num_und_image_tokens = self.add_image_token(
|
336 |
+
image_indices=image_indices,
|
337 |
+
input_ids=input_ids,
|
338 |
+
)
|
339 |
+
|
340 |
+
# load images
|
341 |
+
images_outputs = self.image_processor(images, return_tensors="pt")
|
342 |
+
|
343 |
+
prepare = VLChatProcessorOutput(
|
344 |
+
sft_format=sft_format,
|
345 |
+
input_ids=input_ids,
|
346 |
+
pixel_values=images_outputs.pixel_values,
|
347 |
+
num_und_image_tokens=num_und_image_tokens,
|
348 |
+
)
|
349 |
+
|
350 |
+
return prepare
|
351 |
+
|
352 |
+
def __call__(
|
353 |
+
self,
|
354 |
+
*,
|
355 |
+
prompt: str = None,
|
356 |
+
conversations: List[Dict[str, str]] = None,
|
357 |
+
images: List[Image] = None,
|
358 |
+
force_batchify: bool = True,
|
359 |
+
**kwargs,
|
360 |
+
):
|
361 |
+
"""
|
362 |
+
|
363 |
+
Args:
|
364 |
+
prompt (str): the formatted prompt;
|
365 |
+
conversations (List[Dict]): conversations with a list of messages;
|
366 |
+
images (List[ImageType]): the list of images;
|
367 |
+
force_batchify (bool): force batchify the inputs;
|
368 |
+
**kwargs:
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
outputs (BaseProcessorOutput): the output of the processor,
|
372 |
+
- input_ids (torch.LongTensor): [N + image tokens]
|
373 |
+
- images (torch.FloatTensor): [n_images, 3, H, W]
|
374 |
+
- image_id (int): the id of the image token
|
375 |
+
- num_image_tokens (List[int]): the number of image tokens
|
376 |
+
"""
|
377 |
+
|
378 |
+
prepare = self.process_one(
|
379 |
+
prompt=prompt, conversations=conversations, images=images
|
380 |
+
)
|
381 |
+
|
382 |
+
if force_batchify:
|
383 |
+
prepare = self.batchify([prepare])
|
384 |
+
|
385 |
+
return prepare
|
386 |
+
|
387 |
+
def batchify(
|
388 |
+
self, prepare_list: List[VLChatProcessorOutput]
|
389 |
+
) -> BatchedVLChatProcessorOutput:
|
390 |
+
"""
|
391 |
+
Preprocesses the inputs for multimodal inference.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
|
398 |
+
"""
|
399 |
+
|
400 |
+
batch_size = len(prepare_list)
|
401 |
+
sft_format = []
|
402 |
+
n_images = []
|
403 |
+
seq_lens = []
|
404 |
+
for prepare in prepare_list:
|
405 |
+
# we only fill the images for understanding tasks into the mask
|
406 |
+
n_images.append(len(prepare.num_und_image_tokens))
|
407 |
+
seq_lens.append(len(prepare))
|
408 |
+
|
409 |
+
input_token_max_len = max(seq_lens)
|
410 |
+
max_n_images = max(1, max(n_images))
|
411 |
+
|
412 |
+
batched_input_ids = torch.full(
|
413 |
+
(batch_size, input_token_max_len), self.pad_id
|
414 |
+
).long() # FIXME
|
415 |
+
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
|
416 |
+
batched_pixel_values = torch.zeros(
|
417 |
+
(batch_size, max_n_images, *self.image_processor.default_shape)
|
418 |
+
).float()
|
419 |
+
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
|
420 |
+
batched_images_emb_mask = torch.zeros(
|
421 |
+
(
|
422 |
+
batch_size,
|
423 |
+
max_n_images,
|
424 |
+
self.num_image_tokens + 1,
|
425 |
+
) # add 1 to account for <image_beg>
|
426 |
+
).bool()
|
427 |
+
|
428 |
+
for i, prepare in enumerate(prepare_list):
|
429 |
+
input_ids = prepare.input_ids
|
430 |
+
seq_len = len(prepare)
|
431 |
+
n_image = len(prepare.num_und_image_tokens)
|
432 |
+
# left-padding
|
433 |
+
batched_attention_mask[i, -seq_len:] = 1
|
434 |
+
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
|
435 |
+
batched_images_seq_mask[i, -seq_len:] = (input_ids == self.image_id) | (
|
436 |
+
input_ids == self.image_start_id
|
437 |
+
)
|
438 |
+
|
439 |
+
if n_image > 0:
|
440 |
+
batched_pixel_values[i, :n_image] = prepare.pixel_values
|
441 |
+
for j, n_image_tokens in enumerate(prepare.num_und_image_tokens):
|
442 |
+
batched_images_emb_mask[i, j, :n_image_tokens] = True
|
443 |
+
|
444 |
+
sft_format.append(prepare.sft_format)
|
445 |
+
|
446 |
+
batched_prepares = BatchedVLChatProcessorOutput(
|
447 |
+
input_ids=batched_input_ids,
|
448 |
+
attention_mask=batched_attention_mask,
|
449 |
+
pixel_values=batched_pixel_values,
|
450 |
+
images_seq_mask=batched_images_seq_mask,
|
451 |
+
images_emb_mask=batched_images_emb_mask,
|
452 |
+
sft_format=sft_format,
|
453 |
+
)
|
454 |
+
|
455 |
+
return batched_prepares
|
janus/janusflow/models/siglip_vit.py
ADDED
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
21 |
+
import math
|
22 |
+
import warnings
|
23 |
+
from dataclasses import dataclass
|
24 |
+
from functools import partial
|
25 |
+
from typing import (
|
26 |
+
Callable,
|
27 |
+
Dict,
|
28 |
+
Final,
|
29 |
+
List,
|
30 |
+
Literal,
|
31 |
+
Optional,
|
32 |
+
Sequence,
|
33 |
+
Set,
|
34 |
+
Tuple,
|
35 |
+
Type,
|
36 |
+
Union,
|
37 |
+
)
|
38 |
+
|
39 |
+
import torch
|
40 |
+
import torch.nn as nn
|
41 |
+
import torch.nn.functional as F
|
42 |
+
from timm.layers import (
|
43 |
+
AttentionPoolLatent,
|
44 |
+
DropPath,
|
45 |
+
LayerType,
|
46 |
+
Mlp,
|
47 |
+
PatchDropout,
|
48 |
+
PatchEmbed,
|
49 |
+
resample_abs_pos_embed,
|
50 |
+
)
|
51 |
+
from timm.models._manipulate import checkpoint_seq, named_apply
|
52 |
+
|
53 |
+
|
54 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
55 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
56 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
57 |
+
def norm_cdf(x):
|
58 |
+
# Computes standard normal cumulative distribution function
|
59 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
60 |
+
|
61 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
62 |
+
warnings.warn(
|
63 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
64 |
+
"The distribution of values may be incorrect.",
|
65 |
+
stacklevel=2,
|
66 |
+
)
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
# Values are generated by using a truncated uniform distribution and
|
70 |
+
# then using the inverse CDF for the normal distribution.
|
71 |
+
# Get upper and lower cdf values
|
72 |
+
l = norm_cdf((a - mean) / std) # noqa: E741
|
73 |
+
u = norm_cdf((b - mean) / std)
|
74 |
+
|
75 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
76 |
+
# [2l-1, 2u-1].
|
77 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
78 |
+
|
79 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
80 |
+
# standard normal
|
81 |
+
tensor.erfinv_()
|
82 |
+
|
83 |
+
# Transform to proper mean, std
|
84 |
+
tensor.mul_(std * math.sqrt(2.0))
|
85 |
+
tensor.add_(mean)
|
86 |
+
|
87 |
+
# Clamp to ensure it's in the proper range
|
88 |
+
tensor.clamp_(min=a, max=b)
|
89 |
+
return tensor
|
90 |
+
|
91 |
+
|
92 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
93 |
+
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
94 |
+
r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
|
95 |
+
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
|
96 |
+
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
|
97 |
+
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
98 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
99 |
+
the bounds. The method used for generating the random values works
|
100 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
101 |
+
Args:
|
102 |
+
tensor: an n-dimensional `torch.Tensor`
|
103 |
+
mean: the mean of the normal distribution
|
104 |
+
std: the standard deviation of the normal distribution
|
105 |
+
a: the minimum cutoff value
|
106 |
+
b: the maximum cutoff value
|
107 |
+
Examples:
|
108 |
+
>>> w = torch.empty(3, 5)
|
109 |
+
>>> nn.init.trunc_normal_(w)
|
110 |
+
"""
|
111 |
+
|
112 |
+
with torch.no_grad():
|
113 |
+
dtype = tensor.dtype
|
114 |
+
tensor_fp32 = tensor.float()
|
115 |
+
tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
|
116 |
+
tensor_dtype = tensor_fp32.to(dtype=dtype)
|
117 |
+
tensor.copy_(tensor_dtype)
|
118 |
+
|
119 |
+
|
120 |
+
def init_weights(self):
|
121 |
+
if self.pos_embed is not None:
|
122 |
+
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
123 |
+
trunc_normal_(self.latent, std=self.latent_dim**-0.5)
|
124 |
+
|
125 |
+
|
126 |
+
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
|
127 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
128 |
+
if isinstance(module, nn.Linear):
|
129 |
+
trunc_normal_(module.weight, std=0.02)
|
130 |
+
if module.bias is not None:
|
131 |
+
nn.init.zeros_(module.bias)
|
132 |
+
elif hasattr(module, "init_weights"):
|
133 |
+
module.init_weights()
|
134 |
+
|
135 |
+
|
136 |
+
class Attention(nn.Module):
|
137 |
+
fused_attn: Final[bool]
|
138 |
+
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
dim: int,
|
142 |
+
num_heads: int = 8,
|
143 |
+
qkv_bias: bool = False,
|
144 |
+
qk_norm: bool = False,
|
145 |
+
attn_drop: float = 0.0,
|
146 |
+
proj_drop: float = 0.0,
|
147 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
148 |
+
) -> None:
|
149 |
+
super().__init__()
|
150 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
151 |
+
self.num_heads = num_heads
|
152 |
+
self.head_dim = dim // num_heads
|
153 |
+
self.scale = self.head_dim**-0.5
|
154 |
+
# self.fused_attn = use_fused_attn()
|
155 |
+
self.fused_attn = True
|
156 |
+
|
157 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
158 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
159 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
160 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
161 |
+
self.proj = nn.Linear(dim, dim)
|
162 |
+
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
165 |
+
B, N, C = x.shape
|
166 |
+
qkv = (
|
167 |
+
self.qkv(x)
|
168 |
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
169 |
+
.permute(2, 0, 3, 1, 4)
|
170 |
+
)
|
171 |
+
q, k, v = qkv.unbind(0)
|
172 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
173 |
+
|
174 |
+
if self.fused_attn:
|
175 |
+
x = F.scaled_dot_product_attention(
|
176 |
+
q,
|
177 |
+
k,
|
178 |
+
v,
|
179 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
q = q * self.scale
|
183 |
+
attn = q @ k.transpose(-2, -1)
|
184 |
+
attn = attn.softmax(dim=-1)
|
185 |
+
attn = self.attn_drop(attn)
|
186 |
+
x = attn @ v
|
187 |
+
|
188 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
189 |
+
x = self.proj(x)
|
190 |
+
x = self.proj_drop(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
class LayerScale(nn.Module):
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
dim: int,
|
198 |
+
init_values: float = 1e-5,
|
199 |
+
inplace: bool = False,
|
200 |
+
) -> None:
|
201 |
+
super().__init__()
|
202 |
+
self.inplace = inplace
|
203 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
204 |
+
|
205 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
206 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
207 |
+
|
208 |
+
|
209 |
+
class Block(nn.Module):
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
dim: int,
|
213 |
+
num_heads: int,
|
214 |
+
mlp_ratio: float = 4.0,
|
215 |
+
qkv_bias: bool = False,
|
216 |
+
qk_norm: bool = False,
|
217 |
+
proj_drop: float = 0.0,
|
218 |
+
attn_drop: float = 0.0,
|
219 |
+
init_values: Optional[float] = None,
|
220 |
+
drop_path: float = 0.0,
|
221 |
+
act_layer: nn.Module = nn.GELU,
|
222 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
223 |
+
mlp_layer: nn.Module = Mlp,
|
224 |
+
) -> None:
|
225 |
+
super().__init__()
|
226 |
+
self.norm1 = norm_layer(dim)
|
227 |
+
self.attn = Attention(
|
228 |
+
dim,
|
229 |
+
num_heads=num_heads,
|
230 |
+
qkv_bias=qkv_bias,
|
231 |
+
qk_norm=qk_norm,
|
232 |
+
attn_drop=attn_drop,
|
233 |
+
proj_drop=proj_drop,
|
234 |
+
norm_layer=norm_layer,
|
235 |
+
)
|
236 |
+
self.ls1 = (
|
237 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
238 |
+
)
|
239 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
240 |
+
|
241 |
+
self.norm2 = norm_layer(dim)
|
242 |
+
self.mlp = mlp_layer(
|
243 |
+
in_features=dim,
|
244 |
+
hidden_features=int(dim * mlp_ratio),
|
245 |
+
act_layer=act_layer,
|
246 |
+
drop=proj_drop,
|
247 |
+
)
|
248 |
+
self.ls2 = (
|
249 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
250 |
+
)
|
251 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
252 |
+
|
253 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
254 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
255 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
256 |
+
return x
|
257 |
+
|
258 |
+
|
259 |
+
class VisionTransformer(nn.Module):
|
260 |
+
"""Vision Transformer
|
261 |
+
|
262 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
263 |
+
- https://arxiv.org/abs/2010.11929
|
264 |
+
"""
|
265 |
+
|
266 |
+
dynamic_img_size: Final[bool]
|
267 |
+
|
268 |
+
def __init__(
|
269 |
+
self,
|
270 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
271 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
272 |
+
in_chans: int = 3,
|
273 |
+
num_classes: int = 1000,
|
274 |
+
global_pool: Literal["", "avg", "token", "map"] = "token",
|
275 |
+
embed_dim: int = 768,
|
276 |
+
depth: int = 12,
|
277 |
+
num_heads: int = 12,
|
278 |
+
mlp_ratio: float = 4.0,
|
279 |
+
qkv_bias: bool = True,
|
280 |
+
qk_norm: bool = False,
|
281 |
+
init_values: Optional[float] = None,
|
282 |
+
class_token: bool = True,
|
283 |
+
no_embed_class: bool = False,
|
284 |
+
reg_tokens: int = 0,
|
285 |
+
pre_norm: bool = False,
|
286 |
+
fc_norm: Optional[bool] = None,
|
287 |
+
dynamic_img_size: bool = False,
|
288 |
+
dynamic_img_pad: bool = False,
|
289 |
+
drop_rate: float = 0.0,
|
290 |
+
pos_drop_rate: float = 0.0,
|
291 |
+
patch_drop_rate: float = 0.0,
|
292 |
+
proj_drop_rate: float = 0.0,
|
293 |
+
attn_drop_rate: float = 0.0,
|
294 |
+
drop_path_rate: float = 0.0,
|
295 |
+
weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
|
296 |
+
embed_layer: Callable = PatchEmbed,
|
297 |
+
norm_layer: Optional[LayerType] = None,
|
298 |
+
act_layer: Optional[LayerType] = None,
|
299 |
+
block_fn: Type[nn.Module] = Block,
|
300 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
301 |
+
ignore_head: bool = False,
|
302 |
+
) -> None:
|
303 |
+
"""
|
304 |
+
Args:
|
305 |
+
img_size: Input image size.
|
306 |
+
patch_size: Patch size.
|
307 |
+
in_chans: Number of image input channels.
|
308 |
+
num_classes: Mumber of classes for classification head.
|
309 |
+
global_pool: Type of global pooling for final sequence (default: 'token').
|
310 |
+
embed_dim: Transformer embedding dimension.
|
311 |
+
depth: Depth of transformer.
|
312 |
+
num_heads: Number of attention heads.
|
313 |
+
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
314 |
+
qkv_bias: Enable bias for qkv projections if True.
|
315 |
+
init_values: Layer-scale init values (layer-scale enabled if not None).
|
316 |
+
class_token: Use class token.
|
317 |
+
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
318 |
+
reg_tokens: Number of register tokens.
|
319 |
+
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
320 |
+
drop_rate: Head dropout rate.
|
321 |
+
pos_drop_rate: Position embedding dropout rate.
|
322 |
+
attn_drop_rate: Attention dropout rate.
|
323 |
+
drop_path_rate: Stochastic depth rate.
|
324 |
+
weight_init: Weight initialization scheme.
|
325 |
+
embed_layer: Patch embedding layer.
|
326 |
+
norm_layer: Normalization layer.
|
327 |
+
act_layer: MLP activation layer.
|
328 |
+
block_fn: Transformer block layer.
|
329 |
+
"""
|
330 |
+
super().__init__()
|
331 |
+
assert global_pool in ("", "avg", "token", "map")
|
332 |
+
assert class_token or global_pool != "token"
|
333 |
+
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
|
334 |
+
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
335 |
+
# act_layer = get_act_layer(act_layer) or nn.GELU
|
336 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
337 |
+
act_layer = nn.GELU
|
338 |
+
|
339 |
+
self.num_classes = num_classes
|
340 |
+
self.global_pool = global_pool
|
341 |
+
self.num_features = self.embed_dim = (
|
342 |
+
embed_dim # num_features for consistency with other models
|
343 |
+
)
|
344 |
+
self.num_prefix_tokens = 1 if class_token else 0
|
345 |
+
self.num_prefix_tokens += reg_tokens
|
346 |
+
self.num_reg_tokens = reg_tokens
|
347 |
+
self.has_class_token = class_token
|
348 |
+
self.no_embed_class = (
|
349 |
+
no_embed_class # don't embed prefix positions (includes reg)
|
350 |
+
)
|
351 |
+
self.dynamic_img_size = dynamic_img_size
|
352 |
+
self.grad_checkpointing = False
|
353 |
+
self.ignore_head = ignore_head
|
354 |
+
|
355 |
+
embed_args = {}
|
356 |
+
if dynamic_img_size:
|
357 |
+
# flatten deferred until after pos embed
|
358 |
+
embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
|
359 |
+
self.patch_embed = embed_layer(
|
360 |
+
img_size=img_size,
|
361 |
+
patch_size=patch_size,
|
362 |
+
in_chans=in_chans,
|
363 |
+
embed_dim=embed_dim,
|
364 |
+
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
365 |
+
dynamic_img_pad=dynamic_img_pad,
|
366 |
+
**embed_args,
|
367 |
+
)
|
368 |
+
num_patches = self.patch_embed.num_patches
|
369 |
+
|
370 |
+
self.cls_token = (
|
371 |
+
nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
372 |
+
)
|
373 |
+
self.reg_token = (
|
374 |
+
nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
375 |
+
)
|
376 |
+
embed_len = (
|
377 |
+
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
378 |
+
)
|
379 |
+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
380 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
381 |
+
if patch_drop_rate > 0:
|
382 |
+
self.patch_drop = PatchDropout(
|
383 |
+
patch_drop_rate,
|
384 |
+
num_prefix_tokens=self.num_prefix_tokens,
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
self.patch_drop = nn.Identity()
|
388 |
+
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
389 |
+
|
390 |
+
dpr = [
|
391 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
392 |
+
] # stochastic depth decay rule
|
393 |
+
self.blocks = nn.Sequential(
|
394 |
+
*[
|
395 |
+
block_fn(
|
396 |
+
dim=embed_dim,
|
397 |
+
num_heads=num_heads,
|
398 |
+
mlp_ratio=mlp_ratio,
|
399 |
+
qkv_bias=qkv_bias,
|
400 |
+
qk_norm=qk_norm,
|
401 |
+
init_values=init_values,
|
402 |
+
proj_drop=proj_drop_rate,
|
403 |
+
attn_drop=attn_drop_rate,
|
404 |
+
drop_path=dpr[i],
|
405 |
+
norm_layer=norm_layer,
|
406 |
+
act_layer=act_layer,
|
407 |
+
mlp_layer=mlp_layer,
|
408 |
+
)
|
409 |
+
for i in range(depth)
|
410 |
+
]
|
411 |
+
)
|
412 |
+
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
413 |
+
|
414 |
+
# Classifier Head
|
415 |
+
if global_pool == "map":
|
416 |
+
AttentionPoolLatent.init_weights = init_weights
|
417 |
+
self.attn_pool = AttentionPoolLatent(
|
418 |
+
self.embed_dim,
|
419 |
+
num_heads=num_heads,
|
420 |
+
mlp_ratio=mlp_ratio,
|
421 |
+
norm_layer=norm_layer,
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
self.attn_pool = None
|
425 |
+
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
426 |
+
self.head_drop = nn.Dropout(drop_rate)
|
427 |
+
self.head = (
|
428 |
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
429 |
+
)
|
430 |
+
|
431 |
+
if weight_init != "skip":
|
432 |
+
self.init_weights(weight_init)
|
433 |
+
|
434 |
+
def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
|
435 |
+
assert mode in ("jax", "jax_nlhb", "moco", "")
|
436 |
+
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
437 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
438 |
+
if self.cls_token is not None:
|
439 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
440 |
+
named_apply(init_weights_vit_timm, self)
|
441 |
+
|
442 |
+
@torch.jit.ignore
|
443 |
+
def no_weight_decay(self) -> Set:
|
444 |
+
return {"pos_embed", "cls_token", "dist_token"}
|
445 |
+
|
446 |
+
@torch.jit.ignore
|
447 |
+
def group_matcher(self, coarse: bool = False) -> Dict:
|
448 |
+
return dict(
|
449 |
+
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
|
450 |
+
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
|
451 |
+
)
|
452 |
+
|
453 |
+
@torch.jit.ignore
|
454 |
+
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
455 |
+
self.grad_checkpointing = enable
|
456 |
+
|
457 |
+
@torch.jit.ignore
|
458 |
+
def get_classifier(self) -> nn.Module:
|
459 |
+
return self.head
|
460 |
+
|
461 |
+
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
|
462 |
+
self.num_classes = num_classes
|
463 |
+
if global_pool is not None:
|
464 |
+
assert global_pool in ("", "avg", "token", "map")
|
465 |
+
if global_pool == "map" and self.attn_pool is None:
|
466 |
+
assert (
|
467 |
+
False
|
468 |
+
), "Cannot currently add attention pooling in reset_classifier()."
|
469 |
+
elif global_pool != "map " and self.attn_pool is not None:
|
470 |
+
self.attn_pool = None # remove attention pooling
|
471 |
+
self.global_pool = global_pool
|
472 |
+
self.head = (
|
473 |
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
474 |
+
)
|
475 |
+
|
476 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
477 |
+
if self.dynamic_img_size:
|
478 |
+
B, H, W, C = x.shape
|
479 |
+
pos_embed = resample_abs_pos_embed(
|
480 |
+
self.pos_embed,
|
481 |
+
(H, W),
|
482 |
+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
483 |
+
)
|
484 |
+
x = x.view(B, -1, C)
|
485 |
+
else:
|
486 |
+
pos_embed = self.pos_embed
|
487 |
+
|
488 |
+
to_cat = []
|
489 |
+
if self.cls_token is not None:
|
490 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
491 |
+
if self.reg_token is not None:
|
492 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
493 |
+
|
494 |
+
if self.no_embed_class:
|
495 |
+
# deit-3, updated JAX (big vision)
|
496 |
+
# position embedding does not overlap with class token, add then concat
|
497 |
+
x = x + pos_embed
|
498 |
+
if to_cat:
|
499 |
+
x = torch.cat(to_cat + [x], dim=1)
|
500 |
+
else:
|
501 |
+
# original timm, JAX, and deit vit impl
|
502 |
+
# pos_embed has entry for class token, concat then add
|
503 |
+
if to_cat:
|
504 |
+
x = torch.cat(to_cat + [x], dim=1)
|
505 |
+
x = x + pos_embed
|
506 |
+
|
507 |
+
return self.pos_drop(x)
|
508 |
+
|
509 |
+
def _intermediate_layers(
|
510 |
+
self,
|
511 |
+
x: torch.Tensor,
|
512 |
+
n: Union[int, Sequence] = 1,
|
513 |
+
) -> List[torch.Tensor]:
|
514 |
+
outputs, num_blocks = [], len(self.blocks)
|
515 |
+
take_indices = set(
|
516 |
+
range(num_blocks - n, num_blocks) if isinstance(n, int) else n
|
517 |
+
)
|
518 |
+
|
519 |
+
# forward pass
|
520 |
+
x = self.patch_embed(x)
|
521 |
+
x = self._pos_embed(x)
|
522 |
+
x = self.patch_drop(x)
|
523 |
+
x = self.norm_pre(x)
|
524 |
+
for i, blk in enumerate(self.blocks):
|
525 |
+
x = blk(x)
|
526 |
+
if i in take_indices:
|
527 |
+
outputs.append(x)
|
528 |
+
|
529 |
+
return outputs
|
530 |
+
|
531 |
+
def get_intermediate_layers(
|
532 |
+
self,
|
533 |
+
x: torch.Tensor,
|
534 |
+
n: Union[int, Sequence] = 1,
|
535 |
+
reshape: bool = False,
|
536 |
+
return_prefix_tokens: bool = False,
|
537 |
+
norm: bool = False,
|
538 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
539 |
+
"""Intermediate layer accessor (NOTE: This is a WIP experiment).
|
540 |
+
Inspired by DINO / DINOv2 interface
|
541 |
+
"""
|
542 |
+
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
543 |
+
outputs = self._intermediate_layers(x, n)
|
544 |
+
if norm:
|
545 |
+
outputs = [self.norm(out) for out in outputs]
|
546 |
+
prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
|
547 |
+
outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
|
548 |
+
|
549 |
+
if reshape:
|
550 |
+
grid_size = self.patch_embed.grid_size
|
551 |
+
outputs = [
|
552 |
+
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
|
553 |
+
.permute(0, 3, 1, 2)
|
554 |
+
.contiguous()
|
555 |
+
for out in outputs
|
556 |
+
]
|
557 |
+
|
558 |
+
if return_prefix_tokens:
|
559 |
+
return tuple(zip(outputs, prefix_tokens))
|
560 |
+
return tuple(outputs)
|
561 |
+
|
562 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
563 |
+
x = self.patch_embed(x)
|
564 |
+
x = self._pos_embed(x)
|
565 |
+
x = self.patch_drop(x)
|
566 |
+
x = self.norm_pre(x)
|
567 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
568 |
+
x = checkpoint_seq(self.blocks, x)
|
569 |
+
else:
|
570 |
+
x = self.blocks(x)
|
571 |
+
x = self.norm(x)
|
572 |
+
return x
|
573 |
+
|
574 |
+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
575 |
+
if self.attn_pool is not None:
|
576 |
+
x = self.attn_pool(x)
|
577 |
+
elif self.global_pool == "avg":
|
578 |
+
x = x[:, self.num_prefix_tokens :].mean(dim=1)
|
579 |
+
elif self.global_pool:
|
580 |
+
x = x[:, 0] # class token
|
581 |
+
x = self.fc_norm(x)
|
582 |
+
x = self.head_drop(x)
|
583 |
+
return x if pre_logits else self.head(x)
|
584 |
+
|
585 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
586 |
+
x = self.forward_features(x)
|
587 |
+
if not self.ignore_head:
|
588 |
+
x = self.forward_head(x)
|
589 |
+
return x
|
590 |
+
|
591 |
+
|
592 |
+
@dataclass
|
593 |
+
class SigLIPVisionCfg:
|
594 |
+
width: int = 1152
|
595 |
+
layers: Union[Tuple[int, int, int, int], int] = 27
|
596 |
+
heads: int = 16
|
597 |
+
patch_size: int = 14
|
598 |
+
image_size: Union[Tuple[int, int], int] = 336
|
599 |
+
global_pool: str = "map"
|
600 |
+
mlp_ratio: float = 3.7362
|
601 |
+
class_token: bool = False
|
602 |
+
num_classes: int = 0
|
603 |
+
use_checkpoint: bool = False
|
604 |
+
|
605 |
+
|
606 |
+
SigLIP_MODEL_CONFIG = {
|
607 |
+
"siglip_so400m_patch14_384": {
|
608 |
+
"image_size": 336,
|
609 |
+
"patch_size": 14,
|
610 |
+
"width": 1152,
|
611 |
+
"layers": 27,
|
612 |
+
"heads": 16,
|
613 |
+
"mlp_ratio": 3.7362,
|
614 |
+
"global_pool": "map",
|
615 |
+
"use_checkpoint": False,
|
616 |
+
},
|
617 |
+
"siglip_so400m_patch14_224": {
|
618 |
+
"image_size": 224,
|
619 |
+
"patch_size": 14,
|
620 |
+
"width": 1152,
|
621 |
+
"layers": 27,
|
622 |
+
"heads": 16,
|
623 |
+
"mlp_ratio": 3.7362,
|
624 |
+
"global_pool": "map",
|
625 |
+
"use_checkpoint": False,
|
626 |
+
},
|
627 |
+
"siglip_large_patch16_384": {
|
628 |
+
"image_size": 384,
|
629 |
+
"patch_size": 16,
|
630 |
+
"width": 1024,
|
631 |
+
"layers": 24,
|
632 |
+
"heads": 16,
|
633 |
+
"mlp_ratio": 4,
|
634 |
+
"global_pool": "map",
|
635 |
+
"use_checkpoint": False,
|
636 |
+
},
|
637 |
+
"siglip_large_patch16_256": {
|
638 |
+
"image_size": 256,
|
639 |
+
"patch_size": 16,
|
640 |
+
"width": 1024,
|
641 |
+
"layers": 24,
|
642 |
+
"heads": 16,
|
643 |
+
"mlp_ratio": 4,
|
644 |
+
"global_pool": "map",
|
645 |
+
"use_checkpoint": False,
|
646 |
+
},
|
647 |
+
}
|
648 |
+
|
649 |
+
|
650 |
+
def create_siglip_vit(
|
651 |
+
model_name: str = "siglip_so400m_patch14_384",
|
652 |
+
image_size: int = 384,
|
653 |
+
select_layer: int = -1,
|
654 |
+
ckpt_path: str = "",
|
655 |
+
**kwargs,
|
656 |
+
):
|
657 |
+
assert (
|
658 |
+
model_name in SigLIP_MODEL_CONFIG.keys()
|
659 |
+
), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
|
660 |
+
|
661 |
+
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
|
662 |
+
|
663 |
+
if select_layer <= 0:
|
664 |
+
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
|
665 |
+
else:
|
666 |
+
layers = min(vision_cfg.layers, select_layer)
|
667 |
+
|
668 |
+
model = VisionTransformer(
|
669 |
+
img_size=image_size,
|
670 |
+
patch_size=vision_cfg.patch_size,
|
671 |
+
embed_dim=vision_cfg.width,
|
672 |
+
depth=layers,
|
673 |
+
num_heads=vision_cfg.heads,
|
674 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
675 |
+
class_token=vision_cfg.class_token,
|
676 |
+
global_pool=vision_cfg.global_pool,
|
677 |
+
ignore_head=kwargs.get("ignore_head", True),
|
678 |
+
weight_init=kwargs.get("weight_init", "skip"),
|
679 |
+
num_classes=0,
|
680 |
+
)
|
681 |
+
|
682 |
+
if ckpt_path:
|
683 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
684 |
+
|
685 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
686 |
+
print(
|
687 |
+
f"SigLIP-ViT restores from {ckpt_path},\n"
|
688 |
+
f"\tincompatible_keys:', {incompatible_keys}."
|
689 |
+
)
|
690 |
+
|
691 |
+
return model
|
janus/janusflow/models/uvit.py
ADDED
@@ -0,0 +1,714 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
# modified from: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/simple_diffusion.py
|
21 |
+
import math
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.distributed as dist
|
25 |
+
import torch.nn.functional as F
|
26 |
+
from typing import Optional, Tuple, Union
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
import torchvision
|
30 |
+
import torchvision.utils
|
31 |
+
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
|
32 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm
|
33 |
+
|
34 |
+
|
35 |
+
class ImageHead(nn.Module):
|
36 |
+
|
37 |
+
def __init__(self, decoder_cfg, gpt_cfg, layer_id=None):
|
38 |
+
super().__init__()
|
39 |
+
self.layer_id = layer_id
|
40 |
+
cfg = (
|
41 |
+
AttrDict(
|
42 |
+
norm_type="layernorm",
|
43 |
+
is_exp_norm=False,
|
44 |
+
sequence_parallel=False,
|
45 |
+
use_userbuffer=False,
|
46 |
+
norm_eps=1e-5,
|
47 |
+
norm_bias=True,
|
48 |
+
gradient_accumulation_fusion=True,
|
49 |
+
use_fp32_head_weight=False,
|
50 |
+
)
|
51 |
+
+ gpt_cfg
|
52 |
+
)
|
53 |
+
group = PG.tensor_parallel_group()
|
54 |
+
assert cfg.norm_type in [
|
55 |
+
"layernorm",
|
56 |
+
"rmsnorm",
|
57 |
+
], f"Norm type:{cfg.norm_type} not supported"
|
58 |
+
if cfg.norm_type == "rmsnorm":
|
59 |
+
self.norm = DropoutAddRMSNorm(
|
60 |
+
cfg.n_embed,
|
61 |
+
prenorm=False,
|
62 |
+
eps=cfg.norm_eps,
|
63 |
+
is_exp_norm=cfg.is_exp_norm,
|
64 |
+
sequence_parallel=cfg.sequence_parallel,
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
self.norm = DropoutAddLayerNorm(
|
68 |
+
cfg.n_embed,
|
69 |
+
prenorm=False,
|
70 |
+
eps=cfg.norm_eps,
|
71 |
+
is_exp_norm=cfg.is_exp_norm,
|
72 |
+
sequence_parallel=cfg.sequence_parallel,
|
73 |
+
bias=cfg.norm_bias,
|
74 |
+
)
|
75 |
+
|
76 |
+
multiple_of = 256
|
77 |
+
if decoder_cfg.in_channels % multiple_of != 0:
|
78 |
+
warnings.warn(
|
79 |
+
f"建议把 vocab_size 设置为 {multiple_of} 的倍数, 否则会影响矩阵乘法的性能"
|
80 |
+
)
|
81 |
+
|
82 |
+
dtype = default_dtype = torch.get_default_dtype()
|
83 |
+
if cfg.use_fp32_head_weight:
|
84 |
+
dtype = torch.float32
|
85 |
+
print(
|
86 |
+
"使用 fp32 head weight!!!! 与原来的 bf16 head weight 不兼容\n",
|
87 |
+
end="",
|
88 |
+
flush=True,
|
89 |
+
)
|
90 |
+
torch.set_default_dtype(dtype)
|
91 |
+
self.head = ColumnParallelLinear(
|
92 |
+
cfg.n_embed,
|
93 |
+
decoder_cfg.in_channels,
|
94 |
+
bias=True,
|
95 |
+
group=group,
|
96 |
+
sequence_parallel=cfg.sequence_parallel,
|
97 |
+
use_userbuffer=cfg.use_userbuffer,
|
98 |
+
gradient_accumulation_fusion=cfg.gradient_accumulation_fusion,
|
99 |
+
use_fp32_output=False,
|
100 |
+
)
|
101 |
+
torch.set_default_dtype(default_dtype)
|
102 |
+
|
103 |
+
self.use_fp32_head_weight = cfg.use_fp32_head_weight
|
104 |
+
|
105 |
+
def forward(
|
106 |
+
self, input_args, images_split_mask: Optional[torch.BoolTensor] = None, **kwargs
|
107 |
+
):
|
108 |
+
residual = None
|
109 |
+
if isinstance(input_args, tuple):
|
110 |
+
x, residual = input_args
|
111 |
+
else:
|
112 |
+
x = input_args
|
113 |
+
|
114 |
+
x = self.norm(x, residual)
|
115 |
+
|
116 |
+
if self.use_fp32_head_weight:
|
117 |
+
assert (
|
118 |
+
self.head.weight.dtype == torch.float32
|
119 |
+
), f"head.weight is {self.head.weight.dtype}"
|
120 |
+
x = x.float()
|
121 |
+
|
122 |
+
if images_split_mask is None:
|
123 |
+
logits = self.head(x)
|
124 |
+
else:
|
125 |
+
bs, n_images = images_split_mask.shape[:2]
|
126 |
+
n_embed = x.shape[-1]
|
127 |
+
|
128 |
+
images_embed = torch.masked_select(
|
129 |
+
x.unsqueeze(1), images_split_mask.unsqueeze(-1)
|
130 |
+
)
|
131 |
+
images_embed = images_embed.view((bs * n_images, -1, n_embed))
|
132 |
+
logits = self.head(images_embed)
|
133 |
+
|
134 |
+
return logits
|
135 |
+
|
136 |
+
|
137 |
+
class GlobalResponseNorm(nn.Module):
|
138 |
+
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
139 |
+
def __init__(self, dim):
|
140 |
+
super().__init__()
|
141 |
+
self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
142 |
+
self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
146 |
+
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
147 |
+
|
148 |
+
return torch.addcmul(self.bias, (self.weight * nx + 1), x, value=1)
|
149 |
+
|
150 |
+
|
151 |
+
class Downsample2D(nn.Module):
|
152 |
+
"""A 2D downsampling layer with an optional convolution.
|
153 |
+
|
154 |
+
Parameters:
|
155 |
+
channels (`int`):
|
156 |
+
number of channels in the inputs and outputs.
|
157 |
+
use_conv (`bool`, default `False`):
|
158 |
+
option to use a convolution.
|
159 |
+
out_channels (`int`, optional):
|
160 |
+
number of output channels. Defaults to `channels`.
|
161 |
+
padding (`int`, default `1`):
|
162 |
+
padding for the convolution.
|
163 |
+
name (`str`, default `conv`):
|
164 |
+
name of the downsampling 2D layer.
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
channels: int,
|
170 |
+
use_conv: bool = False,
|
171 |
+
out_channels: Optional[int] = None,
|
172 |
+
padding: int = 1,
|
173 |
+
name: str = "conv",
|
174 |
+
kernel_size=3,
|
175 |
+
stride=2,
|
176 |
+
norm_type=None,
|
177 |
+
eps=None,
|
178 |
+
elementwise_affine=None,
|
179 |
+
bias=True,
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
self.channels = channels
|
183 |
+
self.out_channels = out_channels or channels
|
184 |
+
self.use_conv = use_conv
|
185 |
+
self.padding = padding
|
186 |
+
self.name = name
|
187 |
+
|
188 |
+
if norm_type == "ln_norm":
|
189 |
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
190 |
+
elif norm_type == "rms_norm":
|
191 |
+
self.norm = RMSNorm(channels, eps)
|
192 |
+
elif norm_type is None:
|
193 |
+
self.norm = None
|
194 |
+
else:
|
195 |
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
196 |
+
|
197 |
+
if use_conv:
|
198 |
+
conv = nn.Conv2d(
|
199 |
+
self.channels,
|
200 |
+
self.out_channels,
|
201 |
+
kernel_size=kernel_size,
|
202 |
+
stride=stride,
|
203 |
+
padding=padding,
|
204 |
+
bias=bias,
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
assert self.channels == self.out_channels
|
208 |
+
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
209 |
+
|
210 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
211 |
+
if name == "conv":
|
212 |
+
self.Conv2d_0 = conv
|
213 |
+
self.conv = conv
|
214 |
+
elif name == "Conv2d_0":
|
215 |
+
self.conv = conv
|
216 |
+
else:
|
217 |
+
self.conv = conv
|
218 |
+
|
219 |
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
220 |
+
|
221 |
+
assert hidden_states.shape[1] == self.channels
|
222 |
+
|
223 |
+
if self.norm is not None:
|
224 |
+
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
|
225 |
+
0, 3, 1, 2
|
226 |
+
)
|
227 |
+
|
228 |
+
if self.use_conv and self.padding == 0:
|
229 |
+
pad = (0, 1, 0, 1)
|
230 |
+
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
231 |
+
|
232 |
+
assert hidden_states.shape[1] == self.channels
|
233 |
+
|
234 |
+
hidden_states = self.conv(hidden_states)
|
235 |
+
|
236 |
+
return hidden_states
|
237 |
+
|
238 |
+
|
239 |
+
class Upsample2D(nn.Module):
|
240 |
+
"""A 2D upsampling layer with an optional convolution.
|
241 |
+
|
242 |
+
Parameters:
|
243 |
+
channels (`int`):
|
244 |
+
number of channels in the inputs and outputs.
|
245 |
+
use_conv (`bool`, default `False`):
|
246 |
+
option to use a convolution.
|
247 |
+
use_conv_transpose (`bool`, default `False`):
|
248 |
+
option to use a convolution transpose.
|
249 |
+
out_channels (`int`, optional):
|
250 |
+
number of output channels. Defaults to `channels`.
|
251 |
+
name (`str`, default `conv`):
|
252 |
+
name of the upsampling 2D layer.
|
253 |
+
"""
|
254 |
+
|
255 |
+
def __init__(
|
256 |
+
self,
|
257 |
+
channels: int,
|
258 |
+
use_conv: bool = False,
|
259 |
+
use_conv_transpose: bool = False,
|
260 |
+
out_channels: Optional[int] = None,
|
261 |
+
name: str = "conv",
|
262 |
+
kernel_size: Optional[int] = None,
|
263 |
+
padding=1,
|
264 |
+
stride=2,
|
265 |
+
norm_type=None,
|
266 |
+
eps=None,
|
267 |
+
elementwise_affine=None,
|
268 |
+
bias=True,
|
269 |
+
interpolate=True,
|
270 |
+
):
|
271 |
+
super().__init__()
|
272 |
+
self.channels = channels
|
273 |
+
self.out_channels = out_channels or channels
|
274 |
+
self.use_conv = use_conv
|
275 |
+
self.use_conv_transpose = use_conv_transpose
|
276 |
+
self.name = name
|
277 |
+
self.interpolate = interpolate
|
278 |
+
self.stride = stride
|
279 |
+
|
280 |
+
if norm_type == "ln_norm":
|
281 |
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
282 |
+
elif norm_type == "rms_norm":
|
283 |
+
self.norm = RMSNorm(channels, eps)
|
284 |
+
elif norm_type is None:
|
285 |
+
self.norm = None
|
286 |
+
else:
|
287 |
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
288 |
+
|
289 |
+
conv = None
|
290 |
+
if use_conv_transpose:
|
291 |
+
if kernel_size is None:
|
292 |
+
kernel_size = 4
|
293 |
+
conv = nn.ConvTranspose2d(
|
294 |
+
channels,
|
295 |
+
self.out_channels,
|
296 |
+
kernel_size=kernel_size,
|
297 |
+
stride=stride,
|
298 |
+
padding=padding,
|
299 |
+
bias=bias,
|
300 |
+
)
|
301 |
+
elif use_conv:
|
302 |
+
if kernel_size is None:
|
303 |
+
kernel_size = 3
|
304 |
+
conv = nn.Conv2d(
|
305 |
+
self.channels,
|
306 |
+
self.out_channels,
|
307 |
+
kernel_size=kernel_size,
|
308 |
+
padding=padding,
|
309 |
+
bias=bias,
|
310 |
+
)
|
311 |
+
|
312 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
313 |
+
if name == "conv":
|
314 |
+
self.conv = conv
|
315 |
+
else:
|
316 |
+
self.Conv2d_0 = conv
|
317 |
+
|
318 |
+
def forward(
|
319 |
+
self,
|
320 |
+
hidden_states: torch.Tensor,
|
321 |
+
output_size: Optional[int] = None,
|
322 |
+
*args,
|
323 |
+
**kwargs,
|
324 |
+
) -> torch.Tensor:
|
325 |
+
|
326 |
+
assert hidden_states.shape[1] == self.channels
|
327 |
+
|
328 |
+
if self.norm is not None:
|
329 |
+
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
|
330 |
+
0, 3, 1, 2
|
331 |
+
)
|
332 |
+
|
333 |
+
if self.use_conv_transpose:
|
334 |
+
return self.conv(hidden_states)
|
335 |
+
|
336 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
337 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
338 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
339 |
+
dtype = hidden_states.dtype
|
340 |
+
if dtype == torch.bfloat16:
|
341 |
+
hidden_states = hidden_states.to(torch.float32)
|
342 |
+
|
343 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
344 |
+
if hidden_states.shape[0] >= 64:
|
345 |
+
hidden_states = hidden_states.contiguous()
|
346 |
+
|
347 |
+
# if `output_size` is passed we force the interpolation output
|
348 |
+
# size and do not make use of `scale_factor=2`
|
349 |
+
if self.interpolate:
|
350 |
+
if output_size is None:
|
351 |
+
hidden_states = F.interpolate(
|
352 |
+
hidden_states, scale_factor=self.stride, mode="nearest"
|
353 |
+
)
|
354 |
+
else:
|
355 |
+
hidden_states = F.interpolate(
|
356 |
+
hidden_states, size=output_size, mode="nearest"
|
357 |
+
)
|
358 |
+
|
359 |
+
# If the input is bfloat16, we cast back to bfloat16
|
360 |
+
if dtype == torch.bfloat16:
|
361 |
+
hidden_states = hidden_states.to(dtype)
|
362 |
+
|
363 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
364 |
+
if self.use_conv:
|
365 |
+
if self.name == "conv":
|
366 |
+
hidden_states = self.conv(hidden_states)
|
367 |
+
else:
|
368 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
369 |
+
|
370 |
+
return hidden_states
|
371 |
+
|
372 |
+
|
373 |
+
class ConvNextBlock(nn.Module):
|
374 |
+
def __init__(
|
375 |
+
self,
|
376 |
+
channels,
|
377 |
+
norm_eps,
|
378 |
+
elementwise_affine,
|
379 |
+
use_bias,
|
380 |
+
hidden_dropout,
|
381 |
+
hidden_size,
|
382 |
+
res_ffn_factor: int = 4,
|
383 |
+
):
|
384 |
+
super().__init__()
|
385 |
+
self.depthwise = nn.Conv2d(
|
386 |
+
channels,
|
387 |
+
channels,
|
388 |
+
kernel_size=7,
|
389 |
+
padding=3,
|
390 |
+
groups=channels,
|
391 |
+
bias=use_bias,
|
392 |
+
)
|
393 |
+
self.norm = RMSNorm(channels, norm_eps)
|
394 |
+
self.channelwise_linear_1 = nn.Linear(
|
395 |
+
channels, int(channels * res_ffn_factor), bias=use_bias
|
396 |
+
)
|
397 |
+
self.channelwise_act = nn.GELU()
|
398 |
+
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
|
399 |
+
self.channelwise_linear_2 = nn.Linear(
|
400 |
+
int(channels * res_ffn_factor), channels, bias=use_bias
|
401 |
+
)
|
402 |
+
self.channelwise_dropout = nn.Dropout(hidden_dropout)
|
403 |
+
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
|
404 |
+
|
405 |
+
def forward(self, x, cond_embeds):
|
406 |
+
x_res = x
|
407 |
+
|
408 |
+
x = self.depthwise(x)
|
409 |
+
|
410 |
+
x = x.permute(0, 2, 3, 1)
|
411 |
+
x = self.norm(x)
|
412 |
+
x = self.channelwise_linear_1(x)
|
413 |
+
x = self.channelwise_act(x)
|
414 |
+
x = self.channelwise_norm(x)
|
415 |
+
x = self.channelwise_linear_2(x)
|
416 |
+
x = self.channelwise_dropout(x)
|
417 |
+
x = x.permute(0, 3, 1, 2)
|
418 |
+
|
419 |
+
x = x + x_res
|
420 |
+
|
421 |
+
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
|
422 |
+
# x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
|
423 |
+
x = torch.addcmul(
|
424 |
+
shift[:, :, None, None], x, (1 + scale)[:, :, None, None], value=1
|
425 |
+
)
|
426 |
+
|
427 |
+
return x
|
428 |
+
|
429 |
+
|
430 |
+
class Patchify(nn.Module):
|
431 |
+
def __init__(
|
432 |
+
self,
|
433 |
+
in_channels,
|
434 |
+
block_out_channels,
|
435 |
+
patch_size,
|
436 |
+
bias,
|
437 |
+
elementwise_affine,
|
438 |
+
eps,
|
439 |
+
kernel_size=None,
|
440 |
+
):
|
441 |
+
super().__init__()
|
442 |
+
if kernel_size is None:
|
443 |
+
kernel_size = patch_size
|
444 |
+
self.patch_conv = nn.Conv2d(
|
445 |
+
in_channels,
|
446 |
+
block_out_channels,
|
447 |
+
kernel_size=kernel_size,
|
448 |
+
stride=patch_size,
|
449 |
+
bias=bias,
|
450 |
+
)
|
451 |
+
self.norm = RMSNorm(block_out_channels, eps)
|
452 |
+
|
453 |
+
def forward(self, x):
|
454 |
+
embeddings = self.patch_conv(x)
|
455 |
+
embeddings = embeddings.permute(0, 2, 3, 1)
|
456 |
+
embeddings = self.norm(embeddings)
|
457 |
+
embeddings = embeddings.permute(0, 3, 1, 2)
|
458 |
+
return embeddings
|
459 |
+
|
460 |
+
|
461 |
+
class Unpatchify(nn.Module):
|
462 |
+
def __init__(
|
463 |
+
self, in_channels, out_channels, patch_size, bias, elementwise_affine, eps
|
464 |
+
):
|
465 |
+
super().__init__()
|
466 |
+
self.norm = RMSNorm(in_channels, eps)
|
467 |
+
self.unpatch_conv = nn.Conv2d(
|
468 |
+
in_channels,
|
469 |
+
out_channels * patch_size * patch_size,
|
470 |
+
kernel_size=1,
|
471 |
+
bias=bias,
|
472 |
+
)
|
473 |
+
self.pixel_shuffle = nn.PixelShuffle(patch_size)
|
474 |
+
self.patch_size = patch_size
|
475 |
+
|
476 |
+
def forward(self, x):
|
477 |
+
# [b, c, h, w]
|
478 |
+
x = x.permute(0, 2, 3, 1)
|
479 |
+
x = self.norm(x)
|
480 |
+
x = x.permute(0, 3, 1, 2)
|
481 |
+
x = self.unpatch_conv(x)
|
482 |
+
x = self.pixel_shuffle(x)
|
483 |
+
return x
|
484 |
+
|
485 |
+
|
486 |
+
class UVitBlock(nn.Module):
|
487 |
+
def __init__(
|
488 |
+
self,
|
489 |
+
channels,
|
490 |
+
out_channels,
|
491 |
+
num_res_blocks,
|
492 |
+
stride,
|
493 |
+
hidden_size,
|
494 |
+
hidden_dropout,
|
495 |
+
elementwise_affine,
|
496 |
+
norm_eps,
|
497 |
+
use_bias,
|
498 |
+
downsample: bool,
|
499 |
+
upsample: bool,
|
500 |
+
res_ffn_factor: int = 4,
|
501 |
+
seq_len=None,
|
502 |
+
concat_input=False,
|
503 |
+
original_input_channels=None,
|
504 |
+
use_zero=True,
|
505 |
+
norm_type="RMS",
|
506 |
+
):
|
507 |
+
super().__init__()
|
508 |
+
|
509 |
+
self.res_blocks = nn.ModuleList()
|
510 |
+
for i in range(num_res_blocks):
|
511 |
+
conv_block = ConvNextBlock(
|
512 |
+
channels,
|
513 |
+
norm_eps,
|
514 |
+
elementwise_affine,
|
515 |
+
use_bias,
|
516 |
+
hidden_dropout,
|
517 |
+
hidden_size,
|
518 |
+
res_ffn_factor=res_ffn_factor,
|
519 |
+
)
|
520 |
+
|
521 |
+
self.res_blocks.append(conv_block)
|
522 |
+
|
523 |
+
if downsample:
|
524 |
+
self.downsample = Downsample2D(
|
525 |
+
channels=channels,
|
526 |
+
out_channels=out_channels,
|
527 |
+
use_conv=True,
|
528 |
+
name="Conv2d_0",
|
529 |
+
kernel_size=3,
|
530 |
+
padding=1,
|
531 |
+
stride=stride,
|
532 |
+
norm_type="rms_norm",
|
533 |
+
eps=norm_eps,
|
534 |
+
elementwise_affine=elementwise_affine,
|
535 |
+
bias=use_bias,
|
536 |
+
)
|
537 |
+
else:
|
538 |
+
self.downsample = None
|
539 |
+
|
540 |
+
if upsample:
|
541 |
+
self.upsample = Upsample2D(
|
542 |
+
channels=channels,
|
543 |
+
out_channels=out_channels,
|
544 |
+
use_conv_transpose=False,
|
545 |
+
use_conv=True,
|
546 |
+
kernel_size=3,
|
547 |
+
padding=1,
|
548 |
+
stride=stride,
|
549 |
+
name="conv",
|
550 |
+
norm_type="rms_norm",
|
551 |
+
eps=norm_eps,
|
552 |
+
elementwise_affine=elementwise_affine,
|
553 |
+
bias=use_bias,
|
554 |
+
interpolate=True,
|
555 |
+
)
|
556 |
+
else:
|
557 |
+
self.upsample = None
|
558 |
+
|
559 |
+
def forward(self, x, emb, recompute=False):
|
560 |
+
for res_block in self.res_blocks:
|
561 |
+
x = res_block(x, emb)
|
562 |
+
|
563 |
+
if self.downsample is not None:
|
564 |
+
x = self.downsample(x)
|
565 |
+
|
566 |
+
if self.upsample is not None:
|
567 |
+
x = self.upsample(x)
|
568 |
+
|
569 |
+
return x
|
570 |
+
|
571 |
+
|
572 |
+
class ShallowUViTEncoder(nn.Module):
|
573 |
+
def __init__(
|
574 |
+
self,
|
575 |
+
input_channels=3,
|
576 |
+
stride=4,
|
577 |
+
kernel_size=7,
|
578 |
+
padding=None,
|
579 |
+
block_out_channels=(768,),
|
580 |
+
layers_in_middle=2,
|
581 |
+
hidden_size=2048,
|
582 |
+
elementwise_affine=True,
|
583 |
+
use_bias=True,
|
584 |
+
norm_eps=1e-6,
|
585 |
+
dropout=0.0,
|
586 |
+
use_mid_block=True,
|
587 |
+
**kwargs,
|
588 |
+
):
|
589 |
+
super().__init__()
|
590 |
+
|
591 |
+
self.time_proj = Timesteps(
|
592 |
+
block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
|
593 |
+
)
|
594 |
+
self.time_embed = TimestepEmbedding(
|
595 |
+
block_out_channels[0], hidden_size, sample_proj_bias=use_bias
|
596 |
+
)
|
597 |
+
|
598 |
+
if padding is None:
|
599 |
+
padding = math.ceil(kernel_size - stride)
|
600 |
+
self.in_conv = nn.Conv2d(
|
601 |
+
in_channels=input_channels,
|
602 |
+
out_channels=block_out_channels[0],
|
603 |
+
kernel_size=kernel_size,
|
604 |
+
stride=stride,
|
605 |
+
padding=padding,
|
606 |
+
)
|
607 |
+
if use_mid_block:
|
608 |
+
self.mid_block = UVitBlock(
|
609 |
+
block_out_channels[-1],
|
610 |
+
block_out_channels[-1],
|
611 |
+
num_res_blocks=layers_in_middle,
|
612 |
+
hidden_size=hidden_size,
|
613 |
+
hidden_dropout=dropout,
|
614 |
+
elementwise_affine=elementwise_affine,
|
615 |
+
norm_eps=norm_eps,
|
616 |
+
use_bias=use_bias,
|
617 |
+
downsample=False,
|
618 |
+
upsample=False,
|
619 |
+
stride=1,
|
620 |
+
res_ffn_factor=4,
|
621 |
+
)
|
622 |
+
else:
|
623 |
+
self.mid_block = None
|
624 |
+
|
625 |
+
def get_num_extra_tensors(self):
|
626 |
+
return 2
|
627 |
+
|
628 |
+
def forward(self, x, timesteps):
|
629 |
+
|
630 |
+
bs = x.shape[0]
|
631 |
+
dtype = x.dtype
|
632 |
+
|
633 |
+
t_emb = self.time_proj(timesteps.flatten()).view(bs, -1).to(dtype)
|
634 |
+
t_emb = self.time_embed(t_emb)
|
635 |
+
x_emb = self.in_conv(x)
|
636 |
+
|
637 |
+
if self.mid_block is not None:
|
638 |
+
x_emb = self.mid_block(x_emb, t_emb)
|
639 |
+
|
640 |
+
hs = [x_emb]
|
641 |
+
return x_emb, t_emb, hs
|
642 |
+
|
643 |
+
|
644 |
+
class ShallowUViTDecoder(nn.Module):
|
645 |
+
def __init__(
|
646 |
+
self,
|
647 |
+
in_channels=768,
|
648 |
+
out_channels=3,
|
649 |
+
block_out_channels: Tuple[int] = (768,),
|
650 |
+
upsamples=2,
|
651 |
+
layers_in_middle=2,
|
652 |
+
hidden_size=2048,
|
653 |
+
elementwise_affine=True,
|
654 |
+
norm_eps=1e-6,
|
655 |
+
use_bias=True,
|
656 |
+
dropout=0.0,
|
657 |
+
use_mid_block=True,
|
658 |
+
**kwargs,
|
659 |
+
):
|
660 |
+
super().__init__()
|
661 |
+
if use_mid_block:
|
662 |
+
self.mid_block = UVitBlock(
|
663 |
+
in_channels + block_out_channels[-1],
|
664 |
+
block_out_channels[
|
665 |
+
-1
|
666 |
+
], # In fact, the parameter is not used because it has no effect when both downsample and upsample are set to false.
|
667 |
+
num_res_blocks=layers_in_middle,
|
668 |
+
hidden_size=hidden_size,
|
669 |
+
hidden_dropout=dropout,
|
670 |
+
elementwise_affine=elementwise_affine,
|
671 |
+
norm_eps=norm_eps,
|
672 |
+
use_bias=use_bias,
|
673 |
+
downsample=False,
|
674 |
+
upsample=False,
|
675 |
+
stride=1,
|
676 |
+
res_ffn_factor=4,
|
677 |
+
)
|
678 |
+
else:
|
679 |
+
self.mid_block = None
|
680 |
+
self.out_convs = nn.ModuleList()
|
681 |
+
for rank in range(upsamples):
|
682 |
+
if rank == upsamples - 1:
|
683 |
+
curr_out_channels = out_channels
|
684 |
+
else:
|
685 |
+
curr_out_channels = block_out_channels[-1]
|
686 |
+
if rank == 0:
|
687 |
+
curr_in_channels = block_out_channels[-1] + in_channels
|
688 |
+
else:
|
689 |
+
curr_in_channels = block_out_channels[-1]
|
690 |
+
self.out_convs.append(
|
691 |
+
Unpatchify(
|
692 |
+
curr_in_channels,
|
693 |
+
curr_out_channels,
|
694 |
+
patch_size=2,
|
695 |
+
bias=use_bias,
|
696 |
+
elementwise_affine=elementwise_affine,
|
697 |
+
eps=norm_eps,
|
698 |
+
)
|
699 |
+
)
|
700 |
+
self.input_norm = RMSNorm(in_channels, norm_eps)
|
701 |
+
|
702 |
+
def forward(self, x, hs, t_emb):
|
703 |
+
|
704 |
+
x = x.permute(0, 2, 3, 1)
|
705 |
+
x = self.input_norm(x)
|
706 |
+
x = x.permute(0, 3, 1, 2)
|
707 |
+
|
708 |
+
x = torch.cat([x, hs.pop()], dim=1)
|
709 |
+
if self.mid_block is not None:
|
710 |
+
x = self.mid_block(x, t_emb)
|
711 |
+
for out_conv in self.out_convs:
|
712 |
+
x = out_conv(x)
|
713 |
+
assert len(hs) == 0
|
714 |
+
return x
|
janus/models/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024 DeepSeek.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
# this software and associated documentation files (the "Software"), to deal in
|
5 |
+
# the Software without restriction, including without limitation the rights to
|
6 |
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
7 |
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
8 |
+
# subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19 |
+
|
20 |
+
from .image_processing_vlm import VLMImageProcessor
|
21 |
+
from .modeling_vlm import MultiModalityCausalLM
|
22 |
+
from .processing_vlm import VLChatProcessor
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
"VLMImageProcessor",
|
26 |
+
"VLChatProcessor",
|
27 |
+
"MultiModalityCausalLM",
|
28 |
+
]
|