diff --git a/lang/syn/src/codegen/accounts/constraints.rs b/lang/syn/src/codegen/accounts/constraints.rs index 387c713dd3..75f4a00f58 100644 --- a/lang/syn/src/codegen/accounts/constraints.rs +++ b/lang/syn/src/codegen/accounts/constraints.rs @@ -169,7 +169,7 @@ pub fn generate_constraint_zeroed(f: &Field, _c: &ConstraintZeroed) -> proc_macr let field = &f.ident; let name_str = field.to_string(); let ty_decl = f.ty_decl(); - let from_account_info = f.from_account_info_unchecked(None); + let from_account_info = f.from_account_info(None, false); quote! { let #field: #ty_decl = { let mut __data: &[u8] = &#field.try_borrow_data()?; @@ -340,7 +340,8 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma }; // Convert from account info to account context wrapper type. - let from_account_info = f.from_account_info_unchecked(Some(&c.kind)); + let from_account_info = f.from_account_info(Some(&c.kind), true); + let from_account_info_unchecked = f.from_account_info(Some(&c.kind), false); // PDA bump seeds. let (find_pda, seeds_with_bump) = match &c.seeds { @@ -408,7 +409,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma anchor_spl::token::initialize_account(cpi_ctx)?; } - let pa: #ty_decl = #from_account_info; + let pa: #ty_decl = #from_account_info_unchecked; if #if_needed { if pa.mint != #mint.key() { return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintTokenMint).with_account_name(#name_str).with_pubkeys((pa.mint, #mint.key()))); @@ -443,7 +444,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma let cpi_ctx = anchor_lang::context::CpiContext::new(cpi_program, cpi_accounts); anchor_spl::associated_token::create(cpi_ctx)?; } - let pa: #ty_decl = #from_account_info; + let pa: #ty_decl = #from_account_info_unchecked; if #if_needed { if pa.mint != #mint.key() { return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintTokenMint).with_account_name(#name_str).with_pubkeys((pa.mint, #mint.key()))); @@ -496,7 +497,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma let cpi_ctx = anchor_lang::context::CpiContext::new(cpi_program, accounts); anchor_spl::token::initialize_mint(cpi_ctx, #decimals, &#owner.key(), #freeze_authority)?; } - let pa: #ty_decl = #from_account_info; + let pa: #ty_decl = #from_account_info_unchecked; if #if_needed { if pa.mint_authority != anchor_lang::solana_program::program_option::COption::Some(#owner.key()) { return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintMintMintAuthority).with_account_name(#name_str)); @@ -548,16 +549,19 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma // Create the account. Always do this in the event // if needed is not specified or the system program is the owner. - if !#if_needed || actual_owner == &anchor_lang::solana_program::system_program::ID { + let pa: #ty_decl = if !#if_needed || actual_owner == &anchor_lang::solana_program::system_program::ID { // Define the payer variable. #payer // CPI to the system program to create. #create_account - } - // Convert from account info to account context wrapper type. - let pa: #ty_decl = #from_account_info; + // Convert from account info to account context wrapper type. + #from_account_info_unchecked + } else { + // Convert from account info to account context wrapper type. + #from_account_info + }; // Assert the account was created correctly. if #if_needed { diff --git a/lang/syn/src/lib.rs b/lang/syn/src/lib.rs index ac291b6bb9..7bde237cea 100644 --- a/lang/syn/src/lib.rs +++ b/lang/syn/src/lib.rs @@ -281,20 +281,33 @@ impl Field { // TODO: remove the option once `CpiAccount` is completely removed (not // just deprecated). - pub fn from_account_info_unchecked(&self, kind: Option<&InitKind>) -> proc_macro2::TokenStream { + pub fn from_account_info( + &self, + kind: Option<&InitKind>, + checked: bool, + ) -> proc_macro2::TokenStream { let field = &self.ident; let container_ty = self.container_ty(); + let owner_addr = match &kind { + None => quote! { program_id }, + Some(InitKind::Program { .. }) => quote! { + program_id + }, + _ => quote! { + &anchor_spl::token::ID + }, + }; match &self.ty { Ty::AccountInfo => quote! { #field.to_account_info() }, Ty::UncheckedAccount => { quote! { UncheckedAccount::try_from(#field.to_account_info()) } } Ty::Account(AccountTy { boxed, .. }) => { - if *boxed { + let stream = if checked { quote! { - Box::new(#container_ty::try_from_unchecked( + #container_ty::try_from( &#field, - )?) + )? } } else { quote! { @@ -302,30 +315,61 @@ impl Field { &#field, )? } + }; + if *boxed { + quote! { + Box::new(#stream) + } + } else { + stream } } Ty::CpiAccount(_) => { - quote! { - #container_ty::try_from_unchecked( - &#field, - )? + if checked { + quote! { + #container_ty::try_from( + &#field, + )? + } + } else { + quote! { + #container_ty::try_from_unchecked( + &#field, + )? + } + } + } + Ty::AccountLoader(_) => { + if checked { + quote! { + #container_ty::try_from( + &#field, + )? + } + } else { + quote! { + #container_ty::try_from_unchecked( + #owner_addr, + &#field, + )? + } } } _ => { - let owner_addr = match &kind { - None => quote! { program_id }, - Some(InitKind::Program { .. }) => quote! { - program_id - }, - _ => quote! { - &anchor_spl::token::ID - }, - }; - quote! { - #container_ty::try_from_unchecked( - #owner_addr, - &#field, - )? + if checked { + quote! { + #container_ty::try_from( + #owner_addr, + &#field, + )? + } + } else { + quote! { + #container_ty::try_from_unchecked( + #owner_addr, + &#field, + )? + } } } } diff --git a/tests/misc/Anchor.toml b/tests/misc/Anchor.toml index 30bc243fb6..d80f984a9d 100644 --- a/tests/misc/Anchor.toml +++ b/tests/misc/Anchor.toml @@ -5,13 +5,7 @@ wallet = "~/.config/solana/id.json" [programs.localnet] misc = "3TEqcc8xhrhdspwbvoamUJe2borm4Nr72JxL66k6rgrh" misc2 = "HmbTLCmaGvZhKnn1Zfa1JVnp7vkMV4DYVxPLWBVoN65L" - -[[test.genesis]] -address = "FtMNMKp9DZHKWUyVAsj3Q5QV8ow4P3fUPP7ZrWEQJzKr" -program = "./target/deploy/misc.so" +init_if_needed = "BZoppwWi6jMnydnUBEJzotgEXHwLr3b3NramJgZtWeF2" [workspace] exclude = ["programs/shared"] - -[scripts] -test = "yarn run ts-mocha -t 1000000 tests/*.ts" diff --git a/tests/misc/ci.sh b/tests/misc/ci.sh index 6648ddc57f..a08b0f4938 100755 --- a/tests/misc/ci.sh +++ b/tests/misc/ci.sh @@ -7,5 +7,5 @@ # a validator with a version < 1.9, so it can test # whether anchor's rent-exemption checks work for # legacy accounts which dont have to be rent-exempt -rm ./tests/misc.ts -mv miscNonRentExempt.ts ./tests/miscNonRentExempt.ts +rm ./tests/misc/misc.ts +mv miscNonRentExempt.ts ./tests/misc/miscNonRentExempt.ts diff --git a/tests/misc/miscNonRentExempt.ts b/tests/misc/miscNonRentExempt.ts index 50b70c1f25..df91d50472 100644 --- a/tests/misc/miscNonRentExempt.ts +++ b/tests/misc/miscNonRentExempt.ts @@ -6,7 +6,7 @@ import { SystemProgram, SYSVAR_RENT_PUBKEY, } from "@solana/web3.js"; -import { Misc } from "../target/types/misc"; +import { Misc } from "../../target/types/misc"; const { assert } = require("chai"); describe("miscNonRentExempt", () => { diff --git a/tests/misc/programs/init-if-needed/Cargo.toml b/tests/misc/programs/init-if-needed/Cargo.toml new file mode 100644 index 0000000000..8843138b0b --- /dev/null +++ b/tests/misc/programs/init-if-needed/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "init-if-needed" +version = "0.1.0" +description = "Created with Anchor" +edition = "2018" + +[lib] +crate-type = ["cdylib", "lib"] +name = "init_if_needed" + +[features] +no-entrypoint = [] +no-idl = [] +no-log-ix-name = [] +cpi = ["no-entrypoint"] +default = [] + +[dependencies] +anchor-lang = { path = "../../../../lang", features = ["init-if-needed"] } diff --git a/tests/misc/programs/init-if-needed/Xargo.toml b/tests/misc/programs/init-if-needed/Xargo.toml new file mode 100644 index 0000000000..475fb71ed1 --- /dev/null +++ b/tests/misc/programs/init-if-needed/Xargo.toml @@ -0,0 +1,2 @@ +[target.bpfel-unknown-unknown.dependencies.std] +features = [] diff --git a/tests/misc/programs/init-if-needed/src/lib.rs b/tests/misc/programs/init-if-needed/src/lib.rs new file mode 100644 index 0000000000..ea7abe1feb --- /dev/null +++ b/tests/misc/programs/init-if-needed/src/lib.rs @@ -0,0 +1,61 @@ +use anchor_lang::prelude::*; + +declare_id!("BZoppwWi6jMnydnUBEJzotgEXHwLr3b3NramJgZtWeF2"); + +#[program] +pub mod init_if_needed { + use super::*; + + // _val only used to make tx different so that it doesn't result + // in dup tx error + pub fn initialize(ctx: Context, _val: u8) -> Result<()> { + ctx.accounts.acc.val = 1000; + Ok(()) + } + + pub fn second_initialize(ctx: Context, _val: u8) -> Result<()> { + ctx.accounts.acc.other_val = 2000; + Ok(()) + } + + pub fn close(ctx: Context) -> Result<()> { + ctx.accounts.acc.val = 5000; + Ok(()) + } +} + +#[account] +pub struct MyData { + pub val: u64, +} + +#[account] +pub struct OtherData { + pub other_val: u64, +} + +#[derive(Accounts)] +pub struct Initialize<'info> { + #[account(init_if_needed, payer = payer, space = 8 + 8)] + pub acc: Account<'info, MyData>, + #[account(mut)] + pub payer: Signer<'info>, + pub system_program: Program<'info, System>, +} + +#[derive(Accounts)] +pub struct SecondInitialize<'info> { + #[account(init, payer = payer, space = 8 + 8)] + pub acc: Account<'info, OtherData>, + #[account(mut)] + pub payer: Signer<'info>, + pub system_program: Program<'info, System>, +} + +#[derive(Accounts)] +pub struct Close<'info> { + #[account(mut, close = receiver)] + pub acc: Account<'info, MyData>, + #[account(mut)] + pub receiver: UncheckedAccount<'info>, +} diff --git a/tests/misc/tests/init-if-needed/Test.toml b/tests/misc/tests/init-if-needed/Test.toml new file mode 100644 index 0000000000..9eb55d60cc --- /dev/null +++ b/tests/misc/tests/init-if-needed/Test.toml @@ -0,0 +1,2 @@ +[scripts] +test = "yarn run ts-mocha -t 1000000 ./tests/init-if-needed/*.ts" diff --git a/tests/misc/tests/init-if-needed/init-if-needed.ts b/tests/misc/tests/init-if-needed/init-if-needed.ts new file mode 100644 index 0000000000..dbeaf3c1b5 --- /dev/null +++ b/tests/misc/tests/init-if-needed/init-if-needed.ts @@ -0,0 +1,101 @@ +import * as anchor from "@project-serum/anchor"; +import { AnchorError, Program } from "@project-serum/anchor"; +import { InitIfNeeded } from "../../target/types/init_if_needed"; +import { SystemProgram, LAMPORTS_PER_SOL } from "@solana/web3.js"; +import { expect } from "chai"; + +describe("init-if-needed", () => { + anchor.setProvider(anchor.AnchorProvider.env()); + + const program = anchor.workspace.InitIfNeeded as Program; + + it("init_if_needed should reject a CLOSED discriminator if init is NOT NEEDED", async () => { + const account = anchor.web3.Keypair.generate(); + + await program.methods + .initialize(1) + .accounts({ + acc: account.publicKey, + }) + .signers([account]) + .rpc(); + + const oldState = await program.account.myData.fetch(account.publicKey); + expect(oldState.val.toNumber()).to.equal(1000); + + // This initialize call should fail because the account has the account discriminator + // set to the CLOSED one + try { + await program.methods + .initialize(5) + .accounts({ + acc: account.publicKey, + }) + .signers([account]) + .preInstructions([ + await program.methods + .close() + .accounts({ + acc: account.publicKey, + receiver: program.provider.wallet.publicKey, + }) + .instruction(), + SystemProgram.transfer({ + fromPubkey: program.provider.wallet.publicKey, + toPubkey: account.publicKey, + lamports: 1 * LAMPORTS_PER_SOL, + }), + ]) + .rpc(); + } catch (_err) { + expect(_err).to.be.instanceOf(AnchorError); + const err: AnchorError = _err; + expect(err.error.errorCode.code).to.equal("AccountDiscriminatorMismatch"); + } + }); + + it("init_if_needed should reject a discriminator of a different account if init is NOT NEEDED", async () => { + const account = anchor.web3.Keypair.generate(); + console.log("account: ", account.publicKey.toBase58()); + const otherAccount = anchor.web3.Keypair.generate(); + console.log("otherAccount: ", otherAccount.publicKey.toBase58()); + + await program.methods + .initialize(1) + .accounts({ + acc: account.publicKey, + }) + .signers([account]) + .rpc(); + + const oldState = await program.account.myData.fetch(account.publicKey); + expect(oldState.val.toNumber()).to.equal(1000); + + await program.methods + .secondInitialize(1) + .accounts({ + acc: otherAccount.publicKey, + }) + .signers([otherAccount]) + .rpc(); + + const secondState = await program.account.otherData.fetch( + otherAccount.publicKey + ); + expect(secondState.otherVal.toNumber()).to.equal(2000); + + try { + await program.methods + .initialize(3) + .accounts({ + acc: otherAccount.publicKey, + }) + .signers([otherAccount]) + .rpc(); + } catch (_err) { + expect(_err).to.be.instanceOf(AnchorError); + const err: AnchorError = _err; + expect(err.error.errorCode.code).to.equal("AccountDiscriminatorMismatch"); + } + }); +}); diff --git a/tests/misc/tests/misc/Test.toml b/tests/misc/tests/misc/Test.toml new file mode 100644 index 0000000000..bcacd6b4b0 --- /dev/null +++ b/tests/misc/tests/misc/Test.toml @@ -0,0 +1,6 @@ +[[test.genesis]] +address = "FtMNMKp9DZHKWUyVAsj3Q5QV8ow4P3fUPP7ZrWEQJzKr" +program = "../../target/deploy/misc.so" + +[scripts] +test = "yarn run ts-mocha -t 1000000 ./tests/misc/*.ts" diff --git a/tests/misc/tests/misc.ts b/tests/misc/tests/misc/misc.ts similarity index 99% rename from tests/misc/tests/misc.ts rename to tests/misc/tests/misc/misc.ts index cc28474acc..3f86dbbb5d 100644 --- a/tests/misc/tests/misc.ts +++ b/tests/misc/tests/misc/misc.ts @@ -11,12 +11,12 @@ import { Token, ASSOCIATED_TOKEN_PROGRAM_ID, } from "@solana/spl-token"; -import { Misc } from "../target/types/misc"; -import { Misc2 } from "../target/types/misc2"; +import { Misc } from "../../target/types/misc"; +import { Misc2 } from "../../target/types/misc2"; const utf8 = anchor.utils.bytes.utf8; const { assert, expect } = require("chai"); const nativeAssert = require("assert"); -const miscIdl = require("../target/idl/misc.json"); +const miscIdl = require("../../target/idl/misc.json"); describe("misc", () => { // Configure the client to use the local cluster. @@ -160,11 +160,9 @@ describe("misc", () => { "Program data: jvbowsvlmkcJAAAA", "Program data: zxM5neEnS1kBAgMEBQYHCAkK", "Program data: g06Ei2GL1gIBAgMEBQYHCAkKCw==", - "Program 3TEqcc8xhrhdspwbvoamUJe2borm4Nr72JxL66k6rgrh consumed 3983 of 1400000 compute units", - "Program 3TEqcc8xhrhdspwbvoamUJe2borm4Nr72JxL66k6rgrh success", ]; - assert.deepStrictEqual(expectedRaw, resp.raw); + assert.deepStrictEqual(expectedRaw, resp.raw.slice(0, -2)); assert.strictEqual(resp.events[0].name, "E1"); assert.strictEqual(resp.events[0].data.data, 44); assert.strictEqual(resp.events[1].name, "E2");