-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #620 from openforcefield/parameter-handler-plugins
Add Mechanism to Register Handlers via Entrypoints
- Loading branch information
Showing
6 changed files
with
219 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Contains a set of 'plugin' classes to enable testing of the plugin system.""" | ||
from openforcefield.typing.engines.smirnoff import ParameterHandler, ParameterIOHandler | ||
|
||
|
||
class CustomHandler(ParameterHandler): | ||
_TAGNAME = 'CustomHandler' | ||
|
||
|
||
class CustomIOHandler(ParameterIOHandler): | ||
_FORMAT = 'JSON' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
""" | ||
Test classes and function in module openforcefield.typing.engines.smirnoff.plugins | ||
""" | ||
import pkg_resources | ||
import pytest | ||
|
||
from openforcefield.typing.engines.smirnoff import ForceField | ||
from openforcefield.typing.engines.smirnoff.plugins import load_handler_plugins | ||
|
||
|
||
@pytest.yield_fixture() | ||
def mock_entry_point_plugins(): | ||
"""Registers a fake parameter handler and io handler with the | ||
entry point plugin system. | ||
Notes | ||
----- | ||
This function is based on `this stack overflow answer | ||
<https://stackoverflow.com/a/48666503/11808960>`_ | ||
""" | ||
|
||
previous_entries = pkg_resources.working_set.entries | ||
previous_entry_keys = pkg_resources.working_set.entry_keys | ||
previous_by_key = pkg_resources.working_set.by_key | ||
|
||
# Create a fake distribution to insert into the global working_set | ||
distribution = pkg_resources.Distribution(__file__) | ||
|
||
# Create the fake entry point definitions. These include a parameter handler | ||
# which is supported, and an io parameter handler which should be skipped. | ||
handler_entry_point = pkg_resources.EntryPoint.parse( | ||
"CustomHandler = openforcefield.tests.plugins:CustomHandler", | ||
dist=distribution | ||
) | ||
io_handler_entry_point = pkg_resources.EntryPoint.parse( | ||
"CustomIOHandler = openforcefield.tests.plugins:CustomIOHandler", | ||
dist=distribution | ||
) | ||
|
||
# Add the mapping to the fake EntryPoint | ||
distribution._ep_map = { | ||
"openff.toolkit.plugins.handlers": { | ||
"CustomHandler": handler_entry_point, | ||
"CustomIOHandler": io_handler_entry_point | ||
} | ||
} | ||
|
||
# Add the fake distribution to the global working_set | ||
pkg_resources.working_set.add(distribution, "CustomHandler") | ||
pkg_resources.working_set.add(distribution, "CustomIOHandler") | ||
|
||
yield | ||
|
||
pkg_resources.working_set.entries = previous_entries | ||
pkg_resources.working_set.entry_keys = previous_entry_keys | ||
pkg_resources.working_set.by_key = previous_by_key | ||
|
||
|
||
def test_force_field_custom_handler(mock_entry_point_plugins): | ||
"""Tests a force field can make use of a custom parameter handler registered | ||
through the entrypoint plugin system. | ||
""" | ||
|
||
# Construct a simple FF which only uses the custom handler. | ||
force_field_contents = "\n".join( | ||
[ | ||
"<?xml version='1.0' encoding='ASCII'?>", | ||
"<SMIRNOFF version='0.3' aromaticity_model='OEAroModel_MDL'>", | ||
" <CustomHandler version='0.3'></CustomHandler>", | ||
"</SMIRNOFF>" | ||
] | ||
) | ||
|
||
# An exception should be raised when plugins aren't allowed. | ||
with pytest.raises(KeyError) as error_info: | ||
ForceField(force_field_contents) | ||
|
||
assert ( | ||
"Cannot find a registered parameter handler class for tag 'CustomHandler'" in error_info.value.args[0] | ||
) | ||
|
||
# Otherwise the FF should be created as expected. | ||
force_field = ForceField(force_field_contents, load_plugins=True) | ||
|
||
parameter_handler = force_field.get_parameter_handler("CustomHandler") | ||
assert parameter_handler is not None | ||
assert parameter_handler.__class__.__name__ == "CustomHandler" | ||
|
||
|
||
def test_load_handler_plugins(mock_entry_point_plugins): | ||
"""Tests that parameter handlers can be registered as plugins. | ||
""" | ||
|
||
registered_plugins = load_handler_plugins() | ||
|
||
assert len(registered_plugins) == 1 | ||
assert registered_plugins[0].__name__ == "CustomHandler" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
"""This module defines functions for loading parameter handler and parser classes which | ||
have been registered through the `entrypoint plugin system <https://packaging.python. | ||
org/guides/creating-and-discovering-plugins/#using-package-metadata>`_. | ||
.. warning :: | ||
This feature is experimental and may be removed / altered in future versions. | ||
Currently only ``ParameterHandler`` classes can be registered via the plugin | ||
system. This is possible by registering the handler class through an entry | ||
point in your projects ``setup.py`` file:: | ||
setup( | ||
... | ||
entry_points={ | ||
'openff.toolkit.plugins.handlers': ['CustomHandler = myapp:CustomHandler'] | ||
}, | ||
... | ||
) | ||
where in this example your package is named ``myapp`` and contains a class which | ||
inherits from ``ParameterHandler`` named ``CustomHandler``. | ||
""" | ||
import logging | ||
|
||
from openforcefield.typing.engines.smirnoff.parameters import ParameterHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
SUPPORTED_PLUGIN_NAMES = ["handlers"] # io_handlers could be supported in future. | ||
|
||
|
||
def _load_handler_plugins(handler_name, expected_type): | ||
"""Loads parameter handler plugins of a specified type which have been registered | ||
with the ``entrypoint`` plugin system. | ||
Parameters | ||
---------- | ||
handler_name: str | ||
The name of the hander plugin. This can currently be any of the names | ||
listed in ``SUPPORTED_PLUGIN_NAMES``. | ||
expected_type: type | ||
The expected class type of the plugin. E.g. when loading parameter io | ||
handler plugins the expected class type is ``ParameterIOHandler``. Any | ||
classes not matching the expected type will be skipped. | ||
""" | ||
import pkg_resources | ||
|
||
discovered_plugins = [] | ||
|
||
if handler_name not in SUPPORTED_PLUGIN_NAMES: | ||
raise NotImplementedError() | ||
|
||
for entry_point in pkg_resources.iter_entry_points( | ||
f"openff.toolkit.plugins.{handler_name}" | ||
): | ||
|
||
try: | ||
discovered_plugins.append(entry_point.load()) | ||
except ImportError: | ||
logger.exception(f"Could not load the {entry_point} plugin") | ||
continue | ||
|
||
valid_plugins = [] | ||
|
||
for discovered_plugin in discovered_plugins: | ||
|
||
if not issubclass(discovered_plugin, expected_type): | ||
|
||
logger.info( | ||
f"The {discovered_plugin.__name__} object has been registered as a " | ||
f"{handler_name} plugin, but does not inherit from " | ||
f"{expected_type.__name__}. This plugin will be skipped." | ||
) | ||
continue | ||
|
||
valid_plugins.append(discovered_plugin) | ||
|
||
return valid_plugins | ||
|
||
|
||
def load_handler_plugins(): | ||
"""Loads any ``ParameterHandler`` class plugins which have been registered through | ||
the ``entrypoint`` plugin system. | ||
Returns | ||
------- | ||
list of type of ParameterHandler | ||
The registered ``ParameterHandler`` plugins. | ||
""" | ||
return _load_handler_plugins("handlers", ParameterHandler) |