J'ai implémenté GraphConvLayer de DeepChem avec une couche personnalisée de Pytorch.
La dernière fois J'ai sorti un mini-lot en utilisant l'ensemble DataSet et DataLorder créé, et j'ai essayé de le transmettre à GraphConv et de le sortir.
import torch
from torch.utils import data
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
import torch.nn as nn
import numpy as np
class GraphConv(nn.Module):
    def __init__(self,
               in_channel,
               out_channel,
               min_deg=0,
               max_deg=10,
               activation=lambda x: x
               ):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.min_degree = min_deg
        self.max_degree = max_deg
        num_deg = 2 * self.max_degree + (1 - self.min_degree)
        self.W_list = [
            nn.Parameter(torch.Tensor(
                np.random.normal(size=(in_channel, out_channel))).double())
            for k in range(num_deg)]
        self.b_list = [
            nn.Parameter(torch.Tensor(np.zeros(out_channel)).double()) for k in range(num_deg)]
    def forward(self, atom_features, deg_slice, deg_adj_lists):
        #print("deg_adj_list")
        print(deg_adj_lists)
        W = iter(self.W_list)
        b = iter(self.b_list)
        # Sum all neighbors using adjacency matrix
        deg_summed = self.sum_neigh(atom_features, deg_adj_lists)
        # Get collection of modified atom features
        new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]
        for deg in range(1, self.max_degree + 1):
            # Obtain relevant atoms for this degree
            rel_atoms = deg_summed[deg - 1]
            # Get self atoms
            begin = deg_slice[deg - self.min_degree, 0]
            size = deg_slice[deg - self.min_degree, 1]
            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
            # Apply hidden affine to relevant atoms and append
            rel_out = torch.matmul(rel_atoms, next(W)) + next(b)
            self_out = torch.matmul(self_atoms, next(W)) + next(b)
            out = rel_out + self_out
            new_rel_atoms_collection[deg - self.min_degree] = out
        # Determine the min_deg=0 case
        if self.min_degree == 0:
            deg = 0
            begin = deg_slice[deg - self.min_degree, 0]
            size = deg_slice[deg - self.min_degree, 1]
            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
            # Only use the self layer
            out = torch.matmul(self_atoms, next(W)) + next(b)
            new_rel_atoms_collection[deg - self.min_degree] = out
        # Combine all atoms back into the list
        #print(new_rel_atoms_collection)
        atom_features = torch.cat(new_rel_atoms_collection, 0)
        return atom_features
    def sum_neigh(self, atoms, deg_adj_lists):
        """Store the summed atoms by degree"""
        deg_summed = self.max_degree * [None]
        for deg in range(1, self.max_degree + 1):
            index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
            gathered_atoms = atoms[index]
            # Sum along neighbors as well as self, and store
            summed_atoms = torch.sum(gathered_atoms, 1)
            deg_summed[deg - 1] = summed_atoms
        return deg_summed
class GCNDataset(data.Dataset):
    def __init__(self, smiles_list, label_list):
        self.smiles_list = smiles_list
        self.label_list = label_list
    def __len__(self):
        return len(self.smiles_list)
    def __getitem__(self, index):
        return self.smiles_list[index], self.label_list[index]
def gcn_collate_fn(batch):
    from rdkit import Chem
    cmf = ConvMolFeaturizer()
    mols = []
    labels = []
    for sample, label in batch:
        mols.append(Chem.MolFromSmiles(sample))
        labels.append(torch.tensor(label))
    conv_mols = cmf.featurize(mols)
    multiConvMol = ConvMol.agglomerate_mols(conv_mols)
    atom_feature = torch.tensor(multiConvMol.get_atom_features(), dtype=torch.float64)
    deg_slice = torch.tensor(multiConvMol.deg_slice, dtype=torch.float64)
    membership = torch.tensor(multiConvMol.membership, dtype=torch.float64)
    deg_adj_lists = []
    for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
        deg_adj_lists.append(multiConvMol.get_deg_adjacency_lists()[i])
    return atom_feature, deg_slice, membership, deg_adj_lists,  labels
def main():
    dataset = GCNDataset(["CCC", "CCCC", "CCCCC"], [1, 0, 1])
    dataloader = data.DataLoader(dataset, batch_size=3, shuffle=False, collate_fn =gcn_collate_fn)
    model = GraphConv(75, 20)
    for atom_feature, deg_slice, membership, deg_adj_lists, labels in dataloader:
        print("atom_feature")
        print(atom_feature)
        print("deg_slice")
        print(deg_slice)
        print("membership")
        print(membership)
        print("result")
        print(model(atom_feature, deg_slice, deg_adj_lists))
if __name__ == "__main__":
    main()
Oui, non. Pour l'instant, la forme résultante semble être le nombre d'atomes x 20 dimensions (75 dimensions compressées par convolution).
atom_feature
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.]], dtype=torch.float64)
deg_slice
tensor([[ 0.,  0.],
        [ 0.,  6.],
        [ 6.,  6.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.]], dtype=torch.float64)
membership
tensor([0., 0., 1., 1., 2., 2., 0., 1., 1., 2., 2., 2.], dtype=torch.float64)
result
tensor([[-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-1.6645,  6.3024,  0.6540, -0.7638,  5.3761, -6.3710, -0.3202,  1.3862,
          6.6121, -0.5707, -8.2441, -5.8404,  4.4354,  0.8659, -2.3474, -4.8642,
          8.3175,  0.1378, -4.6038, -3.9733],
        [-0.3320,  1.6265, -0.2117, -0.5792,  5.7710,  0.5828, -0.7252,  3.6408,
          7.6525, -0.3339, -6.1131, -2.3356,  3.6018,  1.5834, -2.7556, -4.1401,
          1.4335, -0.4723, -1.7117, -3.6721],
        [-0.3320,  1.6265, -0.2117, -0.5792,  5.7710,  0.5828, -0.7252,  3.6408,
          7.6525, -0.3339, -6.1131, -2.3356,  3.6018,  1.5834, -2.7556, -4.1401,
          1.4335, -0.4723, -1.7117, -3.6721],
        [-0.3320,  1.6265, -0.2117, -0.5792,  5.7710,  0.5828, -0.7252,  3.6408,
          7.6525, -0.3339, -6.1131, -2.3356,  3.6018,  1.5834, -2.7556, -4.1401,
          1.4335, -0.4723, -1.7117, -3.6721],
        [ 1.0006, -3.0494, -1.0774, -0.3946,  6.1658,  7.5366, -1.1302,  5.8955,
          8.6929, -0.0971, -3.9820,  1.1691,  2.7682,  2.3009, -3.1638, -3.4160,
         -5.4505, -1.0824,  1.1805, -3.3708],
        [-0.3320,  1.6265, -0.2117, -0.5792,  5.7710,  0.5828, -0.7252,  3.6408,
          7.6525, -0.3339, -6.1131, -2.3356,  3.6018,  1.5834, -2.7556, -4.1401,
          1.4335, -0.4723, -1.7117, -3.6721]], dtype=torch.float64,
       grad_fn=<CatBackward>)
        Recommended Posts