Skip to content

Commit

Permalink
Merge pull request #446 from seoklab/python/fmt/cif-perf
Browse files Browse the repository at this point in the history
feat(python/fmt/cif): add utility function to convert DDL2 CIF to dict
  • Loading branch information
jnooree authored Jan 6, 2025
2 parents 41b0a4d + 29ae430 commit 81f797b
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 25 deletions.
106 changes: 87 additions & 19 deletions python/src/nuri/fmt/cif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
#include <vector>

#include <absl/base/nullability.h>
#include <absl/container/flat_hash_map.h>
#include <absl/log/absl_check.h>
#include <absl/strings/match.h>
#include <absl/strings/str_cat.h>
#include <absl/strings/strip.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl/filesystem.h>

Expand All @@ -30,14 +33,13 @@ namespace python_internal {
namespace {
namespace fs = std::filesystem;

pyt::Dict<py::str, pyt::Optional<py::str>>
cif_table_row(const internal::CifTable &table, const pyt::List<py::str> &keys,
int row) {
pyt::Dict<py::str, pyt::Optional<py::str>> data;
pyt::List<pyt::Optional<py::str>> cif_table_row(const internal::CifTable &table,
int row) {
pyt::List<pyt::Optional<py::str>> data;
for (int i = 0; i < table.cols(); ++i) {
const internal::CifValue &val = table.data()[row][i];
data[keys[i]] = val.is_null() ? py::none().cast<py::object>()
: py::str(*val).cast<py::object>();
data.append(val.is_null() ? py::none().cast<py::object>()
: py::str(*val).cast<py::object>());
}
return data;
}
Expand All @@ -47,27 +49,23 @@ cif_table_row(const internal::CifTable &table, const pyt::List<py::str> &keys,
class PyCifTableIterator
: public PyIterator<PyCifTableIterator, const internal::CifTable> {
public:
PyCifTableIterator(const internal::CifTable &table,
const pyt::List<py::str> &keys)
: Parent(table), keys_(keys) { }
using Parent::Parent;

static auto bind(py::module &m) {
return Parent::bind(m, "_CifTableIterator");
}

pyt::Dict<py::str, pyt::Optional<py::str>>
deref(const internal::CifTable &table, int row) const {
return cif_table_row(table, keys_, row);
}

private:
friend Parent;

static pyt::List<pyt::Optional<py::str>>
deref(const internal::CifTable &table, int row) {
return cif_table_row(table, row);
}

static size_t size_of(const internal::CifTable &table) {
return table.size();
}

pyt::List<py::str> keys_;
};

class PyCifTable {
Expand All @@ -77,12 +75,12 @@ class PyCifTable {
keys_.append(key);
}

PyCifTableIterator iter() const { return PyCifTableIterator(*table_, keys_); }
PyCifTableIterator iter() const { return PyCifTableIterator(*table_); }

pyt::Dict<py::str, pyt::Optional<py::str>> get(int row) const {
pyt::List<pyt::Optional<py::str>> get(int row) const {
row = py_check_index(static_cast<int>(table_->size()), row,
"CifTable row index out of range");
return cif_table_row(*table_, keys_, row);
return cif_table_row(*table_, row);
}

size_t size() const { return table_->size(); }
Expand Down Expand Up @@ -139,6 +137,8 @@ class PyCifFrame {
return py::cast(PyCifTable((*frame_)[it.begin()->second.first]));
}

const internal::CifFrame &cpp() const { return *frame_; }

private:
absl::Nonnull<const internal::CifFrame *> frame_;
};
Expand Down Expand Up @@ -224,6 +224,66 @@ void bind_opaque_vector(py::module &m, const char *name, const char *onerror) {
return absl::StrCat("<", name, " of ", self.size(), " tables>");
});
}

pyt::Dict<py::str, pyt::List<pyt::Dict<py::str, pyt::Optional<py::str>>>>
cif_ddl2_frame_as_dict(const PyCifFrame &frame) {
absl::flat_hash_map<
std::string_view,
std::pair<std::vector<py::str>, std::vector<std::vector<py::object>>>>
grouped;

std::vector<std::string_view> parent_keys;
std::vector<decltype(grouped)::iterator> slots;
for (const internal::CifTable &table: frame.cpp()) {
parent_keys.clear();
parent_keys.reserve(table.cols());
for (std::string_view key: table.keys()) {
std::string_view pk = key.substr(0, key.find('.'));
std::string_view sk = key.substr(pk.size() + 1);

pk = absl::StripPrefix(pk, "_");

parent_keys.push_back(pk);
grouped[pk].first.push_back(sk);
}

slots.clear();
slots.reserve(table.cols());
for (std::string_view pk: parent_keys)
slots.push_back(grouped.find(pk));

for (const auto &row: table) {
for (int i = 0; i < table.cols(); ++i) {
auto it = slots[i];
ABSL_DCHECK(it != grouped.end());

auto &data = it->second.second;

if (data.empty() || data.back().size() == it->second.first.size())
data.emplace_back().reserve(it->second.first.size());

const internal::CifValue &val = row[i];
data.back().push_back(val.is_null() ? py::none().cast<py::object>()
: py::str(*val).cast<py::object>());
}
}
}

py::dict tagged;
for (auto &group: grouped) {
py::str pk(group.first);
py::list rows;
for (const auto &row: group.second.second) {
py::dict entry;
for (int i = 0; i < row.size(); ++i)
entry[group.second.first[i]] = row[i];
rows.append(entry);
}
tagged[pk] = rows;
}

return tagged;
}
} // namespace

void bind_cif(py::module &m) {
Expand Down Expand Up @@ -268,6 +328,14 @@ Create a parser object from a CIF file path.
:param path: The path to the CIF file.
:return: A parser object that can be used to iterate over the blocks in the file.
)doc")
.def("cif_ddl2_frame_as_dict", &cif_ddl2_frame_as_dict, py::arg("frame"),
R"doc(
Convert a CIF frame to a dictionary of lists of dictionaries.
:param frame: The CIF frame to convert.
:return: A dictionary of lists of dictionaries, where the keys are the parent
keys and the values are the rows of the table.
)doc");
}
} // namespace python_internal
Expand Down
30 changes: 24 additions & 6 deletions python/test/fmt/cif_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from nuri.fmt import read_cif
from nuri.fmt import cif_ddl2_frame_as_dict, read_cif


