Skip to content

Commit

Permalink
Merge pull request #93 from wells-wood-research/cbcaswap
Browse files Browse the repository at this point in the history
swap ca cb channels
  • Loading branch information
ChrisWellsWood authored Jan 29, 2024
2 parents ffea99e + 56c35a2 commit ce7431b
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 37 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "aposteriori"
version = "2.2.0"
version = "2.3.0"
requires-python = ">= 3.8"
readme = "README.md"
dependencies = [
Expand All @@ -16,4 +16,4 @@ dependencies = [
]

[project.scripts]
make-frame-dataset = "aposteriori.data_prep.cli:cli"
make-frame-dataset = "aposteriori.data_prep.cli:cli"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="aposteriori",
version="2.2.0",
version="2.3.0",
author="Wells Wood Research Group",
author_email="[email protected]",
description="A library for the voxelization of protein structures for protein design.",
Expand Down
2 changes: 1 addition & 1 deletion src/aposteriori/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ampal.data import ELEMENT_DATA

# Config paths
MAKE_FRAME_DATASET_VER = "2.2.0"
MAKE_FRAME_DATASET_VER = "2.3.0"
PROJECT_ROOT_DIR = pathlib.Path(__file__).parent
DATA_FOLDER = PROJECT_ROOT_DIR / "data"
DATA_FOLDER.mkdir(parents=True, exist_ok=True)
Expand Down
61 changes: 40 additions & 21 deletions src/aposteriori/data_prep/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,23 @@
@click.option(
"-ae",
"--atom_encoder",
type=click.Choice(["CNO", "CNOCB", "CNOCBCA", "CNOCBCAQ", "CNOCBCAP"]),
type=click.Choice(
[
"CNO",
"CNOCB",
"CNOCACB",
"CNOCBCA",
"CNOCACBQ",
"CNOCBCAQ",
"CNOCACBP",
"CNOCBCAP",
]
),
default="CNO",
required=True,
help=(
"Encodes atoms in different channels, depending on atom types. Default is CNO, other options are ´CNOCB´ and `CNOCBCA` to encode the Cb or Cb and Ca in different channels respectively."
"Encodes atoms in different channels, depending on atom types. Default is CNO, other options are ´CNOCB´ and `CNOCACB` to encode the Cb or Cb and Ca in different channels respectively. "
"Charged and polar versions can be used with CNOCACBQ and CNOCACBP respectively."
),
)
@click.option(
Expand Down Expand Up @@ -262,25 +274,32 @@ def cli(
f"{structure_file_folder} file not found. Did you specify the -d argument for the download file? If so, check your spelling."
)
sys.exit()
# Create Codec:
if atom_encoder == "CNO":
codec = Codec.CNO()
elif atom_encoder == "CNOCB":
codec = Codec.CNOCB()
elif atom_encoder == "CNOCBCA":
codec = Codec.CNOCBCA()
elif atom_encoder == "CNOCBCAQ":
codec = Codec.CNOCBCAQ()
elif atom_encoder == "CNOCBCAP":
codec = Codec.CNOCBCAP()
else:
assert atom_encoder in [
"CNO",
"CNOCB",
"CNOCBCA",
"CNOCBCAQ",
"CNOCBCAP",
], f"Expected encoder to be CNO, CNOCB, CNOCBCA, CNOCBCAQ, CNOCBCAP, but got {atom_encoder}"
# Mapping of current atom encoders to their corresponding Codec classes
current_codec_mapping = {
"CNO": Codec.CNO,
"CNOCB": Codec.CNOCB,
"CNOCACB": Codec.CNOCACB,
"CNOCACBQ": Codec.CNOCACBQ,
"CNOCACBP": Codec.CNOCACBP,
}

# List of deprecated encodings and their replacements
deprecated_encodings = {
"CNOCBCA": "CNOCACB",
"CNOCBCAQ": "CNOCACBQ",
"CNOCBCAP": "CNOCACBP",
}

# Create Codec based on atom_encoder
if atom_encoder in current_codec_mapping:
codec = current_codec_mapping[atom_encoder]()
elif atom_encoder in deprecated_encodings:
replacement = deprecated_encodings[atom_encoder]
codec = current_codec_mapping[replacement]()
warnings.warn(
f"{atom_encoder} encoding is deprecated and will be removed in future versions, "
f"atoms will be encoded as {replacement}"
)

make_frame_dataset(
structure_files=structure_files,
Expand Down
18 changes: 9 additions & 9 deletions src/aposteriori/data_prep/create_frame_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ def CNOCB(cls):
return cls(["C", "N", "O", "CB"])

@classmethod
def CNOCBCA(cls):
return cls(["C", "N", "O", "CB", "CA"])
def CNOCACB(cls):
return cls(["C", "N", "O", "CA", "CB"])

@classmethod
def CNOCBCAQ(cls):
return cls(["C", "N", "O", "CB", "CA", "Q"])
def CNOCACBQ(cls):
return cls(["C", "N", "O", "CA", "CB", "Q"])

@classmethod
def CNOCBCAP(cls):
return cls(["C", "N", "O", "CB", "CA", "P"])
def CNOCACBP(cls):
return cls(["C", "N", "O", "CA", "CB", "P"])

def encode_atom(self, atom_label: str) -> np.ndarray:
"""
Expand Down Expand Up @@ -726,13 +726,13 @@ def create_residue_frame(
# Check whether central atom is C:
if "CA" in codec.atomic_labels:
if voxels_as_gaussian:
np.testing.assert_array_less(frame[centre, centre, centre][4], 1)
np.testing.assert_array_less(frame[centre, centre, centre][3], 1)
assert (
0 < frame[centre, centre, centre][4] <= 1
0 < frame[centre, centre, centre][3] <= 1
), f"The central atom value should be between 0 and 1 but was {frame[centre, centre, centre][4]}"
else:
assert (
frame[centre, centre, centre][4] == 1
frame[centre, centre, centre][3] == 1
), f"The central atom should be Carbon, but it is {frame[centre, centre, centre]}."
else:
if voxels_as_gaussian:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_create_frame_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def test_make_frame_dataset_as_gaussian_cnocacbq():
test_file = TEST_DATA_DIR / "1ubq.pdb"
frame_edge_length = 18.0
voxels_per_side = 31
codec = cfds.Codec.CNOCBCAQ()
codec = cfds.Codec.CNOCACBQ()
ampal_1ubq = ampal.load_pdb(str(test_file))
ampal_1ubq2 = ampal.load_pdb(str(test_file))

Expand Down Expand Up @@ -437,7 +437,7 @@ def test_make_frame_dataset_as_gaussian_cnocacbp():
test_file = TEST_DATA_DIR / "1ubq.pdb"
frame_edge_length = 18.0
voxels_per_side = 31
codec = cfds.Codec.CNOCBCAP()
codec = cfds.Codec.CNOCACBP()

ampal_1ubq = ampal.load_pdb(str(test_file))
ampal_1ubq2 = ampal.load_pdb(str(test_file))
Expand Down Expand Up @@ -544,7 +544,7 @@ def test_cb_atom_filter(residue_number: int):
def test_add_gaussian_at_position():
main_matrix = np.zeros((5, 5, 5, 5), dtype=np.float)
modifiers_triple = (0, 0, 0)
codec = cfds.Codec.CNOCBCA()
codec = cfds.Codec.CNOCACB()

secondary_matrix, atom_idx = codec.encode_gaussian_atom("C", modifiers_triple)
atom_coord = (1, 1, 1)
Expand Down

0 comments on commit ce7431b

Please sign in to comment.