Skip to content

Commit

Permalink
[lang] [refactor] deprecate @boardcast_if_scalar, all use @binary and @…
Browse files Browse the repository at this point in the history
…unary (#943)

* [skip ci] add test

* fix linalg (hopefully don't harm performance)

* better fix

* [skip ci] fix typo

* fix asin dom err

* enhanced pow

* [skip ci] reduce pow optimization from 100 to 50

* [skip ci] balance test load pressure

* share common_ops between Expr and Matrix

* [skip ci] add comments about the l-value problem

* [skip ci] fix opengl pow test by adding fast_pow for rhs is int

* [skip ci] add comment

* [skip ci] nit comment

* delete

* add comment

* [skip ci] Apply suggestions from code review

Co-authored-by: Yuanming Hu <[email protected]>

* nit comment

* fix abs(power)

Co-authored-by: Yuanming Hu <[email protected]>
  • Loading branch information
archibate and yuanming-hu authored May 12, 2020
1 parent e0dfdd7 commit 3b79511
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 212 deletions.
60 changes: 60 additions & 0 deletions python/taichi/lang/common_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
class TaichiOperations:
def __neg__(self):
import taichi as ti
return ti.neg(self)

def __abs__(self):
import taichi as ti
return ti.abs(self)

def __add__(self, other):
import taichi as ti
return ti.add(self, other)

def __radd__(self, other):
import taichi as ti
return ti.add(other, self)

def __sub__(self, other):
import taichi as ti
return ti.sub(self, other)

def __rsub__(self, other):
import taichi as ti
return ti.sub(other, self)

def __mul__(self, other):
import taichi as ti
return ti.mul(self, other)

def __rmul__(self, other):
import taichi as ti
return ti.mul(other, self)

def __truediv__(self, other):
import taichi as ti
return ti.truediv(self, other)

def __rtruediv__(self, other):
import taichi as ti
return ti.truediv(other, self)

def __floordiv__(self, other):
import taichi as ti
return ti.floordiv(self, other)

def __rfloordiv__(self, other):
import taichi as ti
return ti.floordiv(other, self)

def __mod__(self, other):
import taichi as ti
return ti.mod(self, other)

def __pow__(self, other, modulo=None):
import taichi as ti
return ti.pow(self, other)

def __rpow__(self, other, modulo=None):
import taichi as ti
return ti.pow(other, self)
96 changes: 9 additions & 87 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .core import taichi_lang_core
from .util import *
from .common_ops import TaichiOperations
import traceback


# Scalar, basic data type
class Expr:
class Expr(TaichiOperations):
materialize_layout_callback = None
layout_materialized = False

Expand Down Expand Up @@ -46,70 +47,23 @@ def stack_info():
# remove the confusing last line
return '\n'.join(raw.split('\n')[:-3]) + '\n'

def __add__(self, other):
other = Expr(other)
return Expr(taichi_lang_core.expr_add(self.ptr, other.ptr),
tb=self.stack_info())

__radd__ = __add__

def __neg__(self):
return Expr(taichi_lang_core.expr_neg(self.ptr), tb=self.stack_info())

def __sub__(self, other):
other = Expr(other)
return Expr(taichi_lang_core.expr_sub(self.ptr, other.ptr),
tb=self.stack_info())

def __rsub__(self, other):
other = Expr(other)
return Expr(taichi_lang_core.expr_sub(other.ptr, self.ptr))

def __mul__(self, other):
if is_taichi_class(other) and hasattr(other, '__rmul__'):
return other.__rmul__(self)
else:
other = Expr(other)
return Expr(taichi_lang_core.expr_mul(self.ptr, other.ptr))

__rmul__ = __mul__

def __truediv__(self, other):
return Expr(taichi_lang_core.expr_truediv(self.ptr, Expr(other).ptr))

def __rtruediv__(self, other):
return Expr(taichi_lang_core.expr_truediv(Expr(other).ptr, self.ptr))

def __floordiv__(self, other):
return Expr(taichi_lang_core.expr_floordiv(self.ptr, Expr(other).ptr))

def __rfloordiv__(self, other):
return Expr(taichi_lang_core.expr_floordiv(Expr(other).ptr, self.ptr))

def __mod__(self, other):
other = Expr(other)
quotient = Expr(taichi_lang_core.expr_floordiv(self.ptr, other.ptr))
multiply = Expr(taichi_lang_core.expr_mul(other.ptr, quotient.ptr))
return Expr(taichi_lang_core.expr_sub(self.ptr, multiply.ptr))

def __iadd__(self, other):
self.atomic_add(other)

def __isub__(self, other):
self.atomic_sub(other)

def __imul__(self, other):
self.assign(Expr(taichi_lang_core.expr_mul(self.ptr, other.ptr)))
import taichi as ti
self.assign(ti.mul(self, other))

def __itruediv__(self, other):
self.assign(
Expr(taichi_lang_core.expr_truediv(self.ptr,
Expr(other).ptr)))
import taichi as ti
self.assign(ti.truediv(self, other))

def __ifloordiv__(self, other):
self.assign(
Expr(taichi_lang_core.expr_floordiv(self.ptr,
Expr(other).ptr)))
import taichi as ti
self.assign(ti.floordiv(self, other))

def __iand__(self, other):
self.atomic_and(other)
Expand All @@ -120,6 +74,7 @@ def __ior__(self, other):
def __ixor__(self, other):
self.atomic_xor(other)

# TODO: move to ops.py: ti.cmp_le
def __le__(self, other):
other = Expr(other)
return Expr(taichi_lang_core.expr_cmp_le(self.ptr, other.ptr))
Expand Down Expand Up @@ -314,39 +269,6 @@ def fill(self, val):
from .meta import fill_tensor
fill_tensor(self, val)

def __rpow__(self, power, modulo=None):
# Python will try Matrix.__pow__ first so we don't have to worry whether `power` is `Matrix`
return Expr(power).__pow__(self, modulo)

def __pow__(self, power, modulo=None):
import taichi as ti
if ti.is_taichi_class(power):
return power.element_wise_binary(lambda x, y: pow(y, x), self)
if not isinstance(power, int) or abs(power) > 100:
return Expr(taichi_lang_core.expr_pow(self.ptr, Expr(power).ptr))
if power == 0:
return Expr(1)
negative = power < 0
power = abs(power)
tmp = self
ret = None
while power:
if power & 1:
if ret is None:
ret = tmp
else:
ret = ti.expr_init(ret * tmp)
tmp = ti.expr_init(tmp * tmp)
power >>= 1
if negative:
return 1 / ret
else:
return ret

def __abs__(self):
import taichi as ti
return ti.abs(self)

def __ti_int__(self):
import taichi as ti
return ti.cast(self, ti.get_runtime().default_ip)
Expand Down
112 changes: 8 additions & 104 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numbers
import numpy as np
from .util import to_numpy_type, to_pytorch_type
from .common_ops import TaichiOperations


def broadcast_if_scalar(func):
Expand All @@ -15,7 +16,7 @@ def broadcasted(self, other, *args, **kwargs):
return broadcasted


class Matrix:
class Matrix(TaichiOperations):
is_taichi_class = True

def __init__(self,
Expand All @@ -42,12 +43,14 @@ def __init__(self,
assert row.n == rows[
0].n, "input vectors must be the same shape"
self.m = rows[0].n
# l-value copy:
self.entries = [row(i) for row in rows for i in range(row.n)]
elif isinstance(rows[0], list):
for row in rows:
assert len(row) == len(
rows[0]), "input lists must be the same shape"
self.m = len(rows[0])
# l-value copy:
self.entries = [x for row in rows for x in row]
else:
raise Exception(
Expand Down Expand Up @@ -168,113 +171,12 @@ def __matmul__(self, other):
ret(i, j).assign(ret(i, j) + self(i, k) * other(k, j))
return ret

@broadcast_if_scalar
def __pow__(self, other):
assert self.n == other.n and self.m == other.m
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(self(i, j)**other(i, j))
return ret

@broadcast_if_scalar
def __rpow__(self, other):
assert self.n == other.n and self.m == other.m
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(other(i, j)**self(i, j))
return ret

@broadcast_if_scalar
def __div__(self, other):
assert self.n == other.n and self.m == other.m
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(self(i, j) / other(i, j))
return ret

@broadcast_if_scalar
def __rtruediv__(self, other):
assert self.n == other.n and self.m == other.m
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(other(i, j) / self(i, j))
return ret

def broadcast(self, scalar):
ret = Matrix(self.n, self.m, empty=True)
for i in range(self.n * self.m):
ret.entries[i] = scalar
return ret

@broadcast_if_scalar
def __truediv__(self, other):
assert self.n == other.n and self.m == other.m
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(self(i, j) / other(i, j))
return ret

@broadcast_if_scalar
def __floordiv__(self, other):
assert self.n == other.n and self.m == other.m
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(self(i, j) // other(i, j))
return ret

@broadcast_if_scalar
def __mul__(self, other):
assert self.n == other.n and self.m == other.m
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(self(i, j) * other(i, j))
return ret

__rmul__ = __mul__

@broadcast_if_scalar
def __add__(self, other):
assert self.n == other.n and self.m == other.m
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(self(i, j) + other(i, j))
return ret

__radd__ = __add__

@broadcast_if_scalar
def __sub__(self, other):
assert self.n == other.n and self.m == other.m
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(self(i, j) - other(i, j))
return ret

def __neg__(self):
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(-self(i, j))
return ret

@broadcast_if_scalar
def __rsub__(self, other):
assert self.n == other.n and self.m == other.m
ret = Matrix(self.n, self.m)
for i in range(self.n):
for j in range(self.m):
ret(i, j).assign(other(i, j) - self(i, j))
return ret

def linearize_entry_id(self, *args):
assert 1 <= len(args) <= 2
if len(args) == 1 and isinstance(args[0], (list, tuple)):
Expand Down Expand Up @@ -382,7 +284,7 @@ def abs(self):

def trace(self):
assert self.n == self.m
sum = self(0, 0)
sum = expr.Expr(self(0, 0))
for i in range(1, self.n):
sum = sum + self(i, i)
return sum
Expand All @@ -393,8 +295,9 @@ def inverse(self):
return Matrix([1 / self(0, 0)])
elif self.n == 2:
inv_det = impl.expr_init(1.0 / self.determinant(self))
# Discussion: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626344323
return inv_det * Matrix([[self(1, 1), -self(0, 1)],
[-self(1, 0), self(0, 0)]])
[-self(1, 0), self(0, 0)]]).variable()
elif self.n == 3:
n = 3
import taichi as ti
Expand Down Expand Up @@ -527,6 +430,7 @@ def diag(dim, val):
def loop_range(self):
return self.entries[0]

# TODO
@broadcast_if_scalar
def augassign(self, other, op):
if not isinstance(other, Matrix):
Expand Down
Loading

0 comments on commit 3b79511

Please sign in to comment.