diff --git a/docs/lang/articles/advanced/data_oriented_class.md b/docs/lang/articles/advanced/data_oriented_class.md index 81b2ba46a2fe0..d0deeca91374a 100644 --- a/docs/lang/articles/advanced/data_oriented_class.md +++ b/docs/lang/articles/advanced/data_oriented_class.md @@ -71,16 +71,12 @@ ti.init() @ti.data_oriented class Calc: def __init__(self): - self.x = ti.field(dtype=ti.f32, shape=16) - self.y = ti.field(dtype=ti.f32, shape=4) + self.x = ti.field(dtype=ti.f32, shape=8) @ti.kernel def func(self, temp: ti.template()): for i in range(8): - temp[i] = self.x[i * 2] + self.x[i * 2 + 1] - - for i in range(4): - self.y[i] = max(temp[i * 2], temp[i * 2 + 1]) + temp[i] = self.x[i * 2] def call_func(self): fb = ti.FieldsBuilder() @@ -171,65 +167,13 @@ ti.init() @ti.data_oriented class Array2D: - def __init__(self, n, m, increment): - self.n = n - self.m = m - self.val = ti.field(ti.f32) - self.total = ti.field(ti.f32) - self.increment = float(increment) - ti.root.dense(ti.ij, (self.n, self.m)).place(self.val) - ti.root.place(self.total) + def __init__(self, n): + self.arr = ti.Vector([0.] * n) @staticmethod @ti.func def clamp(x): # Clamp to [0, 1) - return max(0., min(1 - 1e-6, x)) - - @ti.kernel - def inc(self): - for i, j in self.val: - ti.atomic_add(self.val[i, j], self.increment) - - @ti.kernel - def inc2(self, increment: ti.i32): - for i, j in self.val: - ti.atomic_add(self.val[i, j], increment) - - @ti.kernel - def reduce(self): - for i, j in self.val: - ti.atomic_add(self.total[None], self.val[i, j] * 4) - -arr = Array2D(2, 2, 3) - -double_total = ti.field(ti.f32, shape=()) - -ti.root.lazy_grad() - -arr.inc() -arr.inc.grad() -print(arr.val[0, 0]) # 3 -arr.inc2(4) -print(arr.val[0, 0]) # 7 - -with ti.ad.Tape(loss=arr.total): - arr.reduce() - -for i in range(arr.n): - for j in range(arr.m): - print(arr.val.grad[i, j]) # 4 - -@ti.kernel -def double(): - double_total[None] = 2 * arr.total[None] - -with ti.ad.Tape(loss=double_total): - arr.reduce() - double() - -for i in range(arr.n): - for j in range(arr.m): - print(arr.val.grad[i, j]) # 8 + return max(0, min(1, x)) ``` `classmethod` example: