Skip to content

Commit

Permalink
refactor: update & fix unity bindgen (#2631)
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo authored Nov 5, 2024
1 parent cf1d99f commit 447ba4f
Showing 1 changed file with 64 additions and 26 deletions.
90 changes: 64 additions & 26 deletions crates/dojo/bindgen/src/plugins/unity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,50 +244,82 @@ namespace {namespace} {{
let mut sorted_enums = tokens.enums.clone();
sorted_enums.sort_by(compare_tokens_by_type_name);

// Process structs first
for token in &sorted_structs {
if handled_tokens.contains_key(&token.type_path()) {
continue;
}

handled_tokens.insert(token.type_path(), token.to_composite().unwrap().to_owned());

// first index is our model struct
if token.type_name() == naming::get_name_from_tag(&model.tag) {
model_struct = Some(token.to_composite().unwrap());
continue;
}
}

let model_struct = model_struct.expect("model struct not found");

// Handle struct dependencies
let struct_keys: Vec<String> = handled_tokens
.iter()
.filter(|(_, s)| {
model_struct.inners.iter().any(|inner| {
s.r#type == CompositeType::Struct
&& check_token_in_recursively(&inner.token, &s.type_name())
&& inner.token.type_name() != "ByteArray"
})
})
.map(|(k, _)| k.clone())
.collect();

out += UnityPlugin::format_struct(token.to_composite().unwrap()).as_str();
for key in struct_keys {
if let Some(s) = handled_tokens.remove(&key) {
out += UnityPlugin::format_struct(&s).as_str();
}
}

// Process enums
for token in &sorted_enums {
if handled_tokens.contains_key(&token.type_path()) {
continue;
}

handled_tokens.insert(token.type_path(), token.to_composite().unwrap().to_owned());
out += UnityPlugin::format_enum(token.to_composite().unwrap()).as_str();
}

out += "\n";
// Handle enum dependencies
let enum_keys: Vec<String> = handled_tokens
.iter()
.filter(|(_, s)| {
model_struct.inners.iter().any(|inner| {
s.r#type == CompositeType::Enum
&& check_token_in_recursively(&inner.token, &s.type_name())
})
})
.map(|(k, _)| k.clone())
.collect();

out += UnityPlugin::format_model(
&get_namespace_from_tag(&model.tag),
model_struct.expect("model struct not found"),
)
.as_str();
for key in enum_keys {
if let Some(s) = handled_tokens.remove(&key) {
out += UnityPlugin::format_enum(&s).as_str();
}
}

out += "\n";
out +=
UnityPlugin::format_model(&get_namespace_from_tag(&model.tag), model_struct).as_str();

out
}

// Formats a system into a C# method used by the contract class
// Handled tokens should be a list of all structs and enums used by the contract
// Such as a set of referenced tokens from a model
fn format_system(system: &Function, handled_tokens: &HashMap<String, Composite>) -> String {
fn format_system(system: &Function) -> String {
fn handle_arg_recursive(
arg_name: &str,
token: &Token,
handled_tokens: &HashMap<String, Composite>,
// variant name
// if its an enum variant data
enum_variant: Option<String>,
Expand All @@ -304,8 +336,6 @@ namespace {namespace} {{

match token {
Token::Composite(t) => {
let t = handled_tokens.get(&t.type_path).unwrap_or(t);

// Need to flatten the struct members.
match t.r#type {
CompositeType::Struct if t.type_name() == "ByteArray" => vec![(
Expand All @@ -319,7 +349,6 @@ namespace {namespace} {{
tokens.extend(handle_arg_recursive(
&format!("{}.{}", arg_name, f.name),
&f.token,
handled_tokens,
enum_variant.clone(),
));
});
Expand Down Expand Up @@ -360,7 +389,6 @@ namespace {namespace} {{
} else {
field.token.clone()
},
handled_tokens,
Some(field.name.clone()),
))
});
Expand All @@ -375,7 +403,6 @@ namespace {namespace} {{
let inner = handle_arg_recursive(
&format!("{arg_name}Item"),
&array.inner,
handled_tokens,
enum_variant.clone(),
);

Expand Down Expand Up @@ -416,7 +443,6 @@ namespace {namespace} {{
handle_arg_recursive(
&format!("{}.Item{}", arg_name, idx + 1),
token,
handled_tokens,
enum_variant.clone(),
)
})
Expand All @@ -441,7 +467,7 @@ namespace {namespace} {{
.inputs
.iter()
.flat_map(|(name, token)| {
let tokens = handle_arg_recursive(name, token, handled_tokens, None);
let tokens = handle_arg_recursive(name, token, None);

tokens
.iter()
Expand Down Expand Up @@ -477,7 +503,7 @@ namespace {namespace} {{
return await account.ExecuteRaw(new dojo.Call[] {{
new dojo.Call{{
to = contractAddress,
to = new FieldElement(contractAddress).Inner,
selector = \"{system_name}\",
calldata = calldata.ToArray()
}}
Expand Down Expand Up @@ -505,11 +531,7 @@ namespace {namespace} {{
// Will format the contract into a C# class and
// all systems into C# methods
// Handled tokens should be a list of all structs and enums used by the contract
fn handle_contract(
&self,
contract: &DojoContract,
handled_tokens: &HashMap<String, Composite>,
) -> String {
fn handle_contract(&self, contract: &DojoContract) -> String {
let mut out = String::new();
out += UnityPlugin::generated_header().as_str();
out += UnityPlugin::contract_imports().as_str();
Expand All @@ -519,7 +541,7 @@ namespace {namespace} {{
.iter()
// we assume systems dont have outputs
.filter(|s| s.to_function().unwrap().get_output_kind() as u8 == FunctionOutputKind::NoOutput as u8)
.map(|system| UnityPlugin::format_system(system.to_function().unwrap(), handled_tokens))
.map(|system| UnityPlugin::format_system(system.to_function().unwrap()))
.collect::<Vec<String>>()
.join("\n\n ");

Expand All @@ -543,6 +565,22 @@ public class {} : MonoBehaviour {{
}
}

fn check_token_in_recursively(token: &Token, type_name: &str) -> bool {
match token {
Token::Composite(composite) => {
if composite.type_name() == type_name {
return true;
}
composite.inners.iter().any(|inner| check_token_in_recursively(&inner.token, type_name))
}
Token::Array(array) => check_token_in_recursively(&array.inner, type_name),
Token::Tuple(tuple) => {
tuple.inners.iter().any(|inner| check_token_in_recursively(inner, type_name))
}
_ => token.type_name() == type_name,
}
}

#[async_trait]
impl BuiltinPlugin for UnityPlugin {
async fn generate_code(&self, data: &DojoData) -> BindgenResult<HashMap<PathBuf, Vec<u8>>> {
Expand Down Expand Up @@ -572,7 +610,7 @@ impl BuiltinPlugin for UnityPlugin {
let contracts_path = Path::new(&format!("Contracts/{}.gen.cs", name)).to_owned();

println!("Generating contract: {}", name);
let code = self.handle_contract(contract, &handled_tokens);
let code = self.handle_contract(contract);

out.insert(contracts_path, code.as_bytes().to_vec());
}
Expand Down

0 comments on commit 447ba4f

Please sign in to comment.