|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import pandas as pd |
|
import numpy as np |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, input_size, output_size, hidden_size): |
|
super(MLP, self).__init__() |
|
self.fc1 = nn.Linear(input_size, hidden_size) |
|
self.fc2 = nn.Linear(hidden_size, output_size) |
|
self.relu = nn.ReLU() |
|
self.soft = nn.Softmax(1) |
|
|
|
def forward(self, x): |
|
out = self.fc1(x) |
|
out = self.relu(out) |
|
out = self.fc2(out) |
|
out = self.soft(out) |
|
print('Original embeddings:\n', out) |
|
return out |
|
|
|
|
|
class Expert(nn.Module): |
|
def __init__(self, model, output_size, verbose=True): |
|
super().__init__() |
|
self.verbose = verbose |
|
self.model = model |
|
self.output_size = output_size |
|
|
|
def forward(self, x): |
|
|
|
if len(x) == 0: |
|
return torch.empty(size=(0, self.output_size)) |
|
|
|
|
|
out = self.model.encode(x) |
|
|
|
|
|
if isinstance(out, pd.DataFrame): |
|
out = torch.tensor(out.values, dtype=torch.float32) |
|
elif isinstance(out, list): |
|
out = torch.stack(out, dim=0) |
|
|
|
|
|
out = F.pad(out, pad=(0, self.output_size - out.shape[1], 0, 0), value=0) |
|
|
|
|
|
if self.verbose: |
|
print(f'Original embeddings:\n', out) |
|
|
|
return out |
|
|
|
|
|
class Net(nn.Module): |
|
def __init__(self, smiles_embed_dim, output_dim=2, dropout=0.2): |
|
super().__init__() |
|
self.desc_skip_connection = True |
|
self.fc1 = nn.Linear(smiles_embed_dim, smiles_embed_dim) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.relu1 = nn.GELU() |
|
self.fc2 = nn.Linear(smiles_embed_dim, smiles_embed_dim) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.relu2 = nn.GELU() |
|
self.final = nn.Linear(smiles_embed_dim, output_dim) |
|
|
|
def forward(self, smiles_emb): |
|
x_out = self.fc1(smiles_emb) |
|
x_out = self.dropout1(x_out) |
|
x_out = self.relu1(x_out) |
|
|
|
if self.desc_skip_connection is True: |
|
x_out = x_out + smiles_emb |
|
|
|
z = self.fc2(x_out) |
|
z = self.dropout2(z) |
|
z = self.relu2(z) |
|
if self.desc_skip_connection is True: |
|
z = self.final(z + x_out) |
|
else: |
|
z = self.final(z) |
|
|
|
return z |