Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: use enum op traits for floats + conversions #755

Merged
merged 1 commit into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ mod test {
use super::*;

fn test_registry() -> ExtensionRegistry {
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::extension()]).unwrap()
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap()
}

#[test]
Expand Down
164 changes: 116 additions & 48 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,131 @@
//! Conversions between integer and floating-point values.

use smol_str::SmolStr;
use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::{
extension::{prelude::sum_with_error, ExtensionId, ExtensionSet},
extension::{
prelude::sum_with_error,
simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError},
ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc,
},
ops::{custom::ExtensionOp, OpName},
type_row,
types::{FunctionType, PolyFuncType},
types::{FunctionType, PolyFuncType, TypeArg},
Extension,
};

use super::int_types::int_tv;
use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
use lazy_static::lazy_static;

/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");

/// Extension for basic arithmetic operations.
pub fn extension() -> Extension {
let ftoi_sig = PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]),
);

let itof_sig = PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]),
);

let mut extension = Extension::new_with_reqs(
EXTENSION_ID,
ExtensionSet::from_iter(vec![
super::int_types::EXTENSION_ID,
super::float_types::EXTENSION_ID,
]),
);
extension
.add_op(
"trunc_u".into(),
"float to unsigned int".to_owned(),
ftoi_sig.clone(),
)
.unwrap();
extension
.add_op("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig)
.unwrap();
extension
.add_op(
"convert_u".into(),
"unsigned int to float".to_owned(),
itof_sig.clone(),
)
.unwrap();
extension
.add_op(
"convert_s".into(),
"signed int to float".to_owned(),
itof_sig,
)
.unwrap();

extension
/// Extensiop for conversions between floats and integers.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
#[allow(missing_docs, non_camel_case_types)]
pub enum ConvertOpDef {
trunc_u,
trunc_s,
convert_u,
convert_s,
}

impl MakeOpDef for ConvertOpDef {
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
crate::extension::simple_op::try_from_name(op_def.name())
}

fn signature(&self) -> SignatureFunc {
use ConvertOpDef::*;
match self {
trunc_s | trunc_u => PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]),
),

convert_s | convert_u => PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]),
),
}
.into()
}

fn description(&self) -> String {
use ConvertOpDef::*;
match self {
trunc_u => "float to unsigned int",
trunc_s => "float to signed int",
convert_u => "unsigned int to float",
convert_s => "signed int to float",
}
.to_string()
}
}

/// Concrete convert operation with integer width set.
#[derive(Debug, Clone, PartialEq)]
pub struct ConvertOpType {
def: ConvertOpDef,
width: u64,
}

impl OpName for ConvertOpType {
fn name(&self) -> SmolStr {
self.def.name()
}
}

impl MakeExtensionOp for ConvertOpType {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
let def = ConvertOpDef::from_def(ext_op.def())?;
let width = match *ext_op.args() {
[TypeArg::BoundedNat { n }] => n,
_ => return Err(SignatureError::InvalidTypeArgs.into()),
};
Ok(Self { def, width })
}

fn type_args(&self) -> Vec<crate::types::TypeArg> {
vec![TypeArg::BoundedNat { n: self.width }]
}
}

lazy_static! {
/// Extension for conversions between integers and floats.
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new_with_reqs(
EXTENSION_ID,
ExtensionSet::from_iter(vec![
super::int_types::EXTENSION_ID,
super::float_types::EXTENSION_ID,
]),
);

ConvertOpDef::load_all_ops(&mut extension).unwrap();

extension
};

/// Registry of extensions required to validate integer operations.
pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
super::int_types::EXTENSION.to_owned(),
super::float_types::EXTENSION.to_owned(),
EXTENSION.to_owned(),
])
.unwrap();
}

impl MakeRegisteredOp for ConvertOpType {
fn extension_id(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry {
&CONVERT_OPS_REGISTRY
}
}

#[cfg(test)]
Expand All @@ -66,7 +134,7 @@ mod test {

#[test]
fn test_conversions_extension() {
let r = extension();
let r = &EXTENSION;
assert_eq!(r.name() as &str, "arithmetic.conversions");
assert_eq!(r.types().count(), 0);
for (name, _) in r.operations() {
Expand Down
Loading