Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llvm: Initialize builder context and jit engines on demand #2188

Merged
merged 6 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 37 additions & 60 deletions psyneulink/core/llvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from llvmlite import ir

from . import builtins
from . import codegen
from .builder_context import *
from .builder_context import _all_modules, _convert_llvm_ir_to_ctype
Expand Down Expand Up @@ -47,15 +46,11 @@ class ExecutionMode(enum.Flag):


def _compiled_modules() -> Set[ir.Module]:
if ptx_enabled:
return _cpu_engine.compiled_modules | _ptx_engine.compiled_modules
return _cpu_engine.compiled_modules
return set().union(*(e.compiled_modules for e in _get_engines()))


def _staged_modules() -> Set[ir.Module]:
if ptx_enabled:
return _cpu_engine.staged_modules | _ptx_engine.staged_modules
return _cpu_engine.staged_modules
return set().union(*(e.staged_modules for e in _get_engines()))


def _llvm_build(target_generation=_binary_generation + 1):
Expand All @@ -68,9 +63,8 @@ def _llvm_build(target_generation=_binary_generation + 1):
if "compile" in debug_env:
print("STAGING GENERATION: {} -> {}".format(_binary_generation, target_generation))

_cpu_engine.stage_compilation(_modules)
if ptx_enabled:
_ptx_engine.stage_compilation(_modules)
for e in _get_engines():
e.stage_compilation(_modules)
_modules.clear()

# update binary generation
Expand All @@ -84,6 +78,12 @@ def __init__(self, name: str):
self.__c_func = None
self.__cuda_kernel = None

# Make sure builder context is initialized
LLVMBuilderContext.get_current()

# Compile any pending modules
_llvm_build(LLVMBuilderContext._llvm_generation)

# Function signature
# We could skip compilation if the function is in _compiled_models,
# but that happens rarely
Expand All @@ -106,6 +106,9 @@ def __init__(self, name: str):
@property
def c_func(self):
if self.__c_func is None:
# This assumes there are potential staged modules.
# The engine had to be instantiated to have staged modules,
# so it's safe to access it directly
_cpu_engine.compile_staged()
ptr = _cpu_engine._engine.get_function_address(self.name)
self.__c_func = self.__c_func_type(ptr)
Expand Down Expand Up @@ -138,79 +141,53 @@ def cuda_wrap_call(self, *args, threads=1, block_size=128):
@staticmethod
@functools.lru_cache(maxsize=32)
def from_obj(obj, *, tags:frozenset=frozenset()):
name = LLVMBuilderContext.get_global().gen_llvm_function(obj, tags=tags).name
name = LLVMBuilderContext.get_current().gen_llvm_function(obj, tags=tags).name
return LLVMBinaryFunction.get(name)

@staticmethod
@functools.lru_cache(maxsize=32)
def get(name: str):
_llvm_build(LLVMBuilderContext._llvm_generation)
return LLVMBinaryFunction(name)

def get_multi_run(self):
try:
multirun_llvm = _find_llvm_function(self.name + "_multirun")
except ValueError:
function = _find_llvm_function(self.name)
with LLVMBuilderContext.get_global() as ctx:
with LLVMBuilderContext.get_current() as ctx:
multirun_llvm = codegen.gen_multirun_wrapper(ctx, function)

return LLVMBinaryFunction.get(multirun_llvm.name)


_cpu_engine = cpu_jit_engine()
if ptx_enabled:
_ptx_engine = ptx_jit_engine()


# Initialize builtins
def init_builtins():
start = time.perf_counter()
with LLVMBuilderContext.get_global() as ctx:
# Numeric
builtins.setup_pnl_intrinsics(ctx)
builtins.setup_csch(ctx)
builtins.setup_coth(ctx)
builtins.setup_tanh(ctx)
builtins.setup_is_close(ctx)

# PRNG
builtins.setup_mersenne_twister(ctx)

# Matrix/Vector
builtins.setup_vxm(ctx)
builtins.setup_vxm_transposed(ctx)
builtins.setup_vec_add(ctx)
builtins.setup_vec_sum(ctx)
builtins.setup_mat_add(ctx)
builtins.setup_vec_sub(ctx)
builtins.setup_mat_sub(ctx)
builtins.setup_vec_hadamard(ctx)
builtins.setup_mat_hadamard(ctx)
builtins.setup_vec_scalar_mult(ctx)
builtins.setup_mat_scalar_mult(ctx)
builtins.setup_mat_scalar_add(ctx)

finish = time.perf_counter()

if "time_stat" in debug_env:
print("Time to setup PNL builtins: {}".format(finish - start))
_cpu_engine = None
_ptx_engine = None

def cleanup():
_cpu_engine.clean_module()
_cpu_engine.staged_modules.clear()
_cpu_engine.compiled_modules.clear()
def _get_engines():
global _cpu_engine
if _cpu_engine is None:
_cpu_engine = cpu_jit_engine()

global _ptx_engine
if ptx_enabled:
_ptx_engine.clean_module()
_ptx_engine.staged_modules.clear()
_ptx_engine.compiled_modules.clear()
if _ptx_engine is None:
_ptx_engine = ptx_jit_engine()
return [_cpu_engine, _ptx_engine]

return [_cpu_engine]



def cleanup():
global _cpu_engine
_cpu_engine = None
global _ptx_engine
_ptx_engine = None

_modules.clear()
_all_modules.clear()

LLVMBinaryFunction.get.cache_clear()
LLVMBinaryFunction.from_obj.cache_clear()
init_builtins()


init_builtins()
LLVMBuilderContext.clear_global()
59 changes: 51 additions & 8 deletions psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import numpy as np
import os
import re
import time
from typing import Set
import weakref

from psyneulink.core.scheduling.time import Time, TimeScale
from psyneulink.core.globals.sampleiterator import SampleIterator
from psyneulink.core.globals.utilities import ContentAddressableList
from psyneulink.core import llvm as pnlvm

from . import codegen
from . import helpers
from .debug import debug_env
Expand All @@ -40,7 +43,7 @@ def module_count():
if "stat" in debug_env:
print("Total LLVM modules: ", len(_all_modules))
print("Total structures generated: ", _struct_count)
s = LLVMBuilderContext.get_global()
s = LLVMBuilderContext.get_current()
print("Total generations by global context: {}".format(s._llvm_generation))
print("Object cache in global context: {} hits, {} misses".format(s._stats["cache_requests"] - s._stats["cache_misses"], s._stats["cache_misses"]))
for stat in ("input", "output", "param", "state", "data"):
Expand Down Expand Up @@ -82,14 +85,15 @@ def wrapper(bctx, obj):


class LLVMBuilderContext:
__global_context = None
__current_context = None
__uniq_counter = 0
_llvm_generation = 0
int32_ty = ir.IntType(32)
float_ty = ir.DoubleType()
default_float_ty = ir.DoubleType()
bool_ty = ir.IntType(1)

def __init__(self):
def __init__(self, float_ty):
assert LLVMBuilderContext.__current_context is None
self._modules = []
self._cache = weakref.WeakKeyDictionary()
self._stats = { "cache_misses":0,
Expand All @@ -101,6 +105,9 @@ def __init__(self):
"input_structs_generated":0,
"output_structs_generated":0,
}
self.float_ty = float_ty
self.init_builtins()
LLVMBuilderContext.__current_context = self

def __enter__(self):
module = ir.Module(name="PsyNeuLinkModule-" + str(LLVMBuilderContext._llvm_generation))
Expand All @@ -120,17 +127,53 @@ def module(self):
return self._modules[-1]

@classmethod
def get_global(cls):
if cls.__global_context is None:
cls.__global_context = LLVMBuilderContext()
return cls.__global_context
def get_current(cls):
if cls.__current_context is None:
return LLVMBuilderContext(cls.default_float_ty)
return cls.__current_context

@classmethod
def clear_global(cls):
cls.__current_context = None

@classmethod
def get_unique_name(cls, name: str):
cls.__uniq_counter += 1
name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
return name + '_' + str(cls.__uniq_counter)

def init_builtins(self):
start = time.perf_counter()
with self as ctx:
# Numeric
pnlvm.builtins.setup_pnl_intrinsics(ctx)
pnlvm.builtins.setup_csch(ctx)
pnlvm.builtins.setup_coth(ctx)
pnlvm.builtins.setup_tanh(ctx)
pnlvm.builtins.setup_is_close(ctx)

# PRNG
pnlvm.builtins.setup_mersenne_twister(ctx)

# Matrix/Vector
pnlvm.builtins.setup_vxm(ctx)
pnlvm.builtins.setup_vxm_transposed(ctx)
pnlvm.builtins.setup_vec_add(ctx)
pnlvm.builtins.setup_vec_sum(ctx)
pnlvm.builtins.setup_mat_add(ctx)
pnlvm.builtins.setup_vec_sub(ctx)
pnlvm.builtins.setup_mat_sub(ctx)
pnlvm.builtins.setup_vec_hadamard(ctx)
pnlvm.builtins.setup_mat_hadamard(ctx)
pnlvm.builtins.setup_vec_scalar_mult(ctx)
pnlvm.builtins.setup_mat_scalar_mult(ctx)
pnlvm.builtins.setup_mat_scalar_add(ctx)

finish = time.perf_counter()

