Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
da5e325
1
Parent(s):
6eddb24
Update app.py and requirements.txt
Browse files- app.py +15 -0
- requirements.txt +0 -14
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import shutil
|
|
|
2 |
import time
|
3 |
from pathlib import Path
|
4 |
from typing import Tuple
|
@@ -10,6 +11,20 @@ import torch
|
|
10 |
import yaml
|
11 |
from box import Box
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from src.data.datapath import Datapath
|
14 |
from src.data.dataset import DatasetConfig, UniRigDatasetModule
|
15 |
from src.data.extract import extract_builtin, get_files
|
|
|
1 |
import shutil
|
2 |
+
import subprocess
|
3 |
import time
|
4 |
from pathlib import Path
|
5 |
from typing import Tuple
|
|
|
11 |
import yaml
|
12 |
from box import Box
|
13 |
|
14 |
+
# Get the PyTorch and CUDA versions
|
15 |
+
torch_version = torch.__version__.split("+")[0] # Strips any "+cuXXX" suffix
|
16 |
+
cuda_version = torch.version.cuda
|
17 |
+
spconv_version = "-cu121" if cuda_version else ""
|
18 |
+
|
19 |
+
# Format CUDA version to match the URL convention (e.g., "cu118" for CUDA 11.8)
|
20 |
+
if cuda_version:
|
21 |
+
cuda_version = f"cu{cuda_version.replace('.', '')}"
|
22 |
+
else:
|
23 |
+
cuda_version = "cpu" # Fallback in case CUDA is not available
|
24 |
+
|
25 |
+
subprocess.run(f'pip install spconv{spconv_version}', shell=True)
|
26 |
+
subprocess.run(f'pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html --no-cache-dir', shell=True)
|
27 |
+
|
28 |
from src.data.datapath import Datapath
|
29 |
from src.data.dataset import DatasetConfig, UniRigDatasetModule
|
30 |
from src.data.extract import extract_builtin, get_files
|
requirements.txt
CHANGED
@@ -17,18 +17,4 @@ scipy
|
|
17 |
matplotlib
|
18 |
plotly
|
19 |
pyyaml
|
20 |
-
# PyTorch and related packages - ensure compatibility
|
21 |
-
torch==2.7.0
|
22 |
-
torchvision
|
23 |
-
torchaudio
|
24 |
-
# PyTorch Geometric ecosystem packages
|
25 |
-
--find-links https://data.pyg.org/whl/torch-2.7.0+cu126.html
|
26 |
-
torch-scatter
|
27 |
-
torch-sparse
|
28 |
-
torch-cluster
|
29 |
-
torch-spline-conv
|
30 |
-
torch-geometric
|
31 |
-
# Flash attention for PyTorch 2.7
|
32 |
https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.7-cp310-cp310-linux_x86_64.whl
|
33 |
-
# Sparse convolution
|
34 |
-
spconv-cu121
|
|
|
17 |
matplotlib
|
18 |
plotly
|
19 |
pyyaml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.7-cp310-cp310-linux_x86_64.whl
|
|
|
|