Skip to content

Commit

Permalink
Merge branch 'main' into use-rpc-solana
Browse files Browse the repository at this point in the history
  • Loading branch information
Funkatronics authored Jul 1, 2024
2 parents a55a00e + e5a9205 commit cc9047e
Show file tree
Hide file tree
Showing 13 changed files with 409 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ val rpcResponse = rpcDriver.makeRequest(rpcRequest, JsonElement.serializer())
```

<!-- TAG_VERSION -->
[badge-latest-release]: https://img.shields.io/badge/latest--release-0.2.4-blue.svg?style=flat
[badge-latest-release]: https://img.shields.io/badge/latest--release-0.2.5-blue.svg?style=flat
[badge-license]: https://img.shields.io/badge/license-Apache%20License%202.0-blue.svg?style=flat

<!-- TAG_DEPENDENCIES -->
Expand Down
3 changes: 2 additions & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx-
kotlinx-coroutines-test = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-test", version.ref = "kotlinxCoroutines" }
ktor-client-core = { group = "io.ktor", name = "ktor-client-core", version.ref = "ktor" }
ktor-client-cio = { group = "io.ktor", name = "ktor-client-cio", version.ref = "ktor" }
multimult = { group = "io.github.funkatronics", name = "multimult", version = "0.2.2" }
multimult = { group = "io.github.funkatronics", name = "multimult", version = "0.2.3" }
rpc-core = { group = "com.solanamobile", name = "rpc-core", version.ref = "rpcCore" }
rpc-ktordriver = { group = "com.solanamobile", name = "rpc-ktordriver", version.ref = "rpcCore" }
rpc-solana = { group = "com.solanamobile", name = "rpc-solana", version.ref = "rpcCore" }
salkt = { group = "io.github.funkatronics", name = "salkt", version = "0.1.0" }

[plugins]
android-library = { id = "com.android.library", version.ref = "androidGradlePlugin" }
Expand Down
1 change: 1 addition & 0 deletions solana/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ kotlin {
implementation(libs.kotlinx.serialization.json)
implementation(libs.borsh)
implementation(libs.multimult)
implementation(libs.salkt)
}
}
val commonTest by getting {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ object MemoProgram {
listOf(AccountMeta(account, true, true)),
memo.encodeToByteArray()
)

override val programId = PROGRAM_ID
}
41 changes: 41 additions & 0 deletions solana/src/commonMain/kotlin/com/solana/programs/Program.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.solana.programs

import com.funkatronics.hash.Sha256
import com.funkatronics.salt.isOnCurve
import com.solana.publickey.ProgramDerivedAddress
import com.solana.publickey.PublicKey
import com.solana.publickey.SolanaPublicKey
import kotlin.jvm.JvmStatic

interface Program {
val programId: SolanaPublicKey

suspend fun createDerivedAddress(seeds: List<ByteArray>) =
createDerivedAddress(seeds, programId)

suspend fun findDerivedAddress(seeds: List<ByteArray>) =
findDerivedAddress(seeds, programId)

companion object {
@JvmStatic
suspend fun findDerivedAddress(seeds: List<ByteArray>, programId: PublicKey) =
ProgramDerivedAddress.find(seeds, programId)

@JvmStatic
suspend fun createDerivedAddress(seeds: List<ByteArray>, programId: PublicKey): Result<SolanaPublicKey> {
val address = Sha256.hash(
seeds.foldIndexed(ByteArray(0)) { i, a, s ->
require(s.size <= 32) { "Seed length must be <= 32 bytes" }; a + s
} + programId.bytes + "ProgramDerivedAddress".encodeToByteArray()
)

if (address.isOnCurve()) {
return Result.failure(
IllegalArgumentException("Invalid seeds, address must fall off curve")
)
}

return Result.success(SolanaPublicKey(address))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,6 @@ object SystemProgram {
encodeSerializableValue(ByteArraySerializer(), programId.bytes)
}.borshEncodedBytes
)

override val programId = PROGRAM_ID
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.solana.publickey

import com.funkatronics.salt.isOnCurve

suspend fun PublicKey.isOnCurve() = bytes.isOnCurve()
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.solana.publickey

import com.funkatronics.salt.isOnCurve
import com.solana.programs.Program
import kotlin.jvm.JvmStatic

class ProgramDerivedAddress private constructor(bytes: ByteArray, val nonce: UByte) : SolanaPublicKey(bytes) {

private constructor(publicKey: PublicKey, nonce: UByte) : this(publicKey.bytes, nonce)

companion object {
@JvmStatic
suspend fun find(seeds: List<ByteArray>, programId: PublicKey): Result<ProgramDerivedAddress> {
for (bump in 255 downTo 0) {
val result = Program.createDerivedAddress(seeds + byteArrayOf(bump.toByte()), programId)
if (result.isSuccess) return result.map { ProgramDerivedAddress(it, bump.toUByte()) }
}
return Result.failure(Error("Unable to find valid derived address for provided seeds"))
}

@JvmStatic
suspend fun create(bytes: ByteArray, nonce: UByte): ProgramDerivedAddress {
require(!bytes.isOnCurve()) { "Provided public key is not a PDA, address must be off Ed25519 curve" }
return ProgramDerivedAddress(bytes, nonce)
}

@JvmStatic
suspend fun create(publicKey: PublicKey, nonce: UByte) =
ProgramDerivedAddress(publicKey.bytes, nonce)
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
package com.solana.publickey

import com.funkatronics.encoders.Base58
import com.funkatronics.kborsh.BorshDecoder
import com.funkatronics.kborsh.BorshEncoder
import com.solana.serialization.ByteStringSerializer
import com.solana.serialization.TransactionDecoder
import com.solana.serialization.TransactionEncoder
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonDecoder

@Serializable(with=SolanaPublicKeySerializer::class)
open class SolanaPublicKey(override val bytes: ByteArray) : PublicKey {
open class SolanaPublicKey(final override val bytes: ByteArray) : PublicKey {

init {
check (bytes.size == PUBLIC_KEY_LENGTH)
Expand All @@ -28,16 +35,42 @@ open class SolanaPublicKey(override val bytes: ByteArray) : PublicKey {
override fun equals(other: Any?): Boolean {
return (other is PublicKey) && this.bytes.contentEquals(other.bytes)
}

override fun hashCode(): Int = bytes.contentHashCode()

override fun toString() = base58()
}

object SolanaPublicKeySerializer : KSerializer<SolanaPublicKey> {
private val delegate = ByteStringSerializer(SolanaPublicKey.PUBLIC_KEY_LENGTH)
override val descriptor: SerialDescriptor = delegate.descriptor
private val borshDelegate = ByteStringSerializer(SolanaPublicKey.PUBLIC_KEY_LENGTH)
private val jsonDelegate = String.serializer()
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("SolanaPublicKey")

override fun deserialize(decoder: Decoder): SolanaPublicKey =
SolanaPublicKey(decoder.decodeSerializableValue(delegate))
when (decoder) {
is BorshDecoder, is TransactionDecoder ->
SolanaPublicKey(decoder.decodeSerializableValue(borshDelegate))
is JsonDecoder ->
SolanaPublicKey.from(decoder.decodeSerializableValue(jsonDelegate))
else ->
runCatching {
SolanaPublicKey.from(decoder.decodeSerializableValue(jsonDelegate))
}.getOrElse {
SolanaPublicKey(decoder.decodeSerializableValue(borshDelegate))
}
}

override fun serialize(encoder: Encoder, value: SolanaPublicKey) {
encoder.encodeSerializableValue(delegate, value.bytes)
}
override fun serialize(encoder: Encoder, value: SolanaPublicKey) =
when (encoder) {
is BorshEncoder, is TransactionEncoder ->
encoder.encodeSerializableValue(borshDelegate, value.bytes)
is JsonDecoder ->
encoder.encodeSerializableValue(jsonDelegate, value.base58())
else ->
runCatching {
encoder.encodeSerializableValue(jsonDelegate, value.base58())
}.getOrElse {
encoder.encodeSerializableValue(borshDelegate, value.bytes)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.solana.serialization

import com.funkatronics.hash.Sha256
import kotlinx.serialization.KSerializer
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.serializer

open class DiscriminatorSerializer<T>(val discriminator: ByteArray, serializer: KSerializer<T>)
: KSerializer<T> {

private val accountSerializer = serializer
override val descriptor: SerialDescriptor = accountSerializer.descriptor

override fun serialize(encoder: Encoder, value: T) {
discriminator.forEach { encoder.encodeByte(it) }
accountSerializer.serialize(encoder, value)
}

override fun deserialize(decoder: Decoder): T {
ByteArray(discriminator.size).map { decoder.decodeByte() }
return accountSerializer.deserialize(decoder)
}
}

open class AnchorDiscriminatorSerializer<T>(namespace: String, ixName: String,
serializer: KSerializer<T>)
: DiscriminatorSerializer<T>(buildDiscriminator(namespace, ixName), serializer) {
companion object {
private fun buildDiscriminator(namespace: String, ixName: String) =
Sha256.hash("$namespace:$ixName".encodeToByteArray()).sliceArray(0 until 8)
}
}

class AnchorInstructionSerializer<T>(namespace: String, ixName: String, serializer: KSerializer<T>)
: AnchorDiscriminatorSerializer<T>(namespace, ixName, serializer) {
constructor(ixName: String, serializer: KSerializer<T>) : this("global", ixName, serializer)
}

inline fun <reified A> AnchorInstructionSerializer(namespace: String, ixName: String) =
AnchorInstructionSerializer<A>(namespace, ixName, serializer())

inline fun <reified A> AnchorInstructionSerializer(ixName: String) =
AnchorInstructionSerializer<A>(ixName, serializer())
80 changes: 80 additions & 0 deletions solana/src/commonTest/kotlin/com/solana/programs/ProgramTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package com.solana.programs

import com.solana.publickey.SolanaPublicKey
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

class ProgramTests {

@Test
fun `createProgramAddress returns failure for on chain pubkey`() = runTest {
// given
val seeds = listOf("helloWorld".encodeToByteArray(), byteArrayOf(255.toByte()))
val program = object : Program {
override val programId = SolanaPublicKey.from("11111111111111111111111111111111")
}

// when
val result = program.createDerivedAddress(seeds)

// then
assertTrue { result.isFailure }
}

@Test
fun `createProgramAddress returns expected pubkey for nonce`() = runTest {
// given
val seeds = listOf("helloWorld".encodeToByteArray(), byteArrayOf(252.toByte()))
val expectedPublicKey = SolanaPublicKey.from("THfBMgduMonjaNsCisKa7Qz2cBoG1VCUYHyso7UXYHH")
val program = object : Program {
override val programId = SolanaPublicKey.from("11111111111111111111111111111111")
}

// when
val result = program.createDerivedAddress(seeds)

// then
assertTrue { result.isSuccess }
assertEquals(expectedPublicKey, result.getOrNull()!!)
}

@Test
fun `findProgramAddress returns expected pubkey and nonce`() = runTest {
// given
val seeds = listOf<ByteArray>()
val expectedBump = 255.toUByte()
val expectedPublicKey = SolanaPublicKey.from("Cu7NwqCXSmsR5vgGA3Vw9uYVViPi3kQvkbKByVQ8nPY9")
val program = object : Program {
override val programId = SolanaPublicKey.from("11111111111111111111111111111111")
}

// when
val result = program.findDerivedAddress(seeds)

// then
assertTrue { result.isSuccess }
assertEquals(expectedPublicKey, result.getOrNull()!!)
assertEquals(expectedBump, result.getOrNull()!!.nonce)
}

@Test
fun `findProgramAddress returns expected pubkey and nonce for seeds`() = runTest {
// given
val seeds = listOf("helloWorld".encodeToByteArray())
val expectedBump = 254.toUByte()
val expectedPublicKey = SolanaPublicKey.from("46GZzzetjCURsdFPb7rcnspbEMnCBXe9kpjrsZAkKb6X")
val program = object : Program {
override val programId = SolanaPublicKey.from("11111111111111111111111111111111")
}

// when
val result = program.findDerivedAddress(seeds)

// then
assertTrue { result.isSuccess }
assertEquals(expectedPublicKey, result.getOrNull()!!)
assertEquals(expectedBump, result.getOrNull()!!.nonce)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.solana.serialization

import com.funkatronics.hash.Sha256
import com.funkatronics.kborsh.Borsh
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToByteArray
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals

class AnchorDiscriminatorSerializerTests {

@Test
fun `discriminator is first 8 bytes of identifier hash`() {
// given
val namespace = "test"
val ixName = "testInstruction"
val data = "data"
val expectedDiscriminator = Sha256.hash(
"$namespace:$ixName".encodeToByteArray()
).sliceArray(0..7)

// when
val serialized = Borsh.encodeToByteArray(AnchorInstructionSerializer(namespace, ixName), data)

// then
assertContentEquals(expectedDiscriminator, serialized.sliceArray(0..7))
}

@Test
fun `data is serialized after 8 byte identifier hash`() {
// given
val ixName = "testInstruction"
val data = "data"
val expectedEncodedData = Borsh.encodeToByteArray(data)

// when
val serialized = Borsh.encodeToByteArray(AnchorInstructionSerializer(ixName), data)

// then
assertContentEquals(expectedEncodedData, serialized.sliceArray(8 until serialized.size))
}

@Test
fun `serialize and deserialize data struct`() {
// given
@Serializable data class TestData(val name: String, val number: Int, val boolean: Boolean)
val ixName = "testInstruction"
val data = TestData("testName", 12345678, true)
val expectedEncodedData = Borsh.encodeToByteArray(data)

// when
val serialized = Borsh.encodeToByteArray(AnchorInstructionSerializer(ixName), data)
val deserialized: TestData = Borsh.decodeFromByteArray(AnchorInstructionSerializer(ixName), serialized)

// then
assertContentEquals(expectedEncodedData, serialized.sliceArray(8 until serialized.size))
assertEquals(data, deserialized)
}
}
Loading

0 comments on commit cc9047e

Please sign in to comment.