Skip to content

Commit

Permalink
[TOPI][x86] Cascade lake support. (#4123)
Browse files Browse the repository at this point in the history
* [TOPI][x86] Cascade lake support.

* Jenkins test debug 1.

* Testing cascade lake alone.
  • Loading branch information
anijain2305 authored and tmoreau89 committed Oct 17, 2019
1 parent 3185e4a commit 972f019
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 80 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _is_int8_hw_support(target):
Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
and above.
"""
supported_arches = {'-mcpu=skylake-avx512',}
supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
return supported_arches.intersection(set(target.options))

# Collect the dtypes.
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ def model(self):
return opt.value[7:]
return 'unknown'

@property
def mcpu(self):
"""Returns the mcpu from the target if it exists."""
mcpu = ''
if self.options is not None:
for opt in self.options:
if 'mcpu' in opt:
mcpu = opt.split('=')[1]
return mcpu

def __enter__(self):
_api_internal._EnterTargetScope(self)
return self
Expand Down
4 changes: 2 additions & 2 deletions tests/python/contrib/test_gemm_acc16.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition
import tvm
import numpy as np
from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int16
from topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int16


def benchmark_fc_int8_acc16():
Expand All @@ -40,7 +40,7 @@ def verify(target="llvm -mcpu=skylake-avx512"):
ctx = tvm.context(target, 0)
X = tvm.placeholder((m, k), name='X', dtype="uint8")
W = tvm.placeholder((n, k), name='W', dtype="int8")
pc = dot_16x1x16_int8_int8_int16()
pc = dot_16x1x16_uint8_int8_int16()
ak = tvm.reduce_axis((0, k), name='k')

packedW = tvm.placeholder((n//128, 128*(k//2), 2), name='packedW', dtype="int8")
Expand Down
6 changes: 3 additions & 3 deletions tests/python/contrib/test_gemm_acc32_vnni.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import tvm
import numpy as np
from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32_vnni
from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32
from topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake
from topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32
import pytest


Expand All @@ -46,7 +46,7 @@ def verify(target="llvm -mcpu=cascadelake"):
return

ctx = tvm.context(target, 0)
pc = dot_16x1x16_int8_int8_int32_vnni()
pc = dot_16x1x16_uint8_int8_int32_cascadelake()
ak = tvm.reduce_axis((0, k), name='k')
packedW = tvm.placeholder(
(n // 16, 16 * (k // 4), 4), name='packedW', dtype="int8")
Expand Down
110 changes: 62 additions & 48 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,57 +576,71 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
assembly = lib.get_source("asm")
return assembly

# compile conv2d for x86 (skylake) and test assembly contains *pmadd* instructions
target = "llvm -mcpu=skylake-avx512"
name = "llvm.x86.avx512.pmaddubs.w.512"
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
if llvm_id != 0:
fast_int8_dtypes = ('uint8', 'int8', 'int32')
# Sweep the input channels to check int8 robustness
for ic in range(1, 24):
asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm

for ic in range(1, 24):
asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm


# Sweep the output channels to check int8 robustness
for oc in range(2, 24):
asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW", kernel_layout='OIHW',
def _has_fast_int8_instructions(asm, target):
if 'skylake-avx512' in target:
return "pmaddubs" in asm
elif 'cascadelake' in target:
return "vpdpbusd" in asm
else:
assert False, "Target should be Skylake or Cascadelake"

# compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions
targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]
llvm_version = tvm.codegen.llvm_version_major()
for target in targets:
if llvm_version >= 8:
fast_int8_dtypes = ('uint8', 'int8', 'int32')
# Sweep the input channels to check int8 robustness
# Input channels should be a multiple of 4 internally.
for ic in [1, 4, 6]:
asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW",
kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
assert _has_fast_int8_instructions(asm, target)

for ic in [1, 4, 6]:
asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC",
kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
assert _has_fast_int8_instructions(asm, target)


# Sweep the output channels to check int8 robustness
# Output channels should be a multiple of 16 internally.
for oc in [4, 16, 20]:
asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW",
kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
assert _has_fast_int8_instructions(asm, target)

for oc in [4, 16, 20]:
asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC",
kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
assert _has_fast_int8_instructions(asm, target)

# Check that both non-divisible oc and ic work
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm
assert _has_fast_int8_instructions(asm, target)

for oc in range(2, 24):
asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC", kernel_layout='HWIO',
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm

# Check that both non-divisible oc and ic work
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm

asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm

# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('int8', 'int8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert "pmaddubs" not in asm

# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('uint8', 'uint8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert "pmaddubs" not in asm
assert _has_fast_int8_instructions(asm, target)

# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('int8', 'int8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)

# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('uint8', 'uint8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)

# Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout.
Expand Down
6 changes: 3 additions & 3 deletions topi/python/topi/x86/conv2d_avx_1x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..nn.util import infer_pad, get_pad_tuple
from ..generic import conv2d as conv2d_generic
from ..util import get_const_tuple, simplify
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .tensor_intrin import dot_16x1x16_uint8_int8_int32
from .util import get_fp32_len

def _fallback_schedule(cfg, wkl):
Expand Down Expand Up @@ -183,7 +183,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last,
int32_lanes=16,
intrin=dot_16x1x16_int8_int8_int32())
intrin=dot_16x1x16_uint8_int8_int32())


def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype):
Expand Down Expand Up @@ -282,7 +282,7 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last):
ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor)
s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner)

pc = dot_16x1x16_int8_int8_int32()
pc = dot_16x1x16_uint8_int8_int32()
s[C].tensorize(oc_inner, pc)

if C != O:
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/x86/conv2d_avx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..nn.util import infer_pad
from ..generic import conv2d as conv2d_generic
from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .tensor_intrin import dot_16x1x16_uint8_int8_int32
from .util import get_fp32_len

def _fallback_schedule(cfg, wkl):
Expand Down Expand Up @@ -209,4 +209,4 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last,
int32_lanes=16,
intrin=dot_16x1x16_int8_int8_int32())
intrin=dot_16x1x16_uint8_int8_int32())
12 changes: 5 additions & 7 deletions topi/python/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,14 @@ def _is_int8_hw_support(data_dtype, kernel_dtype):
is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8'

# 2) Check LLVM support
llvm_intrin_fast_int8 = "llvm.x86.avx512.pmaddubs.w.512"
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8)
is_llvm_support = llvm_id != 0
llvm_version = tvm.codegen.llvm_version_major()
is_llvm_support = llvm_version >= 8

# 3) Check target
target = tvm.target.current_target()
mcpu = tvm.target.current_target().mcpu
is_target_support = False
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
is_target_support = True
if mcpu == 'skylake-avx512' or mcpu == 'cascadelake':
is_target_support = True

return is_dtype_support and is_llvm_support and is_target_support

Expand Down
30 changes: 21 additions & 9 deletions topi/python/topi/x86/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,27 @@
import tvm


def dot_16x1x16_int8_int8_int32():
def dot_16x1x16_uint8_int8_int32():
"""Dispatch the most optimized intrin depending on the target"""
mcpu = tvm.target.current_target().mcpu

assert mcpu in ("skylake-avx512", "cascadelake"), \
"An old Intel machine that does not have fast Int8 support."
if mcpu == "skylake-avx512":
return dot_16x1x16_uint8_int8_int32_skylake()
# cascadelake
return dot_16x1x16_uint8_int8_int32_cascadelake()


def dot_16x1x16_uint8_int8_int32_skylake():
"""
Int8 dot product by every 4 elements using AVX512 Skylake instructions.
This function takes two arrays of int8 datatype -- data[4] and
This function takes two arrays of uint8 and int8 datatype -- data[4] and
kernel[16][4] -- and computes a dot product of data[4] with every
4 elements of kernels, resulting in output[16] of int32 datatype.
The pseudo code is as follows.
.. code-block:: c
void dot_16x1x16_int8_int8_int32(int8 data[4], int8 kernel[16][4],
void dot_16x1x16_uint8_int8_int32(uint8 data[4], int8 kernel[16][4],
int32 output[16]){
for (int i = 0; i < 16; i++){
output[i] = 0;
Expand Down Expand Up @@ -100,15 +112,15 @@ def _instr(index):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})


def dot_16x1x16_int8_int8_int16():
def dot_16x1x16_uint8_int8_int16():
"""
Int8 dot product by every 2 elements using AVX512 Skylake instructions.
This function takes two arrays of int8 datatype -- data[2] and
This function takes two arrays of uint8 and int8 datatype -- data[2] and
kernel[4][32][2] -- and computes a dot product of data[2] with every
2 elements of kernels, resulting in output[4][32] of int16 datatype.
The pseudo code is as follows.
.. code-block:: c
void dot_16x1x16_int8_int8_int16(int8 data[2], int8 kernel[32*4][2],
void dot_16x1x16_uint8_int8_int16(uint8 data[2], int8 kernel[32*4][2],
int16 output[32*4]){
for (int i = 0; i< 4; i++){
for (int j = 0; j < 32; j++){
Expand Down Expand Up @@ -182,15 +194,15 @@ def _instr(index):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})


def dot_16x1x16_int8_int8_int32_vnni():
def dot_16x1x16_uint8_int8_int32_cascadelake():
"""
Int8 dot product by every 4 elements using AVX512VNNI Cascade Lake instructions.
This function takes two arrays of int8 datatype -- data[4] and
This function takes two arrays of uint8 and int8 datatype -- data[4] and
kernel[16][4] -- and computes a dot product of data[4] with every
4 elements of kernels, resulting in output[16] of int32 datatype.
The pseudo code is as follows.
.. code-block:: c
void dot_16x1x16_int8_int8_int32_vnni(int8 data[4], int8 kernel[16][4],
void dot_16x1x16_uint8_int8_int32_cascadelake(uint8 data[4], int8 kernel[16][4],
int32 output[16]){
for (int i = 0; i < 16; i++){
output[i] = 0;
Expand Down
8 changes: 3 additions & 5 deletions topi/python/topi/x86/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
import tvm

def get_fp32_len():
mcpu = tvm.target.current_target().mcpu
fp32_vec_len = 8
target = tvm.target.current_target()
if target is not None:
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
fp32_vec_len = 16
if mcpu == 'skylake-avx512' or mcpu == 'cascadelake':
fp32_vec_len = 16
return fp32_vec_len

0 comments on commit 972f019

Please sign in to comment.