|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
|
|
class MetaNet(nn.Module): |
|
def __init__(self, input_dim, xprime_dim): |
|
super().__init__() |
|
self.layer1 = nn.Linear(1, input_dim * xprime_dim) |
|
self.layer2 = nn.Linear(input_dim * xprime_dim, input_dim * xprime_dim) |
|
self.input_dim = input_dim |
|
self.xprime_dim = xprime_dim |
|
|
|
def forward(self, x_feat): |
|
B = x_feat.size(0) |
|
out = torch.tanh(self.layer1(x_feat)) |
|
out = torch.tanh(self.layer2(out)) |
|
return out.view(B, self.input_dim, self.xprime_dim) |
|
|
|
|
|
|
|
class GatingNet(nn.Module): |
|
def __init__(self, hidden_size, num_experts=3): |
|
super().__init__() |
|
self.layer1 = nn.Linear(hidden_size, hidden_size) |
|
self.layer2 = nn.Linear(hidden_size, num_experts) |
|
|
|
def forward(self, h, epoch=None, top_k=None, warmup_epochs=0): |
|
logits = self.layer2(torch.tanh(self.layer1(h))) |
|
|
|
if (epoch is None) or (top_k is None) or (epoch < warmup_epochs): |
|
return torch.softmax(logits, dim=-1) |
|
|
|
topk_vals, topk_idx = torch.topk(logits, k=top_k, dim=-1) |
|
mask = torch.zeros_like(logits).scatter(1, topk_idx, 1.0) |
|
masked_logits = logits.masked_fill(mask == 0, float('-inf')) |
|
return torch.softmax(masked_logits, dim=-1) |
|
|
|
|
|
class MetaTransformBlock(nn.Module): |
|
def __init__(self, xprime_dim, num_experts=3, input_dim=1, hidden_size=64): |
|
super().__init__() |
|
self.meta_temp = MetaNet(input_dim, xprime_dim) |
|
self.meta_work = MetaNet(input_dim, xprime_dim) |
|
self.meta_season = MetaNet(input_dim, xprime_dim) |
|
self.gating = GatingNet(hidden_size, num_experts) |
|
self.ln = nn.LayerNorm([input_dim, xprime_dim]) |
|
self.theta0 = nn.Parameter(torch.zeros(1, input_dim, xprime_dim)) |
|
|
|
def forward(self, h_prev_rnn, x_l, x_t, x_w, x_s, epoch=None, top_k=None, warmup_epochs=0): |
|
w_temp = self.ln(self.meta_temp(x_t)) |
|
w_work = self.ln(self.meta_work(x_w)) |
|
w_seas = self.ln(self.meta_season(x_s)) |
|
|
|
gates = self.gating(h_prev_rnn, epoch=epoch, top_k=top_k, warmup_epochs=warmup_epochs) |
|
W_experts = torch.stack([w_temp, w_work, w_seas], dim=1) |
|
gates_expanded = gates.view(gates.size(0), gates.size(1), 1, 1) |
|
theta_dynamic = (W_experts * gates_expanded).sum(dim=1) |
|
theta = theta_dynamic + self.theta0 |
|
|
|
x_prime = torch.bmm(x_l.unsqueeze(1), theta).squeeze(1) |
|
return x_prime, theta |
|
|
|
|
|
class Encoder_meta(nn.Module): |
|
def __init__(self, xprime_dim, hidden_size, num_layers=1, dropout=0.1): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.num_layers = num_layers |
|
self.rnn = nn.GRU(xprime_dim, hidden_size, num_layers, |
|
batch_first=True, |
|
dropout=dropout if num_layers > 1 else 0) |
|
|
|
def forward(self, x_l_seq, x_t_seq, x_w_seq, x_s_seq, |
|
transform_block, h_init=None, epoch=None, top_k=None, warmup_epochs=0): |
|
B, T, _ = x_l_seq.shape |
|
h_rnn = torch.zeros(self.num_layers, B, self.hidden_size, device=x_l_seq.device) if h_init is None else h_init |
|
|
|
for t in range(T): |
|
h_for_meta = h_rnn[-1] |
|
x_prime, _ = transform_block(h_for_meta, |
|
x_l_seq[:, t], x_t_seq[:, t], |
|
x_w_seq[:, t], x_s_seq[:, t], |
|
epoch=epoch, top_k=top_k, warmup_epochs=warmup_epochs) |
|
x_prime = x_prime.unsqueeze(1) |
|
_, h_rnn = self.rnn(x_prime, h_rnn) |
|
|
|
return h_rnn |
|
|
|
|
|
|
|
class Decoder_meta(nn.Module): |
|
def __init__(self, xprime_dim, latent_size, output_len, output_dim=1, |
|
num_layers=1, dropout=0.1, hidden_size=None): |
|
super().__init__() |
|
self.latent_size = latent_size |
|
self.output_len = output_len |
|
self.output_dim = output_dim |
|
self.num_layers = num_layers |
|
|
|
self.rnn = nn.GRU(xprime_dim, latent_size, num_layers, |
|
batch_first=True, |
|
dropout=dropout if num_layers > 1 else 0) |
|
|
|
self.head = nn.Linear(latent_size, output_len * output_dim) |
|
|
|
|
|
assert hidden_size is not None, "You must provide hidden_size for projection." |
|
self.project = nn.ModuleList([ |
|
nn.Linear(hidden_size, latent_size) for _ in range(num_layers) |
|
]) |
|
|
|
def forward(self, x_l_seq, x_t_seq, x_w_seq, x_s_seq, |
|
h_init, transform_block, |
|
epoch=None, top_k=None, warmup_epochs=0): |
|
B, L, _ = x_l_seq.shape |
|
|
|
|
|
h_rnn = torch.stack([ |
|
self.project[i](h_init[i]) for i in range(self.num_layers) |
|
], dim=0) |
|
|
|
preds = [] |
|
|
|
|
|
h_last = h_rnn[-1] |
|
pred_0 = self.head(h_last).view(B, self.output_len, self.output_dim) |
|
preds.append(pred_0.unsqueeze(1)) |
|
|
|
|
|
for t in range(L): |
|
h_for_meta = h_rnn[-1] |
|
x_prime, _ = transform_block(h_for_meta, |
|
x_l_seq[:, t], x_t_seq[:, t], |
|
x_w_seq[:, t], x_s_seq[:, t], |
|
epoch=epoch, top_k=top_k, warmup_epochs=warmup_epochs) |
|
x_prime = x_prime.unsqueeze(1) |
|
out_t, h_rnn = self.rnn(x_prime, h_rnn) |
|
pred_t = self.head(out_t.squeeze(1)).view(B, self.output_len, self.output_dim) |
|
preds.append(pred_t.unsqueeze(1)) |
|
|
|
preds = torch.cat(preds, dim=1) |
|
return preds |
|
|
|
|
|
|
|
class Seq2Seq_meta(nn.Module): |
|
def __init__(self, xprime_dim, input_dim, hidden_size, latent_size, |
|
output_len, output_dim=1, num_layers=1, dropout=0.1, num_experts=3): |
|
super().__init__() |
|
|
|
self.transform_enc = MetaTransformBlock( |
|
xprime_dim=xprime_dim, |
|
num_experts=num_experts, |
|
input_dim=input_dim, |
|
hidden_size=hidden_size |
|
) |
|
|
|
self.transform_dec = MetaTransformBlock( |
|
xprime_dim=xprime_dim, |
|
num_experts=num_experts, |
|
input_dim=input_dim, |
|
hidden_size=latent_size |
|
) |
|
|
|
self.encoder = Encoder_meta( |
|
xprime_dim=xprime_dim, |
|
hidden_size=hidden_size, |
|
num_layers=num_layers, |
|
dropout=dropout) |
|
|
|
self.decoder = Decoder_meta( |
|
xprime_dim=xprime_dim, |
|
latent_size=latent_size, |
|
output_len=output_len, |
|
output_dim=output_dim, |
|
num_layers=num_layers, |
|
dropout=dropout, |
|
hidden_size=hidden_size |
|
) |
|
|
|
def forward(self, |
|
enc_l, enc_t, enc_w, enc_s, |
|
dec_l, dec_t, dec_w, dec_s, |
|
epoch=None, top_k=None, warmup_epochs=0): |
|
|
|
h_enc = self.encoder(enc_l, enc_t, enc_w, enc_s, |
|
transform_block=self.transform_enc, |
|
epoch=epoch, top_k=top_k, warmup_epochs=warmup_epochs) |
|
|
|
preds = self.decoder(dec_l, dec_t, dec_w, dec_s, |
|
h_init=h_enc, |
|
transform_block=self.transform_dec, |
|
epoch=epoch, top_k=top_k, warmup_epochs=warmup_epochs) |
|
return preds |
|
|
|
|
|
|
|
|
|
class VariationalEncoder_meta(nn.Module): |
|
def __init__(self, xprime_dim, hidden_size, latent_size, num_layers=1, dropout=0.1): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.latent_size = latent_size |
|
self.num_layers = num_layers |
|
|
|
self.rnn = nn.GRU(xprime_dim, hidden_size, num_layers, |
|
batch_first=True, |
|
dropout=dropout if num_layers > 1 else 0) |
|
|
|
self.mu_layer = nn.Linear(hidden_size, latent_size) |
|
self.logvar_layer = nn.Linear(hidden_size, latent_size) |
|
|
|
def forward(self, x_l_seq, x_t_seq, x_w_seq, x_s_seq, |
|
transform_block, h_init=None, epoch=None, top_k=None, warmup_epochs=0): |
|
|
|
B, T, _ = x_l_seq.shape |
|
h_rnn = torch.zeros(self.num_layers, B, self.hidden_size, device=x_l_seq.device) if h_init is None else h_init |
|
|
|
for t in range(T): |
|
h_for_meta = h_rnn[-1] |
|
x_prime, _ = transform_block(h_for_meta, |
|
x_l_seq[:, t], x_t_seq[:, t], |
|
x_w_seq[:, t], x_s_seq[:, t], |
|
epoch=epoch, top_k=top_k, warmup_epochs=warmup_epochs) |
|
x_prime = x_prime.unsqueeze(1) |
|
_, h_rnn = self.rnn(x_prime, h_rnn) |
|
|
|
h_last = h_rnn[-1] |
|
mu = self.mu_layer(h_last) |
|
logvar = self.logvar_layer(h_last) |
|
|
|
return mu, logvar |
|
|
|
|
|
|
|
class VariationalDecoder_meta_predvar(nn.Module): |
|
def __init__(self, xprime_dim, latent_size, output_len, output_dim=1, |
|
num_layers=1, dropout=0.1): |
|
super().__init__() |
|
self.latent_size = latent_size |
|
self.output_len = output_len |
|
self.output_dim = output_dim |
|
self.num_layers = num_layers |
|
|
|
self.rnn = nn.GRU(xprime_dim, latent_size, num_layers, |
|
batch_first=True, |
|
dropout=dropout if num_layers > 1 else 0) |
|
|
|
|
|
self.head_mu = nn.Linear(latent_size, output_len * output_dim) |
|
self.head_logvar = nn.Linear(latent_size, output_len * output_dim) |
|
|
|
def forward(self, x_l_seq, x_t_seq, x_w_seq, x_s_seq, |
|
z_latent, transform_block, |
|
epoch=None, top_k=None, warmup_epochs=0): |
|
B, L, _ = x_l_seq.shape |
|
|
|
h_rnn = z_latent.unsqueeze(0).repeat(self.num_layers, 1, 1) |
|
|
|
mu_preds = [] |
|
logvar_preds = [] |
|
|
|
|
|
h_last = h_rnn[-1] |
|
mu_0 = self.head_mu(h_last).view(B, self.output_len, self.output_dim) |
|
logvar_0 = self.head_logvar(h_last).view(B, self.output_len, self.output_dim) |
|
mu_preds.append(mu_0.unsqueeze(1)) |
|
logvar_preds.append(logvar_0.unsqueeze(1)) |
|
|
|
|
|
for t in range(L): |
|
h_for_meta = h_rnn[-1] |
|
x_prime, _ = transform_block(h_for_meta, |
|
x_l_seq[:, t], x_t_seq[:, t], |
|
x_w_seq[:, t], x_s_seq[:, t], |
|
epoch=epoch, top_k=top_k, warmup_epochs=warmup_epochs) |
|
x_prime = x_prime.unsqueeze(1) |
|
out_t, h_rnn = self.rnn(x_prime, h_rnn) |
|
|
|
mu_t = self.head_mu(out_t.squeeze(1)).view(B, self.output_len, self.output_dim) |
|
logvar_t = self.head_logvar(out_t.squeeze(1)).view(B, self.output_len, self.output_dim) |
|
|
|
mu_preds.append(mu_t.unsqueeze(1)) |
|
logvar_preds.append(logvar_t.unsqueeze(1)) |
|
|
|
|
|
mu_preds = torch.cat(mu_preds, dim=1) |
|
logvar_preds = torch.cat(logvar_preds, dim=1) |
|
|
|
return mu_preds, logvar_preds |
|
|
|
|
|
|
|
class VariationalSeq2Seq_meta(nn.Module): |
|
def __init__(self, xprime_dim, input_dim, hidden_size, latent_size, |
|
output_len, output_dim=1, num_layers=1, dropout=0.1, num_experts=3): |
|
super().__init__() |
|
|
|
self.transform_enc = MetaTransformBlock( |
|
xprime_dim=xprime_dim, |
|
num_experts=num_experts, |
|
input_dim=input_dim, |
|
hidden_size=hidden_size |
|
) |
|
|
|
self.transform_dec = MetaTransformBlock( |
|
xprime_dim=xprime_dim, |
|
num_experts=num_experts, |
|
input_dim=input_dim, |
|
hidden_size=latent_size |
|
) |
|
|
|
self.encoder = VariationalEncoder_meta( |
|
xprime_dim=xprime_dim, |
|
hidden_size=hidden_size, |
|
latent_size=latent_size, |
|
num_layers=num_layers, |
|
dropout=dropout |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.decoder = VariationalDecoder_meta_predvar( |
|
xprime_dim=xprime_dim, |
|
latent_size=latent_size, |
|
output_len=output_len, |
|
output_dim=output_dim, |
|
num_layers=num_layers, |
|
dropout=dropout |
|
) |
|
|
|
def reparameterize(self, mu, logvar): |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
return mu + eps * std |
|
|
|
def forward(self, |
|
enc_l, enc_t, enc_w, enc_s, |
|
dec_l, dec_t, dec_w, dec_s, |
|
epoch=None, top_k=None, warmup_epochs=0): |
|
|
|
mu, logvar = self.encoder(enc_l, enc_t, enc_w, enc_s, |
|
transform_block=self.transform_enc, |
|
epoch=epoch, top_k=top_k, warmup_epochs=warmup_epochs) |
|
|
|
z = self.reparameterize(mu, logvar) |
|
|
|
mu_preds, logvar_preds = self.decoder(dec_l, dec_t, dec_w, dec_s, |
|
z_latent=z, |
|
transform_block=self.transform_dec, |
|
epoch=epoch, top_k=top_k, warmup_epochs=warmup_epochs) |
|
|
|
return mu_preds, logvar_preds, mu, logvar |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|