-
Notifications
You must be signed in to change notification settings - Fork 729
[ENH] xLSTMTime implementation
#1709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 35 commits
665825a
5e57d34
e498848
38e4c9c
a72c8c6
b3b3e55
87f4ff4
a6b2da9
39e2b6f
f67509a
46a9e74
7e7d915
31cd4de
c72bff9
62e97ae
93f0913
66900bc
acb23e7
0b85284
b01754e
5e666b4
0a149a7
9b21892
2eda66f
942e717
5556d71
2dca593
8adcb31
1bc559c
fd4b2ba
6a7cc23
60d1651
96ec23d
6a40b7a
40beee8
1cfaf9c
ed189de
7addfad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,12 @@ | |
| Architectural deep learning layers from `nn.Module`. | ||
| """ | ||
|
|
||
| from pytorch_forecasting.layers._attention import AttentionLayer, FullAttention | ||
| from pytorch_forecasting.layers._attention import ( | ||
| AttentionLayer, | ||
| FullAttention, | ||
| TriangularCausalMask, | ||
| ) | ||
| from pytorch_forecasting.layers._decomposition import SeriesDecomposition | ||
| from pytorch_forecasting.layers._embeddings import ( | ||
| DataEmbedding_inverted, | ||
| EnEmbedding, | ||
|
|
@@ -12,18 +17,27 @@ | |
| Encoder, | ||
| EncoderLayer, | ||
| ) | ||
| from pytorch_forecasting.layers._mlstm import mLSTMCell, mLSTMLayer, mLSTMNetwork | ||
| from pytorch_forecasting.layers._output._flatten_head import ( | ||
| FlattenHead, | ||
| ) | ||
| from pytorch_forecasting.layers._slstm import sLSTMCell, sLSTMLayer, sLSTMNetwork | ||
|
|
||
| __all__ = [ | ||
| "FullAttention", | ||
| "TriangularCausalMask", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does this line get removed?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't see any imports for
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually found it in |
||
| "AttentionLayer", | ||
| "TriangularCausalMask", | ||
| "DataEmbedding_inverted", | ||
| "EnEmbedding", | ||
| "PositionalEmbedding", | ||
| "Encoder", | ||
| "EncoderLayer", | ||
| "FlattenHead", | ||
| "mLSTMCell", | ||
| "mLSTMLayer", | ||
| "mLSTMNetwork", | ||
| "sLSTMCell", | ||
| "sLSTMLayer", | ||
| "sLSTMNetwork", | ||
| "SeriesDecomposition", | ||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| """mLSTM layer""" | ||
|
|
||
| from pytorch_forecasting.layers._mlstm.cell import mLSTMCell | ||
| from pytorch_forecasting.layers._mlstm.layer import mLSTMLayer | ||
| from pytorch_forecasting.layers._mlstm.network import mLSTMNetwork | ||
|
|
||
| __all__ = ["mLSTMCell", "mLSTMLayer", "mLSTMNetwork"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| import math | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
|
|
||
| class mLSTMCell(nn.Module): | ||
| """Implements the Matrix Long Short-Term Memory (mLSTM) Cell. | ||
|
|
||
| Implements the mLSTM algorithm as described in the paper: | ||
| (https://arxiv.org/pdf/2407.10240). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| input_size : int | ||
| Size of the input feature vector. | ||
| hidden_size : int | ||
| Number of hidden units in the LSTM cell. | ||
| dropout : float, optional | ||
| Dropout rate applied to inputs and hidden states, by default 0.2. | ||
| layer_norm : bool, optional | ||
| If True, apply Layer Normalization to gates and interactions, by default True. | ||
|
|
||
| Attributes | ||
| ---------- | ||
| Wq : nn.Linear | ||
| Linear layer for computing the query vector. | ||
| Wk : nn.Linear | ||
| Linear layer for computing the key vector. | ||
| Wv : nn.Linear | ||
| Linear layer for computing the value vector. | ||
| Wi : nn.Linear | ||
| Linear layer for the input gate. | ||
| Wf : nn.Linear | ||
| Linear layer for the forget gate. | ||
| Wo : nn.Linear | ||
| Linear layer for the output gate. | ||
| dropout : nn.Dropout | ||
| Dropout regularization layer. | ||
| ln_q, ln_k, ln_v, ln_i, ln_f, ln_o : nn.LayerNorm | ||
| Optional layer normalization layers for respective computations. | ||
| """ | ||
|
|
||
| def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True): | ||
| super().__init__() | ||
| self.input_size = input_size | ||
| self.hidden_size = hidden_size | ||
| self.layer_norm = layer_norm | ||
|
|
||
| self.Wq = nn.Linear(input_size, hidden_size) | ||
| self.Wk = nn.Linear(input_size, hidden_size) | ||
| self.Wv = nn.Linear(input_size, hidden_size) | ||
|
|
||
| self.Wi = nn.Linear(input_size, hidden_size) | ||
| self.Wf = nn.Linear(input_size, hidden_size) | ||
| self.Wo = nn.Linear(input_size, hidden_size) | ||
|
|
||
| self.dropout = nn.Dropout(dropout) | ||
|
|
||
| if layer_norm: | ||
| self.ln_q = nn.LayerNorm(hidden_size) | ||
| self.ln_k = nn.LayerNorm(hidden_size) | ||
| self.ln_v = nn.LayerNorm(hidden_size) | ||
| self.ln_i = nn.LayerNorm(hidden_size) | ||
| self.ln_f = nn.LayerNorm(hidden_size) | ||
| self.ln_o = nn.LayerNorm(hidden_size) | ||
|
|
||
| self.sigmoid = nn.Sigmoid() | ||
| self.tanh = nn.Tanh() | ||
|
|
||
| def forward(self, x, h_prev, c_prev, n_prev): | ||
| """Compute the next hidden, cell, and normalized states in the mLSTM cell. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : torch.Tensor | ||
| The number of features in the input. | ||
| h_prev : torch.Tensor | ||
| Previous hidden state | ||
| c_prev : torch.Tensor | ||
| Previous cell state | ||
| n_prev : torch.Tensor | ||
| Previous normalized state | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple of torch.Tensor: | ||
| h : torch.Tensor | ||
| Current hidden state | ||
| c : torch.Tensor | ||
| Current cell state | ||
| n : torch.Tensor | ||
| Current normalized state | ||
| """ | ||
|
|
||
| batch_size = x.size(0) | ||
| assert ( | ||
| x.dim() == 2 | ||
| ), f"Input should be 2D (batch_size, input_size), got {x.dim()}D" | ||
| assert h_prev.size() == ( | ||
| batch_size, | ||
| self.hidden_size, | ||
| ), f"h_prev shape mismatch: {h_prev.size()}" | ||
| assert c_prev.size() == ( | ||
| batch_size, | ||
| self.hidden_size, | ||
| ), f"c_prev shape mismatch: {c_prev.size()}" | ||
| assert n_prev.size() == ( | ||
| batch_size, | ||
| self.hidden_size, | ||
| ), f"n_prev shape mismatch: {n_prev.size()}" | ||
|
|
||
| x = self.dropout(x) | ||
| h_prev = self.dropout(h_prev) | ||
|
|
||
| q = self.Wq(x) | ||
| k = self.Wk(x) / math.sqrt(self.hidden_size) | ||
| v = self.Wv(x) | ||
|
|
||
| if self.layer_norm: | ||
| q = self.ln_q(q) | ||
| k = self.ln_k(k) | ||
| v = self.ln_v(v) | ||
|
|
||
| i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x)) | ||
| f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x)) | ||
| o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x)) | ||
|
|
||
| k_expanded = k.unsqueeze(-1) | ||
| v_expanded = v.unsqueeze(-2) | ||
|
|
||
| kv_interaction = k_expanded @ v_expanded | ||
|
|
||
| kv_sum = kv_interaction.sum(dim=1) | ||
|
|
||
| c = f * c_prev + i * kv_sum | ||
| n = f * n_prev + i * k | ||
|
|
||
| epsilon = 1e-8 | ||
| normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon) | ||
| h = o * self.tanh(c * normalized_n) | ||
|
|
||
| return h, c, n | ||
|
|
||
| def init_hidden(self, batch_size, device=None): | ||
| """ | ||
| Initialize hidden, cell, and normalization states. | ||
| """ | ||
| if device is None: | ||
| device = next(self.parameters()).device | ||
| shape = (batch_size, self.hidden_size) | ||
| return ( | ||
| torch.zeros(shape, device=device), | ||
| torch.zeros(shape, device=device), | ||
| torch.zeros(shape, device=device), | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from pytorch_forecasting.layers._mlstm.cell import mLSTMCell | ||
|
|
||
|
|
||
| class mLSTMLayer(nn.Module): | ||
| """Implements a mLSTM (Matrix LSTM) layer. | ||
|
|
||
| This class stacks multiple mLSTM cells to form a deep recurrent layer. | ||
| It supports residual connections, layer normalization, and dropout. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| input_size : int | ||
| The number of features in the input. | ||
| hidden_size : int | ||
| The number of features in the hidden state. | ||
| num_layers : int | ||
| The number of mLSTM layers to stack. | ||
| dropout : float, optional | ||
| Dropout probability applied to the inputs and intermediate layers, | ||
| by default 0.2. | ||
| layer_norm : bool, optional | ||
| Whether to use layer normalization in each mLSTM cell, by default True. | ||
| residual_conn : bool, optional | ||
| Whether to enable residual connections between layers, by default True. | ||
|
|
||
| Attributes | ||
| ---------- | ||
| cells : nn.ModuleList | ||
| A list containing all mLSTM cells in the layer. | ||
| dropout : nn.Dropout | ||
| Dropout layer applied between layers. | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| input_size, | ||
| hidden_size, | ||
| num_layers, | ||
| dropout=0.2, | ||
| layer_norm=True, | ||
| residual_conn=True, | ||
| ): | ||
| super().__init__() | ||
| self.input_size = input_size | ||
| self.hidden_size = hidden_size | ||
| self.num_layers = num_layers | ||
| self.layer_norm = layer_norm | ||
| self.residual_conn = residual_conn | ||
| self.dropout = nn.Dropout(dropout) | ||
|
|
||
| self.cells = nn.ModuleList( | ||
| [ | ||
| mLSTMCell( | ||
| input_size if i == 0 else hidden_size, | ||
| hidden_size, | ||
| dropout, | ||
| layer_norm, | ||
| ) | ||
| for i in range(num_layers) | ||
| ] | ||
| ) | ||
|
|
||
| def forward(self, x, h=None, c=None, n=None): | ||
| """Forward pass through the mLSTM layer. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x : torch.Tensor | ||
| The number of features in the input. | ||
| h : torch.Tensor, optional | ||
| Initial hidden states for all layers | ||
| If None, initialized to zeros, by default None. | ||
| c : torch.Tensor, optional | ||
| Initial cell states for all layers | ||
| If None, initialized to zeros, by default None. | ||
| n : torch.Tensor, optional | ||
| Initial normalized states for all layers | ||
| If None, initialized to zeros, by default None. | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple | ||
| output : torch.Tensor | ||
| Final output tensor from the last layer | ||
| (h, c, n) : tuple of torch.Tensor | ||
| Final hidden, cell, and normalized states for all layers: | ||
| - h : torch.Tensor | ||
| - c : torch.Tensor | ||
| - n : torch.Tensor | ||
| """ | ||
|
|
||
| x = x.transpose(0, 1) | ||
| batch_size, seq_len, _ = x.size() | ||
|
|
||
| if h is None or c is None or n is None: | ||
| h, c, n = self.init_hidden(batch_size) | ||
|
|
||
| outputs = [] | ||
|
|
||
| for t in range(seq_len): | ||
| layer_input = x[:, t, :] | ||
| next_hidden_states = [] | ||
| next_cell_states = [] | ||
| next_norm_states = [] | ||
|
|
||
| for i, cell in enumerate(self.cells): | ||
| h_i, c_i, n_i = cell(layer_input, h[i], c[i], n[i]) | ||
|
|
||
| if self.residual_conn and i > 0: | ||
| h_i = h_i + layer_input | ||
|
|
||
| layer_input = h_i | ||
|
|
||
| next_hidden_states.append(h_i) | ||
| next_cell_states.append(c_i) | ||
| next_norm_states.append(n_i) | ||
|
|
||
| h = torch.stack(next_hidden_states) | ||
| c = torch.stack(next_cell_states) | ||
| n = torch.stack(next_norm_states) | ||
|
|
||
| outputs.append(h[-1]) | ||
|
|
||
| output = torch.stack(outputs, dim=1) | ||
|
|
||
| output = output.transpose(0, 1) | ||
|
|
||
| return output, (h, c, n) | ||
|
|
||
| def init_hidden(self, batch_size, device=None): | ||
| """ | ||
| Initialize hidden, cell, and normalization states for all layers. | ||
| """ | ||
| if device is None: | ||
| device = next(self.parameters()).device | ||
| hidden_states, cell_states, norm_states = zip( | ||
| *[ | ||
| self.cells[i].init_hidden(batch_size, device=device) | ||
| for i in range(self.num_layers) | ||
| ] | ||
| ) | ||
|
|
||
| return ( | ||
| torch.stack(hidden_states), | ||
| torch.stack(cell_states), | ||
| torch.stack(norm_states), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is incorrect now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry I forgot to change here