MohamedRashad commited on
Commit
da5e325
·
1 Parent(s): 6eddb24

Update app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +15 -0
  2. requirements.txt +0 -14
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import shutil
 
2
  import time
3
  from pathlib import Path
4
  from typing import Tuple
@@ -10,6 +11,20 @@ import torch
10
  import yaml
11
  from box import Box
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from src.data.datapath import Datapath
14
  from src.data.dataset import DatasetConfig, UniRigDatasetModule
15
  from src.data.extract import extract_builtin, get_files
 
1
  import shutil
2
+ import subprocess
3
  import time
4
  from pathlib import Path
5
  from typing import Tuple
 
11
  import yaml
12
  from box import Box
13
 
14
+ # Get the PyTorch and CUDA versions
15
+ torch_version = torch.__version__.split("+")[0] # Strips any "+cuXXX" suffix
16
+ cuda_version = torch.version.cuda
17
+ spconv_version = "-cu121" if cuda_version else ""
18
+
19
+ # Format CUDA version to match the URL convention (e.g., "cu118" for CUDA 11.8)
20
+ if cuda_version:
21
+ cuda_version = f"cu{cuda_version.replace('.', '')}"
22
+ else:
23
+ cuda_version = "cpu" # Fallback in case CUDA is not available
24
+
25
+ subprocess.run(f'pip install spconv{spconv_version}', shell=True)
26
+ subprocess.run(f'pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html --no-cache-dir', shell=True)
27
+
28
  from src.data.datapath import Datapath
29
  from src.data.dataset import DatasetConfig, UniRigDatasetModule
30
  from src.data.extract import extract_builtin, get_files
requirements.txt CHANGED
@@ -17,18 +17,4 @@ scipy
17
  matplotlib
18
  plotly
19
  pyyaml
20
- # PyTorch and related packages - ensure compatibility
21
- torch==2.7.0
22
- torchvision
23
- torchaudio
24
- # PyTorch Geometric ecosystem packages
25
- --find-links https://data.pyg.org/whl/torch-2.7.0+cu126.html
26
- torch-scatter
27
- torch-sparse
28
- torch-cluster
29
- torch-spline-conv
30
- torch-geometric
31
- # Flash attention for PyTorch 2.7
32
  https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.7-cp310-cp310-linux_x86_64.whl
33
- # Sparse convolution
34
- spconv-cu121
 
17
  matplotlib
18
  plotly
19
  pyyaml
 
 
 
 
 
 
 
 
 
 
 
 
20
  https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.7-cp310-cp310-linux_x86_64.whl