ARtOrias11 commited on
Commit
e737c3b
·
verified ·
1 Parent(s): 381c562

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -2,11 +2,22 @@ import gradio as gr
2
  import torch
3
  import torchaudio
4
  from torchaudio.transforms import Resample
 
5
 
6
- # Load MusicGen model
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
- model = torch.hub.load("facebookresearch/audiocraft", "musicgen", source="github")
9
- model.to(device)
 
 
 
 
 
 
 
 
 
 
10
 
11
  def generate_music(prompt, duration=10):
12
  try:
 
2
  import torch
3
  import torchaudio
4
  from torchaudio.transforms import Resample
5
+ import os
6
 
7
+ # Clone and load MusicGen model
8
+ def load_musicgen_model():
9
+ repo_path = "./audiocraft" # Local path to the cloned repo
10
+ if not os.path.exists(repo_path):
11
+ os.system("git clone https://github.com/facebookresearch/audiocraft.git")
12
+ if repo_path not in os.sys.path:
13
+ os.sys.path.append(repo_path)
14
+
15
+ from audiocraft.models import MusicGen
16
+ model = MusicGen.get_pretrained("small")
17
+ model.set_device("cuda" if torch.cuda.is_available() else "cpu")
18
+ return model
19
+
20
+ model = load_musicgen_model()
21
 
22
  def generate_music(prompt, duration=10):
23
  try: