File size: 3,582 Bytes
364cb51
6ee09b1
9fbf2d1
 
 
08d30fe
96784fc
9fbf2d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364cb51
185e14a
9fbf2d1
364cb51
185e14a
96784fc
185e14a
96784fc
185e14a
364cb51
a24564e
185e14a
9fbf2d1
185e14a
9fbf2d1
96784fc
185e14a
9fbf2d1
 
364cb51
9fbf2d1
 
 
 
 
 
 
 
 
 
5e302e0
9fbf2d1
 
 
 
364cb51
9fbf2d1
 
96784fc
 
9fbf2d1
 
 
 
 
 
 
185e14a
9fbf2d1
 
 
 
 
 
185e14a
9fbf2d1
 
a2a8e37
9fbf2d1
96784fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
from snac import SNAC


def redistribute_codes(row):
    """
    Convert a sequence of token codes into an audio waveform using SNAC.
    The code assumes each 7 tokens represent one group of instructions.
    """
    row_length = row.size(0)
    new_length = (row_length // 7) * 7
    trimmed_row = row[:new_length]
    code_list = [t - 128266 for t in trimmed_row]
    
    layer_1, layer_2, layer_3 = [], [], []
    
    for i in range((len(code_list) + 1) // 7):
        layer_1.append(code_list[7 * i][None])
        layer_2.append(code_list[7 * i + 1][None] - 4096)
        layer_3.append(code_list[7 * i + 2][None] - (2 * 4096))
        layer_3.append(code_list[7 * i + 3][None] - (3 * 4096))
        layer_2.append(code_list[7 * i + 4][None] - (4 * 4096))
        layer_3.append(code_list[7 * i + 5][None] - (5 * 4096))
        layer_3.append(code_list[7 * i + 6][None] - (6 * 4096))
    
    with torch.no_grad():
        codes = [
            torch.concat(layer_1),
            torch.concat(layer_2),
            torch.concat(layer_3)
        ]
        for i in range(len(codes)):
            codes[i][codes[i] < 0] = 0
            codes[i] = codes[i][None]
        
        audio_hat = snac_model.decode(codes)
        return audio_hat.cpu()[0, 0]

# Load the SNAC model for audio decoding
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to("cuda")

# Load the single-speaker language model
tokenizer = AutoTokenizer.from_pretrained('prithivMLmods/Llama-3B-Mono-Cooper')
model = AutoModelForCausalLM.from_pretrained(
    'prithivMLmods/Llama-3B-Mono-Cooper', torch_dtype=torch.bfloat16
).cuda()

@spaces.GPU
def generate_audio(text, temperature, top_p, max_new_tokens):
    """
    Given input text, generate speech audio.
    """
    speaker = "Cooper"
    prompt = f'<custom_token_3><|begin_of_text|>{speaker}: {text}<|eot_id|><custom_token_4><custom_token_5><custom_token_1>'
    input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').to('cuda')
    
    with torch.no_grad():
        generated_ids = model.generate(
            **input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=1.1,
            num_return_sequences=1,
            eos_token_id=128258,
        )
    
    row = generated_ids[0, input_ids['input_ids'].shape[1]:]
    y_tensor = redistribute_codes(row)
    y_np = y_tensor.detach().cpu().numpy()
    return (24000, y_np)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Llama-3B-Mono-Cooper - Single Speaker Audio Generation")
    gr.Markdown("Generate speech audio using the `prithivMLmods/Llama-3B-Mono-Cooper` model.")
    
    with gr.Row():
        text_input = gr.Textbox(lines=4, label="Input Text")
    
    with gr.Row():
        temp_slider = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.9, label="Temperature")
        top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.8, label="Top-p")
        tokens_slider = gr.Slider(minimum=100, maximum=2000, step=50, value=1200, label="Max New Tokens")
    
    output_audio = gr.Audio(type="numpy", label="Generated Audio")
    generate_button = gr.Button("Generate Audio")
    
    generate_button.click(
        fn=generate_audio,
        inputs=[text_input, temp_slider, top_p_slider, tokens_slider],
        outputs=output_audio
    )

if __name__ == "__main__":
    demo.launch()