Skip to content

Commit

Permalink
Fix mypy types
Browse files Browse the repository at this point in the history
  • Loading branch information
matte1 committed Nov 14, 2024
1 parent 0d07664 commit 078400b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
25 changes: 14 additions & 11 deletions symforce/codegen/backends/rust/rust_code_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,18 @@ def _set_override_methods(self, expr: sympy.Function, name: str) -> None:
method_name = f"_print_{str(expr)}"

def _print_expr(expr: sympy.Expr) -> str:
expr_string = ", ".join(map(self._print, expr.args))
expr_string = ", ".join(map(self._print, expr.args)) # type: ignore
return f"{name}({expr_string})"

setattr(self, method_name, _print_expr)

def _print(self, expr: sympy.Expr, **kwargs):
def _print(self, expr: sympy.Expr, **kwargs: T.Any) -> str:
return super()._print(expr, **kwargs)

def _print_Zero(self, expr: sympy.Expr) -> str:
return "0.0"

def _print_Integer(self, expr: sympy.Integer) -> T.Any:
def _print_Integer(self, expr: sympy.Integer, _type: T.Any = None) -> T.Any:
"""
Customizations:
* Cast all integers to either f32 or f64 because Rust does not have implicit casting
Expand All @@ -67,8 +67,7 @@ def _print_Integer(self, expr: sympy.Integer) -> T.Any:
return f"{expr.p}_f64"
assert False, f"Scalar type {self.scalar_type} not supported"

def _print_Pow(self, expr):

def _print_Pow(self, expr: T.Any, rational: T.Any = None) -> str:
if expr.exp.is_rational:
power = self._print_Rational(expr.exp)
func = "powf"
Expand All @@ -92,7 +91,7 @@ def _print_ImaginaryUnit(self, expr: sympy.Expr) -> str:
return "Scalar(1i)"


def _print_Float(self, flt: sympy.Float) -> T.Any:
def _print_Float(self, flt: sympy.Float, _type: T.Any = None) -> T.Any:
"""
Customizations:
* Cast all literals to Scalar at compile time instead of using a suffix at codegen time
Expand All @@ -104,25 +103,27 @@ def _print_Float(self, flt: sympy.Float) -> T.Any:

raise NotImplementedError(f"Scalar type {self.scalar_type} not supported")

def _print_Pi(self, expr, _type=False):
def _print_Pi(self, expr: T.Any, _type:bool=False) -> str:
if self.scalar_type is float32:
return f"core::f32::consts::PI"
if self.scalar_type is float64:
return f"core::f64::consts::PI"

raise NotImplementedError(f"Scalar type {self.scalar_type} not supported")

def _print_Max(self, expr: sympy.Max) -> str:
"""
Customizations:
* The first argument calls the max method on the second argument.
"""
return "{}.max({})".format(self._print(expr.args[0]), self._print(expr.args[1]))
return "{}.max({})".format(self._print(expr.args[0]), self._print(expr.args[1])) # type: ignore

def _print_Min(self, expr: sympy.Min) -> str:
"""
Customizations:
* The first argument calls the min method on the second argument.
"""
return "{}.min({})".format(self._print(expr.args[0]), self._print(expr.args[1]))
return "{}.min({})".format(self._print(expr.args[0]), self._print(expr.args[1])) # type: ignore

def _print_log(self, expr: sympy.log) -> str:
"""
Expand All @@ -131,7 +132,7 @@ def _print_log(self, expr: sympy.log) -> str:
return "{}.ln()".format(self._print(expr.args[0]))


def _print_Rational(self, expr):
def _print_Rational(self, expr: sympy.Rational) -> str:
p, q = int(expr.p), int(expr.q)

float_suffix = None
Expand All @@ -143,8 +144,10 @@ def _print_Rational(self, expr):
return f"({p}_{float_suffix}/{q}_{float_suffix})"


def _print_Exp1(self, expr, _type=False):
def _print_Exp1(self, expr: T.Any, _type: bool=False) -> str:
if self.scalar_type is float32:
return 'core::f32::consts::E'
elif self.scalar_type is float64:
return 'core::f64::consts::E'

raise NotImplementedError(f"Scalar type {self.scalar_type} not supported")
9 changes: 8 additions & 1 deletion symforce/codegen/backends/rust/rust_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,11 @@ def format_matrix_accessor(self, key: str, i: int, j: int, *, shape: T.Tuple[int
return f"{key}[{i}]"
if shape[0] == 1:
return f"{key}[{j}]"
return f"{key}[({i}, {j})]"
return f"{key}[({i}, {j})]"

@staticmethod
def format_eigen_lcm_accessor(key: str, i: int) -> str:
"""
Format accessor for eigen_lcm types.
"""
raise NotImplementedError("Rust does not support eigen_lcm")
3 changes: 2 additions & 1 deletion test/symforce_rust_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class SymforceRustCodegenTest(TestCase):

def test_codegen(self) -> None:
def rust_func(vec3: sf.V3, mat33: sf.M33) -> sf.Matrix31:
return mat33 * vec3
return sf.Matrix31(mat33 * vec3)

output_dir_base = self.make_output_dir("symforce_rust_codegen_test_")
output_dir_src = output_dir_base / "src"
Expand Down Expand Up @@ -147,5 +147,6 @@ def cargo_build(self, output_dir: Path) -> None:
if result.returncode != 0:
self.fail(f"cargo build failed:\n{result.stderr}")


if __name__ == "__main__":
SymforceRustCodegenTest.main()

0 comments on commit 078400b

Please sign in to comment.