if "time_stat" in debug_env:
print("Time to setup PNL builtins: {}".format(finish - start))

def get_builtin(self, name: str, args=[], function_type=None):
if name in _builtin_intrinsics:
return self.import_llvm_function(_BUILTIN_PREFIX + name)
Expand Down
2 changes: 1 addition & 1 deletion psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _bin_func_multirun(self):

def _set_bin_node(self, node):
assert node in self._composition._all_nodes
wrapper = builder_context.LLVMBuilderContext.get_global().get_node_wrapper(self._composition, node)
wrapper = builder_context.LLVMBuilderContext.get_current().get_node_wrapper(self._composition, node)
self.__bin_func = pnlvm.LLVMBinaryFunction.from_obj(
wrapper, tags=self.__tags.union({"node_wrapper"}))

Expand Down
4 changes: 2 additions & 2 deletions psyneulink/core/llvm/jit_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _cpu_jit_constructor():
__pass_manager_builder.populate(__cpu_pass_manager)

# And an execution engine with a builtins backing module
builtins_module = _generate_cpu_builtins_module(LLVMBuilderContext.float_ty)
builtins_module = _generate_cpu_builtins_module(LLVMBuilderContext.get_current().float_ty)
if "llvm" in debug_env:
with open(builtins_module.name + '.parse.ll', 'w') as dump_file:
dump_file.write(str(builtins_module))
Expand Down Expand Up @@ -286,7 +286,7 @@ def __init__(self, tm):
self._target_machine = tm

# -dc option tells the compiler that the code will be used for linking
self._generated_builtins = pycuda.compiler.compile(_ptx_builtin_source.format(type=str(LLVMBuilderContext.float_ty)), target='cubin', options=['-dc'])
self._generated_builtins = pycuda.compiler.compile(_ptx_builtin_source.format(type=str(LLVMBuilderContext.get_current().float_ty)), target='cubin', options=['-dc'])

def set_object_cache(cache):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/llvm/test_builtins_intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_builtin_op(benchmark, op, args, builtin, result, func_mode):
f = pnlvm.LLVMBinaryFunction.get(builtin)
elif func_mode == 'PTX':
wrap_name = builtin + "_test_wrapper"
with pnlvm.LLVMBuilderContext.get_global() as ctx:
with pnlvm.LLVMBuilderContext.get_current() as ctx:
intrin = ctx.import_llvm_function(builtin)
wrap_args = (*intrin.type.pointee.args,
intrin.type.pointee.return_type.as_pointer())
Expand Down
4 changes: 2 additions & 2 deletions tests/llvm/test_builtins_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def ex():
def test_dot_llvm_constant_dim(benchmark, mode):
custom_name = None

with pnlvm.LLVMBuilderContext() as ctx:
with pnlvm.LLVMBuilderContext.get_current() as ctx:
custom_name = ctx.get_unique_name("vxsqm")
double_ptr_ty = ctx.float_ty.as_pointer()
func_ty = ir.FunctionType(ir.VoidType(), (double_ptr_ty, double_ptr_ty, double_ptr_ty))
Expand Down Expand Up @@ -180,7 +180,7 @@ def ex():
def test_dot_transposed_llvm_constant_dim(benchmark, mode):
custom_name = None

with pnlvm.LLVMBuilderContext() as ctx:
with pnlvm.LLVMBuilderContext.get_current() as ctx:
custom_name = ctx.get_unique_name("vxsqm")
double_ptr_ty = ctx.float_ty.as_pointer()
func_ty = ir.FunctionType(ir.VoidType(), (double_ptr_ty, double_ptr_ty, double_ptr_ty))
Expand Down
4 changes: 2 additions & 2 deletions tests/llvm/test_custom_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_fixed_dimensions__pnl_builtin_vxm(mode):

custom_name = None

with pnlvm.LLVMBuilderContext() as ctx:
with pnlvm.LLVMBuilderContext.get_current() as ctx:
custom_name = ctx.get_unique_name("vxsqm")
double_ptr_ty = ctx.convert_python_struct_to_llvm_ir(1.0).as_pointer()
func_ty = ir.FunctionType(ir.VoidType(), (double_ptr_ty, double_ptr_ty, double_ptr_ty))
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_fixed_dimensions__pnl_builtin_vxm(mode):
], ids=lambda x: str(x.dtype))
def test_integer_broadcast(mode, val):
custom_name = None
with pnlvm.LLVMBuilderContext() as ctx:
with pnlvm.LLVMBuilderContext.get_current() as ctx:
custom_name = ctx.get_unique_name("broadcast")
int_ty = ctx.convert_python_struct_to_llvm_ir(val)
int_array_ty = ir.ArrayType(int_ty, 8)
Expand Down
Loading