Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[MetaSchedule][Hexagon] conv2d produces different results after tuning #294

Open
psrivas2 opened this issue Dec 2, 2022 · 4 comments
Open
Assignees

Comments

@psrivas2
Copy link
Contributor

psrivas2 commented Dec 2, 2022

The following PrimFunc produces different results after tuning on hexagon.

@T.prim_func
def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
    # function attr dict
    T.func_attr({"global_symbol": "conv2d", "tir.noalias": True})
    # body
    # with T.block("root")
    for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
        with T.block("conv2d_nhwc"):
            nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
            T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
            T.writes(conv2d_nhwc[nn, yy, xx, ff])
            with T.init():
                conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
            conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff]

Post tuning the PrimFunc is transformed to:

@T.prim_func
def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
    # function attr dict
    T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
    # body
    # with T.block("root")
    conv2d_nhwc_global = T.alloc_buffer([1, 112, 112, 64], dtype="float16")
    for i0_0_i1_0_i2_0_fused in T.parallel(196, annotations={"pragma_auto_unroll_max_step":T.int64(512), "pragma_unroll_explicit":T.int64(1)}):
        for i3_0 in T.serial(1):
            for i0_1_init, i1_1_init, i2_1_init, i3_1_init, i0_2_init, i1_2_init, i2_2_init in T.grid(1, 2, 16, 1, 1, 2, 1):
                for i3_2_fused_init in T.vectorized(64):
                    with T.block("conv2d_nhwc_init"):
                        nn = T.axis.spatial(1, i0_1_init + i0_2_init)
                        yy = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + i1_1_init * 2 + i1_2_init)
                        xx = T.axis.spatial(112, i2_2_init + i0_0_i1_0_i2_0_fused % 7 * 16 + i2_1_init)
                        ff = T.axis.spatial(64, i3_0 * 64 + i3_1_init * 64 + i3_2_fused_init)
                        T.reads()
                        T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
                        T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
                        conv2d_nhwc_global[nn, yy, xx, ff] = T.float16(0)
            for i4_0, i5_0, i6_0 in T.grid(1, 7, 1):
                for i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i0_2, i1_2, i2_2 in T.grid(1, 2, 16, 1, 7, 1, 3, 1, 2, 1):
                    for i3_2_fused in T.vectorized(64):
                        with T.block("conv2d_nhwc_update"):
                            nn = T.axis.spatial(1, i0_1 + i0_2)
                            yy = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + i1_1 * 2 + i1_2)
                            xx = T.axis.spatial(112, i2_2 + i0_0_i1_0_i2_0_fused % 7 * 16 + i2_1)
                            ff = T.axis.spatial(64, i3_0 * 64 + i3_1 * 64 + i3_2_fused)
                            ry = T.axis.reduce(7, i4_0 * 7 + i4_1)
                            rx = T.axis.reduce(7, i5_0 + i5_1)
                            rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
                            T.reads(conv2d_nhwc_global[nn, yy, xx, ff], lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
                            T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
                            T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
                            conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff]
                for ax0, ax1, ax2 in T.grid(1, 4, 16):
                    for ax3_fused in T.vectorized(64):
                        with T.block("conv2d_nhwc_global"):
                            v0 = T.axis.spatial(1, ax0)
                            v1 = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + ax1)
                            v2 = T.axis.spatial(112, i0_0_i1_0_i2_0_fused % 7 * 16 + ax2)
                            v3 = T.axis.spatial(64, ax3_fused)
                            T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
                            T.writes(conv2d_nhwc[v0, v1, v2, v3])
                            conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]

The two PrimFuncs produce different results on hexagon hardware. This needs to be investigated.

@YuchenJin
Copy link
Collaborator

Thanks @psrivas2 for reporting the issue!

Two questions that could help us know more about the context:

  • Is it hexagon specific, i.e. if we tune conv2d on cpu and gpu, will this incorrect results also happen?
  • Is it only conv2d, i.e. if we tune other kernels on hexagon, will the before/after tuned kernels give different results?

@psrivas2
Copy link
Contributor Author

psrivas2 commented Dec 3, 2022

First, it is hexagon specific. On CPU the tuned kernel output is same as untuned output.
Second, I have only observed this behavior for this specific kernel. For example, after fusion, resnet has around 31 PrimFuncs. Out of those 31, only 1 PrimFunc which had the above block as one of the fused operations was producing different results than untuned PrimFuncs.

@psrivas2 psrivas2 self-assigned this Dec 3, 2022
@psrivas2
Copy link
Contributor Author

psrivas2 commented Dec 3, 2022

In addition to that, this is definitely some incorrect transformation of untuned PrimFunc, as the two PrimFuncs shown above give different results even on CPU.

@psrivas2
Copy link
Contributor Author

I think I have narrowed it down to the reordering of loops.

On Hexagon the following two modules which differ only in the order of loops i3 & i4 produce different numeric results. The max difference in values is 0.5 and the mean difference is 0.0708. This is only happening for fp16 dtype.

@tvm.script.ir_module
class TuningBug:
    @T.prim_func
    def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
        # body
        # with T.block("root")
        for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
            with T.block("conv2d_nhwc"):
                nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
                T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
                T.writes(conv2d_nhwc[nn, yy, xx, ff])
                with T.init():
                    conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
                conv2d_nhwc[nn, yy, xx, ff] = (conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff])

    @R.function
    def main(lv1: R.Tensor[(1, 230, 230, 3), "float16"], param_0: R.Tensor[(T.int64(7), T.int64(7), T.int64(3), T.int64(64)), "float16"]):
        with R.dataflow():
            gv = R.call_tir(conv2d, (lv1, param_0), (1, 112, 112, 64), dtype="float16")
            R.output(gv)
        return gv

Reorder loops i3 & i4

sch = tvm.tir.Schedule(mod)
b0 = sch.get_block("conv2d_nhwc", func_name="conv2d")
i0, i1, i2, i3, i4, i5, i6 = sch.get_loops(b0)
sch.reorder(i4, i3)

the modified module looks like below

@tvm.script.ir_module
class TuningBug:
    @T.prim_func
    def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
        # body
        # with T.block("root")
        for i0, i1, i2, i4, i3, i5, i6 in T.grid(1, 112, 112, 7, 64, 7, 3):
            with T.block("conv2d_nhwc"):
                nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
                T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
                T.writes(conv2d_nhwc[nn, yy, xx, ff])
                with T.init():
                    conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
                conv2d_nhwc[nn, yy, xx, ff] = (conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff])

    @R.function
    def main(lv1: R.Tensor[(1, 230, 230, 3), "float16"], param_0: R.Tensor[(T.int64(7), T.int64(7), T.int64(3), T.int64(64)), "float16"]):
        with R.dataflow():
            gv = R.call_tir(conv2d, (lv1, param_0), (1, 112, 112, 64), dtype="float16")
            R.output(gv)
        return gv

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants