The architecture uses nn.ModuleList to manage current tick layers ($U$) and previous tick recurrence ($W$).
import torch
import torch.nn as nn
class TemporalDenseNet(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size):
super().__init__()
self.num_layers = len(hidden_sizes)
self.hidden_sizes = hidden_sizes
self.prev_concat_size = sum(hidden_sizes)
# Current-tick linear layers U[i]
self.U = nn.ModuleList()
for i in range(self.num_layers):
in_size = input_size if i == 0 else sum(hidden_sizes[:i])
self.U.append(nn.Linear(in_size, hidden_sizes[i]))
# Previous-tick linear layers W[i]
self.W = nn.ModuleList([nn.Linear(self.prev_concat_size, hidden_sizes[i])
for i in range(self.num_layers)])
self.out = nn.Linear(self.prev_concat_size, output_size)
self.activation = torch.tanh
def forward(self, x, prev_outputs=None):
layer_outputs = []
prev_cat = torch.cat(prev_outputs, dim=1) if prev_outputs is not None else None
for i in range(self.num_layers):
current_input = x if i == 0 else torch.cat(layer_outputs, dim=1)
out = self.U[i](current_input)
if prev_cat is not None:
out = out + self.W[i](prev_cat)
out = self.activation(out)
layer_outputs.append(out)
final_cat = torch.cat(layer_outputs, dim=1)
return layer_outputs, torch.sigmoid(self.out(final_cat))