Skip to content

Commit

Permalink
an example of how to create a custom type from scratch. (#1113)
Browse files Browse the repository at this point in the history
* teach the sql-entity-graph to not create its own edges for`#[pg_extern]`s with a `requires = [...]` list.  We instead want it require only the entities specified by the user.
  • Loading branch information
eeeebbbbrrrr authored Apr 19, 2023
1 parent 66f17eb commit 15572a7
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 1 deletion.
202 changes: 202 additions & 0 deletions pgrx-examples/custom_types/src/hexint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
use pgrx::pg_sys::{Datum, Oid};
use pgrx::pgrx_sql_entity_graph::metadata::{
ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
};
use pgrx::prelude::*;
use pgrx::{rust_regtypein, StringInfo};
use std::error::Error;
use std::ffi::CStr;
use std::fmt::{Display, Formatter};

/// A `BIGINT`-style type that is always represented as hexadecimal.
///
/// # Examples
///
/// ```sql
/// custom_types=# SELECT
/// 42::hexint,
/// '0xdeadbeef'::hexint,
/// '0Xdeadbeef'::hexint::bigint,
/// 2147483647::hexint,
/// 9223372036854775807::hexint,
/// '0xbadc0ffee'::hexint::numeric;
/// hexint | hexint | int8 | hexint | hexint | numeric
/// --------+------------+------------+------------+--------------------+-------------
/// 0x2a | 0xdeadbeef | 3735928559 | 0x7fffffff | 0x7fffffffffffffff | 50159747054
/// ```
#[repr(transparent)]
#[derive(
Copy,
Clone,
Debug,
Ord,
PartialOrd,
Eq,
PartialEq,
Hash,
PostgresEq,
PostgresOrd,
PostgresHash
)]
struct HexInt {
value: i64,
}

impl Display for HexInt {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
// format ourselves as a `0x`-prefixed hexadecimal string
write!(f, "0x{:x}", self.value)
}
}

unsafe impl SqlTranslatable for HexInt {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
// this is what the SQL type is called when used in a function argument position
Ok(SqlMapping::As("hexint".into()))
}

fn return_sql() -> Result<Returns, ReturnsError> {
// this is what the SQL type is called when used in a function return type position
Ok(Returns::One(SqlMapping::As("hexint".into())))
}
}

impl FromDatum for HexInt {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _: Oid) -> Option<Self>
where
Self: Sized,
{
if is_null {
None
} else {
Some(HexInt { value: datum.value() as _ })
}
}
}

impl IntoDatum for HexInt {
fn into_datum(self) -> Option<Datum> {
Some(Datum::from(self.value))
}

fn type_oid() -> Oid {
rust_regtypein::<Self>()
}
}

/// Input function for `HexInt`. Parses any valid "radix(16)" text string, with or without a leading
/// `0x` (or `0X`) into a `HexInt` type. Parse errors are returned and handled by pgrx
///
/// `requires = [ "shell_type" ]` indicates that the CREATE FUNCTION SQL for this function must happen
/// *after* the SQL for the "shell_type" block.
#[pg_extern(immutable, parallel_safe, requires = [ "shell_type" ])]
fn hexint_in(input: &CStr) -> Result<HexInt, Box<dyn Error>> {
Ok(HexInt {
value: i64::from_str_radix(
// ignore `0x` or `0X` prefixes
input.to_str()?.trim_start_matches("0x").trim_start_matches("0X"),
16,
)?,
})
}

/// Output function for `HexInt`. Returns a cstring with the underlying `HexInt` value formatted
/// as hexadecimal, prefixed with `0x`.
///
/// `requires = [ "shell_type" ]` indicates that the CREATE FUNCTION SQL for this function must happen
/// *after* the SQL for the "shell_type" block.
#[pg_extern(immutable, parallel_safe, requires = [ "shell_type" ])]
fn hexint_out<'a>(value: HexInt) -> &'a CStr {
let mut s = StringInfo::new();
s.push_str(&value.to_string());
s.into()
}

//
// cast functions
//

#[pg_extern(immutable, parallel_safe)]
fn hexint_to_int(hexint: HexInt) -> Result<i32, Box<dyn Error>> {
Ok(hexint.value.try_into()?)
}

#[pg_extern(immutable, parallel_safe)]
fn hexint_to_numeric(hexint: HexInt) -> Result<AnyNumeric, Box<dyn Error>> {
Ok(hexint.value.try_into()?)
}

#[pg_extern(immutable, parallel_safe)]
fn int_to_hexint(bigint: i32) -> HexInt {
HexInt { value: bigint as _ }
}

#[pg_extern(immutable, parallel_safe)]
fn numeric_to_hexint(bigint: AnyNumeric) -> Result<HexInt, Box<dyn Error>> {
Ok(HexInt { value: bigint.try_into()? })
}

//
// handwritten sql
//

// creates the `hexint` shell type, which is essentially a type placeholder so that the
// input and output functions can be created
extension_sql!(
r#"
CREATE TYPE hexint; -- shell type
"#,
name = "shell_type",
bootstrap // declare this extension_sql block as the "bootstrap" block so that it happens first in sql generation
);

// create the actual type, specifying the input and output functions
extension_sql!(
r#"
CREATE TYPE hexint (
INPUT = hexint_in,
OUTPUT = hexint_out,
LIKE = int8
);
"#,
name = "concrete_type",
creates = [Type(HexInt)],
requires = ["shell_type", hexint_in, hexint_out], // so that we won't be created until the shell type and input and output functions have
);

// some convenient casts
extension_sql!(
r#"
CREATE CAST (bigint AS hexint) WITHOUT FUNCTION AS IMPLICIT;
CREATE CAST (hexint AS bigint) WITHOUT FUNCTION AS IMPLICIT;
CREATE CAST (hexint AS int) WITH FUNCTION hexint_to_int AS IMPLICIT;
CREATE CAST (int AS hexint) WITH FUNCTION int_to_hexint AS IMPLICIT;
CREATE CAST (hexint AS numeric) WITH FUNCTION hexint_to_numeric AS IMPLICIT;
CREATE CAST (numeric AS hexint) WITH FUNCTION numeric_to_hexint AS IMPLICIT;
"#,
name = "casts",
requires = [hexint_to_int, int_to_hexint, hexint_to_numeric, numeric_to_hexint]
);

#[cfg(not(feature = "no-schema-generation"))]
#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
mod tests {
use crate::hexint::HexInt;
use pgrx::prelude::*;
use std::error::Error;

#[pg_test]
fn test_hexint_input_func() -> Result<(), Box<dyn Error>> {
let value = Spi::get_one::<HexInt>("SELECT '0x2a'::hexint")?;
assert_eq!(value, Some(HexInt { value: 0x2a }));
Ok(())
}

#[pg_test]
fn test_hexint_output_func() -> Result<(), Box<dyn Error>> {
let value = Spi::get_one::<String>("SELECT 42::hexint::text")?;
assert_eq!(value, Some("0x2a".into()));
Ok(())
}
}
1 change: 1 addition & 0 deletions pgrx-examples/custom_types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Use of this source code is governed by the MIT license that can be found in the
mod complex;
mod fixed_size;
mod generic_enum;
mod hexint;
mod hstore_clone;
mod ordered;
mod rust_enum;
Expand Down
6 changes: 5 additions & 1 deletion pgrx-sql-entity-graph/src/pgrx_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ fn connect_externs(
) -> eyre::Result<()> {
for (item, &index) in externs {
let mut found_schema_declaration = false;
let mut has_explicit_requires = false;
for extern_attr in &item.extern_attrs {
match extern_attr {
crate::ExternArgs::Requires(requirements) => {
Expand All @@ -863,6 +864,7 @@ fn connect_externs(
triggers,
) {
graph.add_edge(*target, index, SqlGraphRelationship::RequiredBy);
has_explicit_requires = true;
} else {
return Err(eyre!("Could not find `requires` target: {:?}", requires));
}
Expand Down Expand Up @@ -933,7 +935,9 @@ fn connect_externs(
if let Some(_) = ext_item.has_sql_declared_entity(&SqlDeclared::Type(
arg.used_ty.full_path.to_string(),
)) {
graph.add_edge(*ext_index, index, SqlGraphRelationship::RequiredByArg);
if !has_explicit_requires {
graph.add_edge(*ext_index, index, SqlGraphRelationship::RequiredByArg);
}
} else if let Some(_) = ext_item.has_sql_declared_entity(&SqlDeclared::Enum(
arg.used_ty.full_path.to_string(),
)) {
Expand Down

0 comments on commit 15572a7

Please sign in to comment.