-
Notifications
You must be signed in to change notification settings - Fork 10
/
featurizer.py
94 lines (79 loc) · 3.39 KB
/
featurizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# -*- coding: utf-8 -*-
"""
@Time:Created on 2020/4/27 14:19
@author: LiFan Chen
@Filename: featurizer.py
@Software: PyCharm
"""
import numpy as np
from rdkit import Chem
from tape import TAPETokenizer
import torch
num_atom_feat = 34
def one_of_k_encoding(x, allowable_set):
if x not in allowable_set:
raise Exception("input {0} not in allowable set{1}:".format(
x, allowable_set))
return [x == s for s in allowable_set]
def one_of_k_encoding_unk(x, allowable_set):
"""Maps inputs not in the allowable set to the last element."""
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
def atom_features(atom,explicit_H=False,use_chirality=True):
"""Generate atom features including atom symbol(10),degree(7),formal charge,
radical electrons,hybridization(6),aromatic(1),Chirality(3)
"""
symbol = ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'other'] # 10-dim
degree = [0, 1, 2, 3, 4, 5, 6] # 7-dim
hybridizationType = [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2,
'other'] # 6-dim
results = one_of_k_encoding_unk(atom.GetSymbol(),symbol) + \
one_of_k_encoding(atom.GetDegree(),degree) + \
[atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
one_of_k_encoding_unk(atom.GetHybridization(), hybridizationType) + [atom.GetIsAromatic()] # 10+7+2+6+1=26
# In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
if not explicit_H:
results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
[0, 1, 2, 3, 4]) # 26+5=31
if use_chirality:
try:
results = results + one_of_k_encoding_unk(
atom.GetProp('_CIPCode'),
['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
except:
results = results + [False, False] + [atom.HasProp('_ChiralityPossible')] # 31+3 =34
return results
def adjacent_matrix(mol):
adjacency = Chem.GetAdjacencyMatrix(mol)
return np.array(adjacency)
def mol_features(smiles):
try:
mol = Chem.MolFromSmiles(smiles)
except:
raise RuntimeError("SMILES cannot been parsed!")
atom_feat = np.zeros((mol.GetNumAtoms(), num_atom_feat))
for atom in mol.GetAtoms():
atom_feat[atom.GetIdx(), :] = atom_features(atom)
adj_matrix = adjacent_matrix(mol)
return atom_feat, adj_matrix
def seq_cat(prot,tokenizer):
xs = tokenizer.encode(prot)
return xs
def featurizer(smiles,sequence):
compounds, adjacencies, proteins = [], [], []
tokenizer = TAPETokenizer(vocab='iupac')
atom_feature, adj = mol_features(smiles)
compounds.append(atom_feature)
adjacencies.append(adj)
sequence = seq_cat(sequence, tokenizer)
with torch.no_grad():
protein_embedding = torch.tensor([sequence], dtype=torch.int64)
proteins.append(protein_embedding.squeeze(0).numpy())
return compounds,adjacencies,proteins
if __name__ == "__main__":
print(0)