diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index af42eb689298b..b6cd8f5f805c0 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -989,6 +989,11 @@ class MakeAdjoint : public ADTransform { void visit(GlobalLoadStmt *stmt) override { // issue global store to adjoint + if (stmt->src->is()) { + TI_ERROR( + "Importing data from external array (such as numpy array) not " + "supported in AutoDiff for now") + } GlobalPtrStmt *src = stmt->src->as(); TI_ASSERT(src->width() == 1); auto snodes = src->snodes; @@ -1008,6 +1013,11 @@ class MakeAdjoint : public ADTransform { void visit(GlobalStoreStmt *stmt) override { // erase and replace with global load adjoint + if (stmt->dest->is()) { + TI_ERROR( + "Exporting data to external array (such as numpy array) not " + "supported in AutoDiff for now") + } GlobalPtrStmt *dest = stmt->dest->as(); TI_ASSERT(dest->width() == 1); auto snodes = dest->snodes; diff --git a/tests/python/test_ad_external_array.py b/tests/python/test_ad_external_array.py new file mode 100644 index 0000000000000..fae7966f5f323 --- /dev/null +++ b/tests/python/test_ad_external_array.py @@ -0,0 +1,35 @@ +import numpy as np +import pytest + +import taichi as ti +from tests import test_utils + + +@test_utils.test() +def test_to_numpy(): + a = ti.field(dtype=float, shape=(), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + def func(): + b = a.to_numpy() + + with pytest.raises(RuntimeError) as e: + with ti.ad.Tape(loss): + func() + assert 'Exporting data to external array (such as numpy array) not supported in AutoDiff for now' in e.value.args[ + 0] + + +@test_utils.test() +def test_from_numpy(): + a = ti.field(dtype=float, shape=(), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + def func(): + a.from_numpy(np.asarray(1)) + + with pytest.raises(RuntimeError) as e: + with ti.ad.Tape(loss): + func() + assert 'Importing data from external array (such as numpy array) not supported in AutoDiff for now' in e.value.args[ + 0]