Skip to content

Commit

Permalink
[Doc] Update data_oriented_class.md (#6181)
Browse files Browse the repository at this point in the history
  • Loading branch information
neozhaoliang authored Sep 28, 2022
1 parent eed9081 commit 1a1819f
Showing 1 changed file with 5 additions and 61 deletions.
66 changes: 5 additions & 61 deletions docs/lang/articles/advanced/data_oriented_class.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,12 @@ ti.init()
@ti.data_oriented
class Calc:
def __init__(self):
self.x = ti.field(dtype=ti.f32, shape=16)
self.y = ti.field(dtype=ti.f32, shape=4)
self.x = ti.field(dtype=ti.f32, shape=8)

@ti.kernel
def func(self, temp: ti.template()):
for i in range(8):
temp[i] = self.x[i * 2] + self.x[i * 2 + 1]

for i in range(4):
self.y[i] = max(temp[i * 2], temp[i * 2 + 1])
temp[i] = self.x[i * 2]

def call_func(self):
fb = ti.FieldsBuilder()
Expand Down Expand Up @@ -171,65 +167,13 @@ ti.init()

@ti.data_oriented
class Array2D:
def __init__(self, n, m, increment):
self.n = n
self.m = m
self.val = ti.field(ti.f32)
self.total = ti.field(ti.f32)
self.increment = float(increment)
ti.root.dense(ti.ij, (self.n, self.m)).place(self.val)
ti.root.place(self.total)
def __init__(self, n):
self.arr = ti.Vector([0.] * n)

@staticmethod
@ti.func
def clamp(x): # Clamp to [0, 1)
return max(0., min(1 - 1e-6, x))

@ti.kernel
def inc(self):
for i, j in self.val:
ti.atomic_add(self.val[i, j], self.increment)

@ti.kernel
def inc2(self, increment: ti.i32):
for i, j in self.val:
ti.atomic_add(self.val[i, j], increment)

@ti.kernel
def reduce(self):
for i, j in self.val:
ti.atomic_add(self.total[None], self.val[i, j] * 4)

arr = Array2D(2, 2, 3)

double_total = ti.field(ti.f32, shape=())

ti.root.lazy_grad()

arr.inc()
arr.inc.grad()
print(arr.val[0, 0]) # 3
arr.inc2(4)
print(arr.val[0, 0]) # 7

with ti.ad.Tape(loss=arr.total):
arr.reduce()

for i in range(arr.n):
for j in range(arr.m):
print(arr.val.grad[i, j]) # 4

@ti.kernel
def double():
double_total[None] = 2 * arr.total[None]

with ti.ad.Tape(loss=double_total):
arr.reduce()
double()

for i in range(arr.n):
for j in range(arr.m):
print(arr.val.grad[i, j]) # 8
return max(0, min(1, x))
```

`classmethod` example:
Expand Down

0 comments on commit 1a1819f

Please sign in to comment.