def test_read_cif(test_data: Path):
Expand All @@ -30,10 +30,10 @@ def test_read_cif(test_data: Path):

for table in frame:
assert len(table) == 1
assert table[0]["_entry.id"] == "1A8O"
assert table[0][0] == "1A8O"
break

assert frame[0][0]["_entry.id"] == "1A8O"
assert frame[0][0][0] == "1A8O"

atom_site = frame.prefix_search_first("_atom_site.")
assert atom_site is not None
Expand Down Expand Up @@ -70,12 +70,30 @@ def test_read_cif(test_data: Path):
assert len(atom_site) == 644

row = atom_site[0]
assert row["_atom_site.type_symbol"] == "N"
assert row["_atom_site.label_alt_id"] is None
assert row[2] == "N" # _atom_site.type_symbol
assert row[4] is None # _atom_site.label_alt_id

for row in atom_site:
assert row["_atom_site.id"] == "1"
assert row[1] == "1" # _atom_site.id
break

nonexistent = frame.prefix_search_first("_foobar.")
assert nonexistent is None


def test_convert_ddl2_cif(test_data: Path):
cif = test_data / "1a8o.cif"

blocks = list(read_cif(cif))
assert len(blocks) == 1
frame = blocks[0].data

ddl = cif_ddl2_frame_as_dict(frame)

assert ddl["entry"][0]["id"] == "1A8O"

atom_site = ddl["atom_site"]
assert len(atom_site) == 644

assert atom_site[0]["type_symbol"] == "N"
assert atom_site[0]["label_alt_id"] is None

0 comments on commit 81f797b

Please sign in to comment.