Skip to content

Commit

Permalink
[autodiff] Print more specific error message that autodiff does not s…
Browse files Browse the repository at this point in the history
…upport to_numpy (#5630)
  • Loading branch information
PhrygianGates authored Aug 6, 2022
1 parent bee0e63 commit aefa42b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
10 changes: 10 additions & 0 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,11 @@ class MakeAdjoint : public ADTransform {

void visit(GlobalLoadStmt *stmt) override {
// issue global store to adjoint
if (stmt->src->is<ExternalPtrStmt>()) {
TI_ERROR(
"Importing data from external array (such as numpy array) not "
"supported in AutoDiff for now")
}
GlobalPtrStmt *src = stmt->src->as<GlobalPtrStmt>();
TI_ASSERT(src->width() == 1);
auto snodes = src->snodes;
Expand All @@ -1008,6 +1013,11 @@ class MakeAdjoint : public ADTransform {

void visit(GlobalStoreStmt *stmt) override {
// erase and replace with global load adjoint
if (stmt->dest->is<ExternalPtrStmt>()) {
TI_ERROR(
"Exporting data to external array (such as numpy array) not "
"supported in AutoDiff for now")
}
GlobalPtrStmt *dest = stmt->dest->as<GlobalPtrStmt>();
TI_ASSERT(dest->width() == 1);
auto snodes = dest->snodes;
Expand Down
35 changes: 35 additions & 0 deletions tests/python/test_ad_external_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import pytest

import taichi as ti
from tests import test_utils


@test_utils.test()
def test_to_numpy():
a = ti.field(dtype=float, shape=(), needs_grad=True)
loss = ti.field(dtype=float, shape=(), needs_grad=True)

def func():
b = a.to_numpy()

with pytest.raises(RuntimeError) as e:
with ti.ad.Tape(loss):
func()
assert 'Exporting data to external array (such as numpy array) not supported in AutoDiff for now' in e.value.args[
0]


@test_utils.test()
def test_from_numpy():
a = ti.field(dtype=float, shape=(), needs_grad=True)
loss = ti.field(dtype=float, shape=(), needs_grad=True)

def func():
a.from_numpy(np.asarray(1))

with pytest.raises(RuntimeError) as e:
with ti.ad.Tape(loss):
func()
assert 'Importing data from external array (such as numpy array) not supported in AutoDiff for now' in e.value.args[
0]

0 comments on commit aefa42b

Please sign in to comment.