Commit
·
251e9cd
1
Parent(s):
466b8a2
Add layer normalization to LSTM model
Browse files- models/lstm_model.py +25 -17
models/lstm_model.py
CHANGED
@@ -2,10 +2,13 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
|
|
|
|
|
|
|
|
|
5 |
class DocumentBiLSTM(nn.Module):
|
6 |
"""
|
7 |
-
|
8 |
-
Good for getting started quickly
|
9 |
"""
|
10 |
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,
|
11 |
n_layers=2, dropout=0.5, pad_idx=0):
|
@@ -20,6 +23,9 @@ class DocumentBiLSTM(nn.Module):
|
|
20 |
dropout=dropout if n_layers > 1 else 0,
|
21 |
batch_first=True)
|
22 |
|
|
|
|
|
|
|
23 |
self.fc = nn.Linear(hidden_dim * 2, output_dim)
|
24 |
|
25 |
self.dropout = nn.Dropout(dropout)
|
@@ -33,44 +39,46 @@ class DocumentBiLSTM(nn.Module):
|
|
33 |
# Apply dropout to embeddings
|
34 |
embedded = self.dropout(embedded)
|
35 |
|
|
|
|
|
|
|
|
|
36 |
if attention_mask is not None:
|
37 |
# Convert attention mask to sequence lengths
|
38 |
-
# First, get the length of each sequence by summing the attention mask
|
39 |
seq_lengths = attention_mask.sum(dim=1).to(torch.int64).cpu()
|
40 |
|
41 |
-
# Sort sequences by decreasing length
|
42 |
seq_lengths, indices = torch.sort(seq_lengths, descending=True)
|
43 |
-
|
44 |
|
45 |
# Pack the embedded sequences
|
46 |
packed_embedded = nn.utils.rnn.pack_padded_sequence(
|
47 |
-
|
48 |
)
|
49 |
|
50 |
-
# Pass
|
51 |
packed_output, (hidden, cell) = self.lstm(packed_embedded)
|
52 |
|
53 |
# Unpack the sequence
|
54 |
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
55 |
|
56 |
-
#
|
57 |
_, restore_indices = torch.sort(indices)
|
|
|
58 |
else:
|
59 |
# Standard processing without masking
|
60 |
-
|
61 |
-
|
62 |
-
# output = [batch size, seq len, hid dim * num directions]
|
63 |
-
# hidden = [n layers * num directions, batch size, hid dim]
|
64 |
-
# cell = [n layers * num directions, batch size, hid dim]
|
65 |
-
output, (hidden, cell) = self.lstm(embedded)
|
66 |
|
67 |
# Concatenate the final forward and backward hidden states
|
68 |
-
|
|
|
|
|
|
|
69 |
|
70 |
# Apply dropout to hidden state
|
71 |
-
|
72 |
|
73 |
# prediction = [batch size, output dim]
|
74 |
-
prediction = self.fc(
|
75 |
|
76 |
return prediction
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
class DocumentBiLSTM(nn.Module):
|
10 |
"""
|
11 |
+
BiLSTM implementation with stability improvements inspired by DocBERT
|
|
|
12 |
"""
|
13 |
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,
|
14 |
n_layers=2, dropout=0.5, pad_idx=0):
|
|
|
23 |
dropout=dropout if n_layers > 1 else 0,
|
24 |
batch_first=True)
|
25 |
|
26 |
+
# Add layer normalization for stability (like in DocBERT)
|
27 |
+
self.layer_norm = nn.LayerNorm(hidden_dim * 2)
|
28 |
+
|
29 |
self.fc = nn.Linear(hidden_dim * 2, output_dim)
|
30 |
|
31 |
self.dropout = nn.Dropout(dropout)
|
|
|
39 |
# Apply dropout to embeddings
|
40 |
embedded = self.dropout(embedded)
|
41 |
|
42 |
+
# Initialize hidden and cell variables
|
43 |
+
hidden = None
|
44 |
+
cell = None
|
45 |
+
|
46 |
if attention_mask is not None:
|
47 |
# Convert attention mask to sequence lengths
|
|
|
48 |
seq_lengths = attention_mask.sum(dim=1).to(torch.int64).cpu()
|
49 |
|
50 |
+
# Sort sequences by decreasing length
|
51 |
seq_lengths, indices = torch.sort(seq_lengths, descending=True)
|
52 |
+
sorted_embedded = embedded[indices]
|
53 |
|
54 |
# Pack the embedded sequences
|
55 |
packed_embedded = nn.utils.rnn.pack_padded_sequence(
|
56 |
+
sorted_embedded, seq_lengths, batch_first=True, enforce_sorted=True
|
57 |
)
|
58 |
|
59 |
+
# Pass through LSTM
|
60 |
packed_output, (hidden, cell) = self.lstm(packed_embedded)
|
61 |
|
62 |
# Unpack the sequence
|
63 |
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
64 |
|
65 |
+
# Get the hidden states in correct order
|
66 |
_, restore_indices = torch.sort(indices)
|
67 |
+
hidden = hidden[:, restore_indices]
|
68 |
else:
|
69 |
# Standard processing without masking
|
70 |
+
_, (hidden, cell) = self.lstm(embedded)
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
# Concatenate the final forward and backward hidden states
|
73 |
+
hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
|
74 |
+
|
75 |
+
# Apply layer normalization (improves stability)
|
76 |
+
normalized = self.layer_norm(hidden_cat)
|
77 |
|
78 |
# Apply dropout to hidden state
|
79 |
+
dropped = self.dropout(normalized)
|
80 |
|
81 |
# prediction = [batch size, output dim]
|
82 |
+
prediction = self.fc(dropped)
|
83 |
|
84 |
return prediction
|