Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experimental: Add minimal rust backend. #405

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ jobs:
libgmp-dev \
pandoc

- name: Install rust
run: |
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

# NOTE(brad): libunwind-dev is a broken dependency of libgoogle-glog-dev, itself
# a dependency of ceres. Without this step on jammy, apt-get install libgoogle-glog-dev
# would fail. If this step could be removed and still have the build succeed, it should.
Expand Down
3 changes: 3 additions & 0 deletions symforce/codegen/backends/rust/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
***THIS MODULE IS EXPERIMENTAL***

Backend for Rust. This currently only supports vector/matrices inputs and outputs, we do not have geo or cam types for Rust yet.
12 changes: 12 additions & 0 deletions symforce/codegen/backends/rust/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

from pathlib import Path

__doc__ = (Path(__file__).parent / "README.rst").read_text()

from .rust_code_printer import RustCodePrinter
from .rust_code_printer import ScalarType
from .rust_config import RustConfig
151 changes: 151 additions & 0 deletions symforce/codegen/backends/rust/rust_code_printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

from enum import Enum

import sympy
from sympy.core.singleton import S
from sympy.core.numbers import Rational
from sympy.codegen.ast import float32
from sympy.codegen.ast import float64
from sympy.codegen.ast import real
from sympy.printing.rust import RustCodePrinter as SympyRustCodePrinter

from symforce import typing as T


class ScalarType(Enum):
FLOAT = float32
DOUBLE = float64


class RustCodePrinter(SympyRustCodePrinter):
"""
SymForce code printer for Rust. Based on the SymPy Rust printer.
"""

def __init__(
self,
scalar_type: ScalarType,
settings: T.Optional[T.Dict[str, T.Any]] = None,
override_methods: T.Optional[T.Dict[sympy.Function, str]] = None,
) -> None:
super().__init__(dict(settings or {}))

self.scalar_type = scalar_type.value
self.override_methods = override_methods or {}
for expr, name in self.override_methods.items():
self._set_override_methods(expr, name)

def _set_override_methods(self, expr: sympy.Function, name: str) -> None:
method_name = f"_print_{str(expr)}"

def _print_expr(expr: sympy.Expr) -> str:
expr_string = ", ".join(map(self._print, expr.args))
return f"{name}({expr_string})"

setattr(self, method_name, _print_expr)

def _print(self, expr: sympy.Expr, **kwargs):
# For whatever reason S.Zero is not a sympy.Integer, so we need to handle it separately
# by returning "0.0" instead of "0" to avoid compilation errors.
if expr == S.Zero:
return "0.0"
return super()._print(expr, **kwargs)

def _print_Integer(self, expr: sympy.Integer) -> T.Any:
"""
Customizations:
* Cast all integers to either f32 or f64 because Rust does not have implicit casting
and needs to know the type of the literal at compile time. We assume that we are only
ever operating on floats in SymForce which should make this safe.
"""
if self.scalar_type is float32:
return f"{expr.p}_f32"
if self.scalar_type is float64:
return f"{expr.p}_f64"
assert False, f"Scalar type {self.scalar_type} not supported"

def _print_Pow(self, expr):

if expr.exp.is_rational:
power = self._print_Rational(expr.exp)
func = "powf"
return f"{self._print(expr.base)}.{func}({power})"
else:
power = self._print(expr.exp)

if expr.exp.is_integer:
func = "powi"
else:
func = "powf"

return f"{expr.base}.{func}({power})"

def _print_ImaginaryUnit(self, expr: sympy.Expr) -> str:
"""
Customizations:
* Print 1i instead of I
* Cast to Scalar, since the literal is of type std::complex<double>
"""
return "Scalar(1i)"


def _print_Float(self, flt: sympy.Float) -> T.Any:
"""
Customizations:
* Cast all literals to Scalar at compile time instead of using a suffix at codegen time
matte1 marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.scalar_type is float32:
return f"{super()._print_Float(flt)}_f32"
if self.scalar_type is float64:
return f"{super()._print_Float(flt)}_f64"

raise NotImplementedError(f"Scalar type {self.scalar_type} not supported")

def _print_Pi(self, expr, _type=False):
if self.scalar_type is float32:
return f"core::f32::consts::PI"
if self.scalar_type is float64:
return f"core::f64::consts::PI"

def _print_Max(self, expr: sympy.Max) -> str:
"""
Customizations:
* The first argument calls the max method on the second argument.
"""
return "{}.max({})".format(self._print(expr.args[0]), self._print(expr.args[1]))

def _print_Min(self, expr: sympy.Min) -> str:
"""
Customizations:
* The first argument calls the min method on the second argument.
"""
return "{}.min({})".format(self._print(expr.args[0]), self._print(expr.args[1]))

def _print_log(self, expr: sympy.log) -> str:
"""
Customizations:
"""
return "{}.ln()".format(self._print(expr.args[0]))


def _print_Rational(self, expr):
p, q = int(expr.p), int(expr.q)

float_suffix = None
if self.scalar_type is float32:
float_suffix = 'f32'
elif self.scalar_type is float64:
float_suffix = 'f64'

return f"({p}_{float_suffix}/{q}_{float_suffix})"


def _print_Exp1(self, expr, _type=False):
if self.scalar_type is float32:
return 'core::f32::consts::E'
elif self.scalar_type is float64:
return 'core::f64::consts::E'
60 changes: 60 additions & 0 deletions symforce/codegen/backends/rust/rust_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
from dataclasses import dataclass
from pathlib import Path

import sympy
from sympy.printing.codeprinter import CodePrinter

from symforce import typing as T
from symforce.codegen.backends.rust import rust_code_printer
from symforce.codegen.codegen_config import CodegenConfig

CURRENT_DIR = Path(__file__).parent


@dataclass
class RustConfig(CodegenConfig):
"""
Code generation config for the Rust backend.

Args:
doc_comment_line_prefix: Prefix applied to each line in a docstring
line_length: Maximum allowed line length in docstrings; used for formatting docstrings.
scalar_type: The scalar type to use (float or double)
"""

doc_comment_line_prefix: str = "///"
line_length: int = 100
scalar_type: rust_code_printer.ScalarType = rust_code_printer.ScalarType.FLOAT
aaron-skydio marked this conversation as resolved.
Show resolved Hide resolved
use_eigen_types: bool = False

@classmethod
def backend_name(cls) -> str:
return "rust"

@classmethod
def template_dir(cls) -> Path:
return CURRENT_DIR / "templates"

def templates_to_render(self, generated_file_name: str) -> T.List[T.Tuple[str, str]]:
return [("function/FUNCTION.rs.jinja", f"{generated_file_name}.rs")]

def printer(self) -> CodePrinter:
kwargs: T.Mapping[str, T.Any] = {}
return rust_code_printer.RustCodePrinter(scalar_type=self.scalar_type, **kwargs)

def format_matrix_accessor(self, key: str, i: int, j: int, *, shape: T.Tuple[int, int]) -> str:
"""
Format accessor for matrix types.

Assumes matrices are row-major.
"""
RustConfig._assert_indices_in_bounds(i, j, shape)
if shape[1] == 1:
return f"{key}[{i}]"
if shape[0] == 1:
return f"{key}[{j}]"
return f"{key}[({i}, {j})]"
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ---------------------------------------------------------------------------- #}

{%- import "../util/util.jinja" as util with context -%}


pub mod {{ spec.namespace }} {

{% if spec.docstring %}
{{ util.print_docstring(spec.docstring) }}
{% endif %}
{{ util.function_declaration(spec) }} {
{{ util.expr_code(spec) }}
}

} // mod {{ spec.namespace }}
Loading
Loading