File size: 3,195 Bytes
68e98e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math
from transformers import PretrainedConfig, PreTrainedModel
from model.latent_Recurrent import LatentRecurrentDepthLM

# Configuration for the Latent Recurrent Depth Model
class LatentRecurrentDepthConfig(PretrainedConfig):
    model_type = "latent_recurrent_depth"

    def __init__(self, vocab_size=50257, d_model=768, num_heads=12, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout = dropout


# Hugging Face-Compatible Model Wrapper
class LatentRecurrentDepthModel(PreTrainedModel):
    config_class = LatentRecurrentDepthConfig
    base_model_prefix = "latent_recurrent_depth"

    def __init__(self, config: LatentRecurrentDepthConfig):
        super().__init__(config)
        self.latent_model = LatentRecurrentDepthLM(config.vocab_size, config.d_model, config.num_heads, config.dropout)
        self.init_weights()

    def forward(self, input_ids: torch.Tensor, num_iterations: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        return self.latent_model(input_ids, num_iterations, mask)

    def generate(
        self,
        input_ids: torch.Tensor,
        max_length: int = 20,
        num_iterations: int = 3,
        temperature: float = 1.0,
        top_k: Optional[int] = 50,
    ) -> torch.Tensor:
        """
        Generate a sequence of tokens given input_ids.

        Args:
          input_ids: torch.Tensor of shape (batch, seq_length) containing the prompt.
          max_length: The number of tokens to generate.
          num_iterations: The number of recurrent iterations to use in each forward pass.
          temperature: Temperature for scaling logits.
          top_k: If set, only sample from the top k tokens.

        Returns:
          generated: torch.Tensor containing the generated sequence.
        """
        generated = input_ids.clone()
        self.eval()
        with torch.no_grad():
            for _ in range(max_length):
                # Get logits from the model for the current sequence.
                logits = self.forward(generated, num_iterations=num_iterations)
                # Use only the logits for the last token in the sequence.
                next_token_logits = logits[:, -1, :] / temperature
                if top_k is not None:
                    # Top-k filtering
                    top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
                    probabilities = F.softmax(top_k_logits, dim=-1)
                    next_token = top_k_indices.gather(-1, torch.multinomial(probabilities, num_samples=1))
                else:
                    probabilities = F.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probabilities, num_samples=1)
                generated = torch.cat([generated, next_token], dim=1)
                # Optionally, break if the EOS token is generated.
                if next_token.item() == self.config.eos_token_id:
                    break
        return generated