diff --git a/fastcore/__init__.py b/fastcore/__init__.py
index bde00312..5b601886 100644
--- a/fastcore/__init__.py
+++ b/fastcore/__init__.py
@@ -1 +1 @@
-__version__ = "1.4.6"
+__version__ = "1.5.0"
diff --git a/fastcore/_nbdev.py b/fastcore/_nbdev.py
index 13acbd97..20e0a973 100644
--- a/fastcore/_nbdev.py
+++ b/fastcore/_nbdev.py
@@ -227,10 +227,8 @@
"do_request": "03b_net.ipynb",
"start_server": "03b_net.ipynb",
"start_client": "03b_net.ipynb",
- "lenient_issubclass": "04_dispatch.ipynb",
- "sorted_topologically": "04_dispatch.ipynb",
- "TypeDispatch": "04_dispatch.ipynb",
- "DispatchReg": "04_dispatch.ipynb",
+ "FastFunction": "04_dispatch.ipynb",
+ "FastDispatcher": "04_dispatch.ipynb",
"typedispatch": "04_dispatch.ipynb",
"retain_meta": "04_dispatch.ipynb",
"default_set_meta": "04_dispatch.ipynb",
diff --git a/fastcore/dispatch.py b/fastcore/dispatch.py
index a498f667..de6f8b0d 100644
--- a/fastcore/dispatch.py
+++ b/fastcore/dispatch.py
@@ -4,154 +4,96 @@
from __future__ import annotations
-__all__ = ['lenient_issubclass', 'sorted_topologically', 'TypeDispatch', 'DispatchReg', 'typedispatch', 'cast',
- 'retain_meta', 'default_set_meta', 'retain_type', 'retain_types', 'explode_types']
+__all__ = ['FastFunction', 'FastDispatcher', 'typedispatch', 'cast', 'retain_meta', 'default_set_meta', 'retain_type',
+ 'retain_types', 'explode_types']
# Cell
#nbdev_comment from __future__ import annotations
from .imports import *
from .foundation import *
from .utils import *
+from .meta import delegates
from collections import defaultdict
+from plum import Function, Dispatcher
# Cell
-def lenient_issubclass(cls, types):
- "If possible return whether `cls` is a subclass of `types`, otherwise return False."
- if cls is object and types is not object: return False # treat `object` as highest level
- try: return isinstance(cls, types) or issubclass(cls, types)
- except: return False
+def _eval_annotations(f):
+ "Evaluate future annotations before passing to plum to support backported union operator `|`"
+ f = copy_func(f)
+ for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v
+ return f
# Cell
-def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):
- "Return a new list containing all items from the iterable sorted topologically"
- l,res = L(list(iterable)),[]
- for _ in range(len(l)):
- t = l.reduce(lambda x,y: y if cmp(y,x) else x)
- res.append(t), l.remove(t)
- return res[::-1] if reverse else res
+def _pt_repr(o):
+ "Concise repr of plum types"
+ n = type(o).__name__
+ if n == 'Tuple': return f"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]"
+ if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'
+ if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'
+ if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'
+ if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'
+ if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))
+ assert len(o.get_types()) == 1
+ return o.get_types()[0].__name__
# Cell
-def _chk_defaults(f, ann):
- pass
-# Implementation removed until we can figure out how to do this without `inspect` module
-# try: # Some callables don't have signatures, so ignore those errors
-# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]
-# if any(p.default!=inspect.Parameter.empty for p in params):
-# warn(f"{f.__name__} has default params. These will be ignored.")
-# except ValueError: pass
-
-# Cell
-def _p2_anno(f):
- "Get the 1st 2 annotations of `f`, defaulting to `object`"
- hints = type_hints(f)
- ann = [o for n,o in hints.items() if n!='return']
- if callable(f): _chk_defaults(f, ann)
- while len(ann)<2: ann.append(object)
- return ann[:2]
+class FastFunction(Function):
+ def __repr__(self):
+ return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}"
+ for s, (f, r) in self.methods.items())
-# Cell
-class _TypeDict:
- def __init__(self): self.d,self.cache = {},{}
-
- def _reset(self):
- self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}
- self.cache = {}
-
- def add(self, t, f):
- "Add type `t` and function `f`"
- if not isinstance(t, tuple): t = tuple(L(union2tuple(t)))
- for t_ in t: self.d[t_] = f
- self._reset()
-
- def all_matches(self, k):
- "Find first matching type that is a super-class of `k`"
- if k not in self.cache:
- types = [f for f in self.d if lenient_issubclass(k,f)]
- self.cache[k] = [self.d[o] for o in types]
- return self.cache[k]
-
- def __getitem__(self, k):
- "Find first matching type that is a super-class of `k`"
- res = self.all_matches(k)
- return res[0] if len(res) else None
-
- def __repr__(self): return self.d.__repr__()
- def first(self): return first(self.d.values())
+ @delegates(Function.dispatch)
+ def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs)
-# Cell
-class TypeDispatch:
- "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
- def __init__(self, funcs=(), bases=()):
- self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))
- for o in L(funcs): self.add(o)
- self.inst = None
- self.owner = None
-
- def add(self, f):
- "Add type `t` and function `f`"
- if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)
- else: a0,a1 = _p2_anno(f)
- t = self.funcs.d.get(a0)
- if t is None:
- t = _TypeDict()
- self.funcs.add(a0, t)
- t.add(a1, f)
-
- def first(self):
- "Get first function in ordered dict of type:func."
- return self.funcs.first().first()
-
- def returns(self, x):
- "Get the return type of annotation of `x`."
- return anno_ret(self[type(x)])
-
- def _attname(self,k): return getattr(k,'__name__',str(k))
- def __repr__(self):
- r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}'
- for k in self.funcs.d for l,v in self.funcs[k].d.items()]
- r = r + [o.__repr__() for o in self.bases]
- return '\n'.join(r)
-
- def __call__(self, *args, **kwargs):
- ts = L(args).map(type)[:2]
- f = self[tuple(ts)]
- if not f: return args[0]
- if isinstance(f, staticmethod): f = f.__func__
- elif self.inst is not None: f = MethodType(f, self.inst)
- elif self.owner is not None: f = MethodType(f, self.owner)
- return f(*args, **kwargs)
-
- def __get__(self, inst, owner):
- self.inst = inst
- self.owner = owner
- return self
-
- def __getitem__(self, k):
- "Find first matching type that is a super-class of `k`"
- k = L(k)
- while len(k)<2: k.append(object)
- r = self.funcs.all_matches(k[0])
- for t in r:
- o = t[k[1]]
- if o is not None: return o
- for base in self.bases:
- res = base[k]
- if res is not None: return res
- return None
+ def __getitem__(self, ts):
+ "Return the most-specific matching method with fewest parameters"
+ ts = L(ts)
+ nargs = min(len(o) for o in self.methods.keys())
+ while len(ts) < nargs: ts.append(object)
+ return self.invoke(*ts)
# Cell
-class DispatchReg:
- "A global registry for `TypeDispatch` objects keyed by function name"
- def __init__(self): self.d = defaultdict(TypeDispatch)
- def __call__(self, f):
- if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'
- else: nm = f'{f.__qualname__}'
- if isinstance(f, classmethod): f=f.__func__
- self.d[nm].add(f)
- return self.d[nm]
-
-typedispatch = DispatchReg()
+class FastDispatcher(Dispatcher):
+ def _get_function(self, method, owner):
+ "Adapted from `Dispatcher._get_function` to use `FastFunction`"
+ name = method.__name__
+ if owner:
+ if owner not in self._classes: self._classes[owner] = {}
+ namespace = self._classes[owner]
+ else: namespace = self._functions
+ if name not in namespace: namespace[name] = FastFunction(method, owner=owner)
+ return namespace[name]
+
+ @delegates(Dispatcher.__call__, but='method')
+ def __call__(self, f, **kwargs): return super().__call__(_eval_annotations(f), **kwargs)
+
+ def _to(self, cls, nm, f, **kwargs):
+ nf = copy_func(f)
+ nf.__qualname__ = f'{cls.__name__}.{nm}' # plum uses __qualname__ to infer f's owner
+ pf = self(nf, **kwargs)
+ # plum uses __set_name__ to resolve a plum.Function's owner
+ # since we assign after class creation, __set_name__ must be called directly
+ # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__
+ pf.__set_name__(cls, nm)
+ pf = pf.resolve()
+ setattr(cls, nm, pf)
+ return pf
+
+ def to(self, cls):
+ "Decorator: dispatch `f` to `cls.f`"
+ def _inner(f, **kwargs):
+ nm = f.__name__
+ # check __dict__ to avoid inherited methods but use getattr so pf.__get__ is called, which plum relies on
+ if nm in cls.__dict__:
+ pf = getattr(cls, nm)
+ if not hasattr(pf, 'dispatch'): pf = self._to(cls, nm, pf, **kwargs)
+ pf.dispatch(f)
+ else: pf = self._to(cls, nm, f, **kwargs)
+ return pf
+ return _inner
+
+typedispatch = FastDispatcher()
# Cell
#nbdev_comment _all_=['cast']
diff --git a/fastcore/imports.py b/fastcore/imports.py
index 086c899f..12665e59 100644
--- a/fastcore/imports.py
+++ b/fastcore/imports.py
@@ -1,5 +1,6 @@
import sys,os,re,typing,itertools,operator,functools,math,warnings,functools,io,enum
+from copy import copy
from operator import itemgetter,attrgetter
from warnings import warn
from typing import Iterable,Generator,Sequence,Iterator,List,Set,Dict,Union,Optional,Tuple
@@ -14,6 +15,15 @@
MethodDescriptorType = type(str.join)
from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType,SimpleNamespace
+#Patch autoreload (if its loaded) to work with plum
+try: from IPython import get_ipython
+except ImportError: pass
+else:
+ ip = get_ipython()
+ if ip is not None and 'IPython.extensions.storemagic' in ip.extension_manager.loaded:
+ from plum.autoreload import activate
+ activate()
+
NoneType = type(None)
string_classes = (str,bytes)
diff --git a/fastcore/transform.py b/fastcore/transform.py
index b8bba612..d5c86160 100644
--- a/fastcore/transform.py
+++ b/fastcore/transform.py
@@ -9,6 +9,7 @@
from .utils import *
from .dispatch import *
import inspect
+from plum import add_conversion_method
# Cell
_tfm_methods = 'encodes','decodes','setups'
@@ -16,29 +17,20 @@
def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)
class _TfmDict(dict):
- def __setitem__(self, k, v):
- if not _is_tfm_method(k, v): return super().__setitem__(k,v)
- if k not in self: super().__setitem__(k,TypeDispatch())
- self[k].add(v)
+ def __setitem__(self, k, v): super().__setitem__(k, typedispatch(v) if _is_tfm_method(k, v) else v)
# Cell
class _TfmMeta(type):
def __new__(cls, name, bases, dict):
- res = super().__new__(cls, name, bases, dict)
- for nm in _tfm_methods:
- base_td = [getattr(b,nm,None) for b in bases]
- if nm in res.__dict__: getattr(res,nm).bases = base_td
- else: setattr(res, nm, TypeDispatch(bases=base_td))
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
+ res = super().__new__(cls, name, bases, dict)
res.__signature__ = inspect.signature(res.__init__)
return res
def __call__(cls, *args, **kwargs):
f = first(args)
n = getattr(f, '__name__', None)
- if _is_tfm_method(n, f):
- getattr(cls,n).add(f)
- return f
+ if _is_tfm_method(n, f): return typedispatch.to(cls)(f)
obj = super().__call__(*args, **kwargs)
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
# instances of cls, fix it
@@ -67,13 +59,14 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None):
self.init_enc = enc or dec
if not self.init_enc: return
- self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()
+ def identity(x): return x
+ for n in _tfm_methods: setattr(self,n,FastFunction(identity).dispatch(identity))
if enc:
- self.encodes.add(enc)
+ self.encodes.dispatch(enc)
self.order = getattr(enc,'order',self.order)
if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values()))
self._name = _get_name(enc)
- if dec: self.decodes.add(dec)
+ if dec: self.decodes.dispatch(dec)
@property
def name(self): return getattr(self, '_name', _get_name(self))
@@ -92,13 +85,24 @@ def _call(self, fn, x, split_idx=None, **kwargs):
def _do_call(self, f, x, **kwargs):
if not _is_tuple(x):
if f is None: return x
- ret = f.returns(x) if hasattr(f,'returns') else None
- return retain_type(f(x, **kwargs), x, ret)
+ ts = [type(self),type(x)] if hasattr(f,'instance') else [type(x)]
+ _, ret = f.resolve_method(*ts)
+ ret = ret._type
+ # plum reads empty return annotation as object, retain_type expects it as None
+ if ret is object: ret = None
+ return retain_type(f(x,**kwargs), x, ret)
res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
return retain_type(res, x)
+ def encodes(self, x): return x
+ def decodes(self, x): return x
+ def setups(self, dl): return dl
add_docs(Transform, decode="Delegate to decodes
to undo transform", setup="Delegate to setups
to set up transform")
+# Cell
+#Implement the Transform convention that a None return annotation disables conversion
+add_conversion_method(object, NoneType, lambda x: x)
+
# Cell
class InplaceTransform(Transform):
"A `Transform` that modifies in-place and just returns whatever it's passed"
diff --git a/nbs/01_basics.ipynb b/nbs/01_basics.ipynb
index 21d6fb6c..756ca23d 100644
--- a/nbs/01_basics.ipynb
+++ b/nbs/01_basics.ipynb
@@ -804,7 +804,7 @@
{
"data": {
"text/markdown": [
- "
noop
[source]noop
[source]noop
(**`x`**=*`None`*, **\\*`args`**, **\\*\\*`kwargs`**)\n",
"\n",
@@ -840,7 +840,7 @@
{
"data": {
"text/markdown": [
- "noops
[source]noops
[source]noops
(**`self`**, **`x`**=*`None`*, **\\*`args`**, **\\*\\*`kwargs`**)\n",
"\n",
@@ -5637,7 +5637,7 @@
{
"data": {
"text/markdown": [
- "ipython_shell
[source]ipython_shell
[source]ipython_shell
()\n",
"\n",
@@ -5663,7 +5663,7 @@
{
"data": {
"text/markdown": [
- "in_ipython
[source]in_ipython
[source]in_ipython
()\n",
"\n",
@@ -5689,7 +5689,7 @@
{
"data": {
"text/markdown": [
- "in_colab
[source]in_colab
[source]in_colab
()\n",
"\n",
@@ -5715,7 +5715,7 @@
{
"data": {
"text/markdown": [
- "in_jupyter
[source]in_jupyter
[source]in_jupyter
()\n",
"\n",
@@ -5741,7 +5741,7 @@
{
"data": {
"text/markdown": [
- "in_notebook
[source]in_notebook
[source]in_notebook
()\n",
"\n",
diff --git a/nbs/04_dispatch.ipynb b/nbs/04_dispatch.ipynb
index 08b6bcbe..6bf65d2d 100644
--- a/nbs/04_dispatch.ipynb
+++ b/nbs/04_dispatch.ipynb
@@ -20,8 +20,10 @@
"from fastcore.imports import *\n",
"from fastcore.foundation import *\n",
"from fastcore.utils import *\n",
+ "from fastcore.meta import delegates\n",
"\n",
- "from collections import defaultdict"
+ "from collections import defaultdict\n",
+ "from plum import Function, Dispatcher"
]
},
{
@@ -41,42 +43,18 @@
"source": [
"# Type dispatch\n",
"\n",
- "> Basic single and dual parameter dispatch"
+ "> Multiple dispatch, extending [plum](https://github.com/wesselb/plum)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Helpers"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#export\n",
- "def lenient_issubclass(cls, types):\n",
- " \"If possible return whether `cls` is a subclass of `types`, otherwise return False.\"\n",
- " if cls is object and types is not object: return False # treat `object` as highest level\n",
- " try: return isinstance(cls, types) or issubclass(cls, types)\n",
- " except: return False"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "assert not lenient_issubclass(typing.Collection, list)\n",
- "assert lenient_issubclass(list, typing.Collection)\n",
- "assert lenient_issubclass(typing.Collection, object)\n",
- "assert lenient_issubclass(typing.List, typing.Collection)\n",
- "assert not lenient_issubclass(typing.Collection, typing.List)\n",
- "assert not lenient_issubclass(object, typing.Callable)"
+ "Type dispatch, or [multiple dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia), allows you to change the way a function behaves based on the input types it receives. This is a prominent feature in some programming languages like [Julia](https://docs.julialang.org/en/v1/manual/methods/).\n",
+ "\n",
+ "Type dispatch allows you to have a common API for functions that do similar tasks. This is especially useful in data science, where the same operation (e.g. normalize, categorize) requires an implementation that depends on its input type (e.g. numpy array, pandas dataframe, pytorch tensor).\n",
+ "\n",
+ "Fastcore uses and extends the wonderful [plum](https://github.com/wesselb/plum) library's implementation of multiple dispatch for Python. Be sure to view their [informative documentation](https://github.com/wesselb/plum#basic-usage) as well."
]
},
{
@@ -86,34 +64,11 @@
"outputs": [],
"source": [
"#export\n",
- "def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):\n",
- " \"Return a new list containing all items from the iterable sorted topologically\"\n",
- " l,res = L(list(iterable)),[]\n",
- " for _ in range(len(l)):\n",
- " t = l.reduce(lambda x,y: y if cmp(y,x) else x)\n",
- " res.append(t), l.remove(t)\n",
- " return res[::-1] if reverse else res"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "td = [3, 1, 2, 5]\n",
- "test_eq(sorted_topologically(td), [1, 2, 3, 5])\n",
- "test_eq(sorted_topologically(td, reverse=True), [5, 3, 2, 1])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "td = {int:1, numbers.Number:2, numbers.Integral:3}\n",
- "test_eq(sorted_topologically(td, cmp=lenient_issubclass), [int, numbers.Integral, numbers.Number])"
+ "def _eval_annotations(f):\n",
+ " \"Evaluate future annotations before passing to plum to support backported union operator `|`\"\n",
+ " f = copy_func(f)\n",
+ " for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v\n",
+ " return f"
]
},
{
@@ -122,26 +77,13 @@
"metadata": {},
"outputs": [],
"source": [
- "td = [numbers.Integral, tuple, list, int, dict]\n",
- "td = sorted_topologically(td, cmp=lenient_issubclass)\n",
- "assert td.index(int) < td.index(numbers.Integral)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#export\n",
- "def _chk_defaults(f, ann):\n",
- " pass\n",
- "# Implementation removed until we can figure out how to do this without `inspect` module\n",
- "# try: # Some callables don't have signatures, so ignore those errors\n",
- "# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]\n",
- "# if any(p.default!=inspect.Parameter.empty for p in params):\n",
- "# warn(f\"{f.__name__} has default params. These will be ignored.\")\n",
- "# except ValueError: pass"
+ "#hide\n",
+ "def f(x:int|str) -> float: pass\n",
+ "test_eq(_eval_annotations(f).__annotations__, {'x': typing.Union[int, str], 'return': float})\n",
+ "def f(x:(int,str)) -> float: pass\n",
+ "test_eq(_eval_annotations(f).__annotations__, {'x': typing.Union[int, str], 'return': float})\n",
+ "def f(x): pass\n",
+ "test_eq(_eval_annotations(f).__annotations__, {})"
]
},
{
@@ -151,13 +93,17 @@
"outputs": [],
"source": [
"#export\n",
- "def _p2_anno(f):\n",
- " \"Get the 1st 2 annotations of `f`, defaulting to `object`\"\n",
- " hints = type_hints(f)\n",
- " ann = [o for n,o in hints.items() if n!='return']\n",
- " if callable(f): _chk_defaults(f, ann)\n",
- " while len(ann)<2: ann.append(object)\n",
- " return ann[:2]"
+ "def _pt_repr(o):\n",
+ " \"Concise repr of plum types\"\n",
+ " n = type(o).__name__\n",
+ " if n == 'Tuple': return f\"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]\"\n",
+ " if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'\n",
+ " if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'\n",
+ " if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'\n",
+ " if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'\n",
+ " if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))\n",
+ " assert len(o.get_types()) == 1\n",
+ " return o.get_types()[0].__name__"
]
},
{
@@ -167,116 +113,26 @@
"outputs": [],
"source": [
"#hide\n",
- "def _f(a): pass\n",
- "test_eq(_p2_anno(_f), (object,object))\n",
- "def _f(a, b): pass\n",
- "test_eq(_p2_anno(_f), (object,object))\n",
- "def _f(a:None, b)->str: pass\n",
- "test_eq(_p2_anno(_f), (NoneType,object))\n",
- "def _f(a:str, b)->float: pass\n",
- "test_eq(_p2_anno(_f), (str,object))\n",
- "def _f(a:None, b:str)->float: pass\n",
- "test_eq(_p2_anno(_f), (NoneType,str))\n",
- "def _f(a:int, b:int)->float: pass\n",
- "test_eq(_p2_anno(_f), (int,int))\n",
- "def _f(self, a:int, b:int): pass\n",
- "test_eq(_p2_anno(_f), (int,int))\n",
- "def _f(a:int, b:str)->float: pass\n",
- "test_eq(_p2_anno(_f), (int,str))\n",
- "test_eq(_p2_anno(attrgetter('foo')), (object,object))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "([object, object], [int, object])"
- ]
- },
- "execution_count": null,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "#hide\n",
- "# Disabled until _chk_defaults fixed\n",
- "# def _f(x:int,y:int=10): pass\n",
- "# test_warns(lambda: _p2_anno(_f))\n",
- "def _f(x:int,y=10): pass\n",
- "_p2_anno(None),_p2_anno(_f)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## TypeDispatch"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Type dispatch, or [Multiple dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia), allows you to change the way a function behaves based upon the input types it recevies. This is a prominent feature in some programming languages like Julia. For example, this is a [conceptual example](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia) of how multiple dispatch works in Julia, returning different values depending on the input types of x and y:\n",
- "\n",
- "```julia\n",
- "collide_with(x::Asteroid, y::Asteroid) = ... \n",
- "# deal with asteroid hitting asteroid\n",
- "\n",
- "collide_with(x::Asteroid, y::Spaceship) = ... \n",
- "# deal with asteroid hitting spaceship\n",
- "\n",
- "collide_with(x::Spaceship, y::Asteroid) = ... \n",
- "# deal with spaceship hitting asteroid\n",
+ "from typing import Dict, List, Iterable, Sequence, Tuple\n",
+ "from plum.type import VarArgs, ptype\n",
"\n",
- "collide_with(x::Spaceship, y::Spaceship) = ... \n",
- "# deal with spaceship hitting spaceship\n",
- "```\n",
- "\n",
- "Type dispatch can be especially useful in data science, where you might allow different input types (i.e. numpy arrays and pandas dataframes) to function that processes data. Type dispatch allows you to have a common API for functions that do similar tasks.\n",
- "\n",
- "The `TypeDispatch` class allows us to achieve type dispatch in Python. It contains a dictionary that maps types from type annotations to functions, which ensures that the proper function is called when passed inputs."
+ "test_eq(_pt_repr(ptype(int)), 'int')\n",
+ "test_eq(_pt_repr(ptype(Union[int, str])), 'int|str')\n",
+ "test_eq(_pt_repr(ptype(Tuple[int, str])), 'tuple[int,str]')\n",
+ "test_eq(_pt_repr(ptype(List[int])), 'list[int]')\n",
+ "test_eq(_pt_repr(ptype(Sequence[int])), 'Sequence[int]')\n",
+ "test_eq(_pt_repr(ptype(Iterable[int])), 'Iterable[int]')\n",
+ "test_eq(_pt_repr(ptype(Dict[str, int])), 'dict[str,int]')\n",
+ "test_eq(_pt_repr(ptype(VarArgs[str])), 'VarArgs[str]')\n",
+ "test_eq(_pt_repr(ptype(Dict[Tuple[Union[int,str],float], List[Tuple[object]]])),\n",
+ " 'dict[tuple[int|str,float],list[tuple[object]]]')"
]
},
{
- "cell_type": "code",
- "execution_count": null,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "#export\n",
- "class _TypeDict:\n",
- " def __init__(self): self.d,self.cache = {},{}\n",
- "\n",
- " def _reset(self):\n",
- " self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}\n",
- " self.cache = {}\n",
- "\n",
- " def add(self, t, f):\n",
- " \"Add type `t` and function `f`\"\n",
- " if not isinstance(t, tuple): t = tuple(L(union2tuple(t)))\n",
- " for t_ in t: self.d[t_] = f\n",
- " self._reset()\n",
- "\n",
- " def all_matches(self, k):\n",
- " \"Find first matching type that is a super-class of `k`\"\n",
- " if k not in self.cache:\n",
- " types = [f for f in self.d if lenient_issubclass(k,f)]\n",
- " self.cache[k] = [self.d[o] for o in types]\n",
- " return self.cache[k]\n",
- "\n",
- " def __getitem__(self, k):\n",
- " \"Find first matching type that is a super-class of `k`\"\n",
- " res = self.all_matches(k)\n",
- " return res[0] if len(res) else None\n",
- "\n",
- " def __repr__(self): return self.d.__repr__()\n",
- " def first(self): return first(self.d.values())"
+ "## FastFunction -"
]
},
{
@@ -286,92 +142,34 @@
"outputs": [],
"source": [
"#export\n",
- "class TypeDispatch:\n",
- " \"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`\"\n",
- " def __init__(self, funcs=(), bases=()):\n",
- " self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))\n",
- " for o in L(funcs): self.add(o)\n",
- " self.inst = None\n",
- " self.owner = None\n",
- "\n",
- " def add(self, f):\n",
- " \"Add type `t` and function `f`\"\n",
- " if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)\n",
- " else: a0,a1 = _p2_anno(f)\n",
- " t = self.funcs.d.get(a0)\n",
- " if t is None:\n",
- " t = _TypeDict()\n",
- " self.funcs.add(a0, t)\n",
- " t.add(a1, f)\n",
- "\n",
- " def first(self):\n",
- " \"Get first function in ordered dict of type:func.\"\n",
- " return self.funcs.first().first()\n",
- "\n",
- " def returns(self, x):\n",
- " \"Get the return type of annotation of `x`.\"\n",
- " return anno_ret(self[type(x)])\n",
- "\n",
- " def _attname(self,k): return getattr(k,'__name__',str(k))\n",
+ "class FastFunction(Function):\n",
" def __repr__(self):\n",
- " r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, \"__name__\", type(v).__name__)}'\n",
- " for k in self.funcs.d for l,v in self.funcs[k].d.items()]\n",
- " r = r + [o.__repr__() for o in self.bases]\n",
- " return '\\n'.join(r)\n",
- "\n",
- " def __call__(self, *args, **kwargs):\n",
- " ts = L(args).map(type)[:2]\n",
- " f = self[tuple(ts)]\n",
- " if not f: return args[0]\n",
- " if isinstance(f, staticmethod): f = f.__func__\n",
- " elif self.inst is not None: f = MethodType(f, self.inst)\n",
- " elif self.owner is not None: f = MethodType(f, self.owner)\n",
- " return f(*args, **kwargs)\n",
+ " return '\\n'.join(f\"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}\"\n",
+ " for s, (f, r) in self.methods.items())\n",
"\n",
- " def __get__(self, inst, owner):\n",
- " self.inst = inst\n",
- " self.owner = owner\n",
- " return self\n",
+ " @delegates(Function.dispatch)\n",
+ " def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs)\n",
"\n",
- " def __getitem__(self, k):\n",
- " \"Find first matching type that is a super-class of `k`\"\n",
- " k = L(k)\n",
- " while len(k)<2: k.append(object)\n",
- " r = self.funcs.all_matches(k[0])\n",
- " for t in r:\n",
- " o = t[k[1]]\n",
- " if o is not None: return o\n",
- " for base in self.bases:\n",
- " res = base[k]\n",
- " if res is not None: return res\n",
- " return None"
+ " def __getitem__(self, ts):\n",
+ " \"Return the most-specific matching method with fewest parameters\"\n",
+ " ts = L(ts)\n",
+ " nargs = min(len(o) for o in self.methods.keys())\n",
+ " while len(ts) < nargs: ts.append(object)\n",
+ " return self.invoke(*ts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "To demonstrate how `TypeDispatch` works, we define a set of functions that accept a variety of input types, specified with different type annotations:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def f2(x:int, y:float): return x+y #int and float for 2nd arg\n",
- "def f_nin(x:numbers.Integral)->int: return x+1 #integral numeric\n",
- "def f_ni2(x:int): return x #integer\n",
- "def f_bll(x:bool|list): return x #bool or list\n",
- "def f_num(x:numbers.Number): return x #Number (root of numerics) "
+ "`FastFunction` extends `plum.Function` with the following functionality."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "We can optionally initialize `TypeDispatch` with a list of functions we want to search. Printing an instance of `TypeDispatch` will display convenient mapping of types -> functions:"
+ "`FastFunction` has a concise `repr`:"
]
},
{
@@ -382,12 +180,7 @@
{
"data": {
"text/plain": [
- "(bool,object) -> f_bll\n",
- "(int,object) -> f_ni2\n",
- "(Integral,object) -> f_nin\n",
- "(Number,object) -> f_num\n",
- "(list,object) -> f_bll\n",
- "(object,object) -> NoneType"
+ "f: (int) -> float"
]
},
"execution_count": null,
@@ -396,39 +189,16 @@
}
],
"source": [
- "t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])\n",
- "t"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Note that only the first two arguments are used for `TypeDispatch`. If your function only contains one argument, the second parameter will be shown as `object`. If you pass `None` into `TypeDispatch`, then this will be displayed as `(object, object) -> NoneType`.\n",
- "\n",
- "`TypeDispatch` is a dictionary-like object, which means that you can retrieve a function by the associated type annotation. For example, the statement:\n",
- "\n",
- "```py\n",
- "t[float]\n",
- "```\n",
- "Will return `f_num` because that is the matching function that has a type annotation that is a super-class of of `float` - `numbers.Number`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "assert issubclass(float, numbers.Number)\n",
- "test_eq(t[float], f_num)"
+ "def f(x: int) -> float: pass\n",
+ "f = FastFunction(f).dispatch(f)\n",
+ "f"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "The same is true for other types as well:"
+ "`FastFunction` supports fastcore's backport of the `|` operator on types:"
]
},
{
@@ -437,169 +207,20 @@
"metadata": {},
"outputs": [],
"source": [
- "test_eq(t[np.int32], f_nin)\n",
- "test_eq(t[bool], f_bll)\n",
- "test_eq(t[list], f_bll)\n",
- "test_eq(t[np.int32], f_nin)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "If you try to get a type that doesn't match, `TypeDispatch` will return `None`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "test_eq(t[str], None)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/markdown": [
- "TypeDispatch.add
[source]TypeDispatch.add
(**`f`**)\n",
- "\n",
- "Add type `t` and function `f`"
- ],
- "text/plain": [
- "TypeDispatch.__call__
[source]TypeDispatch.__call__
(**\\*`args`**, **\\*\\*`kwargs`**)\n",
- "\n",
- "Call self as a function."
- ],
- "text/plain": [
- "TypeDispatch.returns
[source]TypeDispatch.returns
(**`x`**)\n",
- "\n",
- "Get the return type of annotation of `x`."
- ],
- "text/plain": [
- "decodes
to undo transform\", setup=\"Delegate to setups
to set up transform\")"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#export\n",
+ "#Implement the Transform convention that a None return annotation disables conversion\n",
+ "add_conversion_method(object, NoneType, lambda x: x)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -1060,8 +1071,7 @@
"data": {
"text/plain": [
"A:\n",
- "encodes: (object,object) -> noop\n",
- "decodes: (object,object) -> noop"
+ "encodes: noop: (object,VarArgs[object]) -> objectdecodes: noop: (object,VarArgs[object]) -> object"
]
},
"execution_count": null,
@@ -1091,8 +1101,7 @@
"data": {
"text/plain": [
"A -- {'a': 1, 'b': 2}:\n",
- "encodes: (object,object) -> noop\n",
- "decodes: "
+ "encodes: noop: (object,VarArgs[object]) -> objectdecodes: decodes: (object,object) -> object"
]
},
"execution_count": null,
@@ -1978,13 +1987,6 @@
"from nbdev.export import notebook2script\n",
"notebook2script()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/settings.ini b/settings.ini
index 625c459c..b5e60ec8 100644
--- a/settings.ini
+++ b/settings.ini
@@ -7,7 +7,7 @@ author = Jeremy Howard and Sylvain Gugger
author_email = infos@fast.ai
copyright = fast.ai
branch = master
-version = 1.4.6
+version = 1.5.0
min_python = 3.7
audience = Developers
language = English
diff --git a/setup.py b/setup.py
index 1fa9526b..aea17800 100644
--- a/setup.py
+++ b/setup.py
@@ -26,7 +26,7 @@
min_python = cfg['min_python']
lic = licenses[cfg['license']]
-requirements = ['pip', 'packaging']
+requirements = ['pip', 'packaging', 'plum-dispatch>=1.6']
if cfg.get('requirements'): requirements += cfg.get('requirements','').split()
if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split()
dev_requirements = (cfg.get('dev_requirements') or '').split()