jesse-tong commited on
Commit
251e9cd
·
1 Parent(s): 466b8a2

Add layer normalization to LSTM model

Browse files
Files changed (1) hide show
  1. models/lstm_model.py +25 -17
models/lstm_model.py CHANGED
@@ -2,10 +2,13 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
 
 
 
 
5
  class DocumentBiLSTM(nn.Module):
6
  """
7
- A simpler BiLSTM implementation that doesn't require pre-loaded embeddings
8
- Good for getting started quickly
9
  """
10
  def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,
11
  n_layers=2, dropout=0.5, pad_idx=0):
@@ -20,6 +23,9 @@ class DocumentBiLSTM(nn.Module):
20
  dropout=dropout if n_layers > 1 else 0,
21
  batch_first=True)
22
 
 
 
 
23
  self.fc = nn.Linear(hidden_dim * 2, output_dim)
24
 
25
  self.dropout = nn.Dropout(dropout)
@@ -33,44 +39,46 @@ class DocumentBiLSTM(nn.Module):
33
  # Apply dropout to embeddings
34
  embedded = self.dropout(embedded)
35
 
 
 
 
 
36
  if attention_mask is not None:
37
  # Convert attention mask to sequence lengths
38
- # First, get the length of each sequence by summing the attention mask
39
  seq_lengths = attention_mask.sum(dim=1).to(torch.int64).cpu()
40
 
41
- # Sort sequences by decreasing length for pack_padded_sequence
42
  seq_lengths, indices = torch.sort(seq_lengths, descending=True)
43
- embedded = embedded[indices]
44
 
45
  # Pack the embedded sequences
46
  packed_embedded = nn.utils.rnn.pack_padded_sequence(
47
- embedded, seq_lengths, batch_first=True, enforce_sorted=True
48
  )
49
 
50
- # Pass the packed sequence through LSTM
51
  packed_output, (hidden, cell) = self.lstm(packed_embedded)
52
 
53
  # Unpack the sequence
54
  output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
55
 
56
- # Restore the original batch order
57
  _, restore_indices = torch.sort(indices)
 
58
  else:
59
  # Standard processing without masking
60
- output, (hidden, cell) = self.lstm(embedded)
61
-
62
- # output = [batch size, seq len, hid dim * num directions]
63
- # hidden = [n layers * num directions, batch size, hid dim]
64
- # cell = [n layers * num directions, batch size, hid dim]
65
- output, (hidden, cell) = self.lstm(embedded)
66
 
67
  # Concatenate the final forward and backward hidden states
68
- hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
 
 
 
69
 
70
  # Apply dropout to hidden state
71
- hidden = self.dropout(hidden)
72
 
73
  # prediction = [batch size, output dim]
74
- prediction = self.fc(hidden)
75
 
76
  return prediction
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
  class DocumentBiLSTM(nn.Module):
10
  """
11
+ BiLSTM implementation with stability improvements inspired by DocBERT
 
12
  """
13
  def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,
14
  n_layers=2, dropout=0.5, pad_idx=0):
 
23
  dropout=dropout if n_layers > 1 else 0,
24
  batch_first=True)
25
 
26
+ # Add layer normalization for stability (like in DocBERT)
27
+ self.layer_norm = nn.LayerNorm(hidden_dim * 2)
28
+
29
  self.fc = nn.Linear(hidden_dim * 2, output_dim)
30
 
31
  self.dropout = nn.Dropout(dropout)
 
39
  # Apply dropout to embeddings
40
  embedded = self.dropout(embedded)
41
 
42
+ # Initialize hidden and cell variables
43
+ hidden = None
44
+ cell = None
45
+
46
  if attention_mask is not None:
47
  # Convert attention mask to sequence lengths
 
48
  seq_lengths = attention_mask.sum(dim=1).to(torch.int64).cpu()
49
 
50
+ # Sort sequences by decreasing length
51
  seq_lengths, indices = torch.sort(seq_lengths, descending=True)
52
+ sorted_embedded = embedded[indices]
53
 
54
  # Pack the embedded sequences
55
  packed_embedded = nn.utils.rnn.pack_padded_sequence(
56
+ sorted_embedded, seq_lengths, batch_first=True, enforce_sorted=True
57
  )
58
 
59
+ # Pass through LSTM
60
  packed_output, (hidden, cell) = self.lstm(packed_embedded)
61
 
62
  # Unpack the sequence
63
  output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
64
 
65
+ # Get the hidden states in correct order
66
  _, restore_indices = torch.sort(indices)
67
+ hidden = hidden[:, restore_indices]
68
  else:
69
  # Standard processing without masking
70
+ _, (hidden, cell) = self.lstm(embedded)
 
 
 
 
 
71
 
72
  # Concatenate the final forward and backward hidden states
73
+ hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
74
+
75
+ # Apply layer normalization (improves stability)
76
+ normalized = self.layer_norm(hidden_cat)
77
 
78
  # Apply dropout to hidden state
79
+ dropped = self.dropout(normalized)
80
 
81
  # prediction = [batch size, output dim]
82
+ prediction = self.fc(dropped)
83
 
84
  return prediction