File size: 5,227 Bytes
1c817fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
#from configs import GPT2Config, MBartConfig, CodeGenConfig, SummarizationConfig, OpenLRMConfig, UNet2DConditionModelConfig, AutoencoderKLConfig, BartConfig, MusicGenConfig
from configs import *
#from extensions import gelu, LayerNorm, Conv1D, Attention, MLP, Block, GPT2Model, GPT2LMHead, MBartEncoderLayer, MBartDecoderLayer, MBartEncoder, MBartDecoder, MBartModel, MBartForConditionalGeneration, CodeGenAttention, CodeGenBlock, CodeGenModel, CodeGenForCausalLM, SummarizationModel, OpenLRM, OpenLRMLayer, OpenLRMAttention, OpenLRMFeedForward, AutoencoderKL, Encoder_, Decoder_, DownBlock, UpBlock, ResnetBlock, MidBlock, Downsample2D, Upsample2D, UNet2DConditionModel, UNetMidBlock2DConditionModel, UNetDownBlock2DConditionModel, UNetUpBlock2DConditionModel, ResnetBlock2D, CrossAttentionBlock2D, CrossAttention, SimpleClassifier
from extensions import *

class SentimentClassifierModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(config.d_model * 2, 3)

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_embedded)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        pooled = output[:, -1, :]
        logits = self.fc(pooled)
        return logits

class STTModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=2, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.lstm = nn.LSTM(32 * (config.max_position_embeddings // 8), 128, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(128 * 2, config.vocab_size)

    def forward(self, audio_data):
        x = self.pool1(self.relu1(self.conv1(audio_data.unsqueeze(1))))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.transpose(1, 2).contiguous()
        x = x.view(x.size(0), -1, x.size(2))
        packed_output = nn.utils.rnn.pack_padded_sequence(x, lengths=[x.size(1)]*x.size(0), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_output)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        logits = self.fc(output)
        return logits

class TTSModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(config.d_model * 2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_embedded)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        logits = self.fc(output)
        audio = self.sigmoid(logits)
        return audio

class MusicGenModel(nn.Module):
    def __init__(self, config: MusicGenConfig):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.transformer_layers = nn.ModuleList([CodeGenBlock(config) for _ in range(config.num_hidden_layers)])
        self.fc_out = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids):
        embedded_tokens = self.embedding(input_ids)
        hidden_states = embedded_tokens
        for layer in self.transformer_layers:
            hidden_states = layer(hidden_states)
        logits = self.fc_out(hidden_states)
        return logits

    def sample(self, attributes, sample_rate, duration):
        input_tokens = torch.randint(0, self.config.vocab_size, (1, 1), dtype=torch.long).to(device)
        audio_output = []
        num_steps = int(duration * sample_rate / 1024)
        for _ in tqdm(range(num_steps), desc="Generating music"):
            logits = self.forward(input_tokens)
            predicted_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            audio_output.append(predicted_token.cpu())
            input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
        audio_output = torch.cat(audio_output, dim=1).float()
        return audio_output