File size: 3,565 Bytes
364cb51
6ee09b1
9fbf2d1
 
 
08d30fe
9fbf2d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364cb51
185e14a
9fbf2d1
364cb51
185e14a
2062b4d
185e14a
2062b4d
185e14a
364cb51
a24564e
185e14a
9fbf2d1
185e14a
9fbf2d1
4bd6ecd
185e14a
9fbf2d1
 
364cb51
9fbf2d1
 
 
 
 
 
 
 
 
 
5e302e0
9fbf2d1
 
 
 
364cb51
9fbf2d1
 
4bd6ecd
2062b4d
9fbf2d1
 
 
 
 
 
 
185e14a
9fbf2d1
 
 
 
 
 
185e14a
9fbf2d1
 
a2a8e37
9fbf2d1
4bd6ecd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
from snac import SNAC

def redistribute_codes(row):
    """
    Convert a sequence of token codes into an audio waveform using SNAC.
    The code assumes each 7 tokens represent one group of instructions.
    """
    row_length = row.size(0)
    new_length = (row_length // 7) * 7
    trimmed_row = row[:new_length]
    code_list = [t - 128266 for t in trimmed_row]
    
    layer_1, layer_2, layer_3 = [], [], []
    
    for i in range((len(code_list) + 1) // 7):
        layer_1.append(code_list[7 * i][None])
        layer_2.append(code_list[7 * i + 1][None] - 4096)
        layer_3.append(code_list[7 * i + 2][None] - (2 * 4096))
        layer_3.append(code_list[7 * i + 3][None] - (3 * 4096))
        layer_2.append(code_list[7 * i + 4][None] - (4 * 4096))
        layer_3.append(code_list[7 * i + 5][None] - (5 * 4096))
        layer_3.append(code_list[7 * i + 6][None] - (6 * 4096))
    
    with torch.no_grad():
        codes = [
            torch.concat(layer_1),
            torch.concat(layer_2),
            torch.concat(layer_3)
        ]
        for i in range(len(codes)):
            codes[i][codes[i] < 0] = 0
            codes[i] = codes[i][None]
        
        audio_hat = snac_model.decode(codes)
        return audio_hat.cpu()[0, 0]

# Load the SNAC model for audio decoding
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to("cuda")

# Load the single-speaker language model
tokenizer = AutoTokenizer.from_pretrained('prithivMLmods/Llama-3B-Mono-Jim')
model = AutoModelForCausalLM.from_pretrained(
    'prithivMLmods/Llama-3B-Mono-Jim', torch_dtype=torch.bfloat16
).cuda()

@spaces.GPU
def generate_audio(text, temperature, top_p, max_new_tokens):
    """
    Given input text, generate speech audio.
    """
    speaker = "Jim"
    prompt = f'<custom_token_3><|begin_of_text|>{speaker}: {text}<|eot_id|><custom_token_4><custom_token_5><custom_token_1>'
    input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').to('cuda')
    
    with torch.no_grad():
        generated_ids = model.generate(
            **input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=1.1,
            num_return_sequences=1,
            eos_token_id=128258,
        )
    
    row = generated_ids[0, input_ids['input_ids'].shape[1]:]
    y_tensor = redistribute_codes(row)
    y_np = y_tensor.detach().cpu().numpy()
    return (24000, y_np)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Llama-3B-Mono-Jim - Single Speaker Audio Generation")
    gr.Markdown("Generate speech audio using the `prithivMLmods/Llama-3B-Mono-Jim` model.")
    
    with gr.Row():
        text_input = gr.Textbox(lines=4, label="Input Text")
    
    with gr.Row():
        temp_slider = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.9, label="Temperature")
        top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.8, label="Top-p")
        tokens_slider = gr.Slider(minimum=100, maximum=2000, step=50, value=1200, label="Max New Tokens")
    
    output_audio = gr.Audio(type="numpy", label="Generated Audio")
    generate_button = gr.Button("Generate Audio")
    
    generate_button.click(
        fn=generate_audio,
        inputs=[text_input, temp_slider, top_p_slider, tokens_slider],
        outputs=output_audio
    )

if __name__ == "__main__":
    demo.launch()