Skip to content

Commit

Permalink
feat(pollux): [ATL-2679] Improve Error Hanlding and Verification (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
CryptoKnightIOG authored Dec 12, 2022
1 parent d4010fb commit 6348e13
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import io.circe.generic.auto.*
import io.circe.parser.decode
import io.circe.syntax.*
import io.circe.{Decoder, Encoder, HCursor, Json}
import io.iohk.atala.pollux.vc.jwt.JWTVerification.{extractPublicKey, validateEncodedJwt}
import io.iohk.atala.pollux.vc.jwt.schema.{SchemaResolver, SchemaValidator}
import net.reactivecore.cjs.validator.Violation
import net.reactivecore.cjs.{DocumentValidator, Loader}
import pdi.jwt.*
import zio.ZIO.none
import zio.prelude.*
import zio.{IO, NonEmptyChunk, Task, ZIO}

Expand All @@ -24,75 +26,94 @@ import scala.util.{Failure, Success, Try}
object JWTVerification {
def validateEncodedJwt[T](jwt: JWT)(
didResolver: DidResolver
)(decoder: String => IO[String, T])(issuerDidExtractor: T => String): IO[String, Boolean] = {
val decodeJWT = ZIO
)(decoder: String => Validation[String, T])(issuerDidExtractor: T => String): IO[String, Validation[String, Unit]] = {
val decodeJWT = Validation
.fromTry(JwtCirce.decodeRawAll(jwt.value, JwtOptions(false, false, false)))
.mapError(_.getMessage)

val extractAlgorithm =
val extractAlgorithm: Validation[String, JwtAlgorithm] =
for {
decodedJwtTask <- decodeJWT
(header, _, _) = decodedJwtTask
algorithm <- Validation
.fromOptionWith("An algorithm must be specified in the header")(JwtCirce.parseHeader(header).algorithm)
.toZIO
} yield algorithm

val loadDidDocument =
val validatedIssuerDid: Validation[String, String] =
for {
decodedJwtTask <- decodeJWT
(_, claim, _) = decodedJwtTask
decodedClaim <- decoder(claim)
extractIssuerDid = issuerDidExtractor(decodedClaim)
resolvedDidDocument <- resolve(extractIssuerDid)(didResolver)
} yield resolvedDidDocument
} yield extractIssuerDid

val loadDidDocument =
ValidationUtils
.foreach(
validatedIssuerDid
.map(validIssuerDid => resolve(validIssuerDid)(didResolver))
)(identity)
.map(b => b.flatten)

for {
results <- loadDidDocument validatePar extractAlgorithm
(didDocument, algorithm) = results
verificationMethods <- extractVerificationMethods(didDocument, algorithm)
} yield validateEncodedJwt(jwt, verificationMethods)
loadDidDocument
.map(validatedDidDocument => {
for {
results <- Validation.validateWith(validatedDidDocument, extractAlgorithm)((didDocument, algorithm) =>
(didDocument, algorithm)
)
(didDocument, algorithm) = results
verificationMethods <- extractVerificationMethods(didDocument, algorithm)
validatedJwt <- validateEncodedJwt(jwt, verificationMethods)
} yield validatedJwt
})
}

def validateEncodedJwt(jwt: JWT, publicKey: PublicKey): Boolean =
JwtCirce.isValid(jwt.value, publicKey)
def validateEncodedJwt(jwt: JWT, publicKey: PublicKey): Validation[String, Unit] =
if JwtCirce.isValid(jwt.value, publicKey) then Validation.unit
else Validation.fail(s"Jwt[$jwt] not singed by $publicKey")

def validateEncodedJwt(jwt: JWT, verificationMethods: IndexedSeq[VerificationMethod]): Boolean = {
verificationMethods.exists(verificationMethod =>
toPublicKey(verificationMethod).exists(publicKey => validateEncodedJwt(jwt, publicKey))
)
def validateEncodedJwt(jwt: JWT, verificationMethods: IndexedSeq[VerificationMethod]): Validation[String, Unit] = {
verificationMethods
.map(verificationMethod => {
for {
publicKey <- extractPublicKey(verificationMethod)
signatureValidation <- validateEncodedJwt(jwt, publicKey)
} yield signatureValidation
})
.reduce((v1, v2) => v1.orElse(v2))
}

private def resolve(issuerDid: String)(didResolver: DidResolver): IO[String, DIDDocument] = {
private def resolve(issuerDid: String)(didResolver: DidResolver): IO[String, Validation[String, DIDDocument]] = {
didResolver
.resolve(issuerDid)
.flatMap(
.map(
_ match
case (didResolutionSucceeded: DIDResolutionSucceeded) =>
ZIO.succeed(didResolutionSucceeded.didDocument)
case (didResolutionFailed: DIDResolutionFailed) => ZIO.fail(didResolutionFailed.error.toString)
Validation.succeed(didResolutionSucceeded.didDocument)
case (didResolutionFailed: DIDResolutionFailed) => Validation.fail(didResolutionFailed.error.toString)
)
}

private def extractVerificationMethods(
didDocument: DIDDocument,
jwtAlgorithm: JwtAlgorithm
): IO[String, IndexedSeq[VerificationMethod]] = {
): Validation[String, IndexedSeq[VerificationMethod]] = {
Validation
.fromPredicateWith("No PublicKey to validate against found")(
didDocument.verificationMethod.filter(verification => verification.`type` == jwtAlgorithm.name)
)(_.nonEmpty)
.toZIO
}

// TODO Implement other key types
def toPublicKey(verificationMethod: VerificationMethod): Option[PublicKey] = {
for {
publicKeyJwk <- verificationMethod.publicKeyJwk
curve <- publicKeyJwk.crv
x <- publicKeyJwk.x.map(Base64URL.from)
y <- publicKeyJwk.y.map(Base64URL.from)
d <- publicKeyJwk.d.map(Base64URL.from)
} yield new ECKey.Builder(Curve.parse(curve), x, y).d(d).build().toPublicKey
def extractPublicKey(verificationMethod: VerificationMethod): Validation[String, PublicKey] = {
val maybePublicKey =
for {
publicKeyJwk <- verificationMethod.publicKeyJwk
curve <- publicKeyJwk.crv
x <- publicKeyJwk.x.map(Base64URL.from)
y <- publicKeyJwk.y.map(Base64URL.from)
d <- publicKeyJwk.d.map(Base64URL.from)
} yield new ECKey.Builder(Curve.parse(curve), x, y).d(d).build().toPublicKey
Validation.fromOptionWith("Unable to parse Public Key")(maybePublicKey)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package io.iohk.atala.pollux.vc.jwt

import zio.{Trace, ZIO}
import zio.prelude.{Validation, ZValidation}

object ValidationUtils {
final def foreach[R, E, W, VE, A, B](in: ZValidation[W, VE, A])(f: A => ZIO[R, E, B])(implicit
trace: Trace
): ZIO[R, E, ZValidation[W, VE, B]] =
in.fold(e => ZIO.succeed(Validation.failNonEmptyChunk(e)), f(_).map(Validation.succeed))

}
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ object CredentialPayloadValidation {

def validateCredentialSchema(
maybeCredentialSchema: Option[Json]
)(schemaToValidator: Json => Either[String, SchemaValidator]): Validation[String, Option[SchemaValidator]] = {
)(schemaToValidator: Json => Validation[String, SchemaValidator]): Validation[String, Option[SchemaValidator]] = {
maybeCredentialSchema.fold(Validation.succeed(Option.empty))(credentialSchema => {
Validation.fromEither(schemaToValidator(credentialSchema)).map(Some(_))
schemaToValidator(credentialSchema).map(Some(_))
})
}

Expand Down Expand Up @@ -179,7 +179,7 @@ object CredentialPayloadValidation {
) { (`@context`, `type`) => credentialPayload }

def validateSchema[C <: CredentialPayload](credentialPayload: C)(schemaResolver: SchemaResolver)(
schemaToValidator: Json => Either[String, SchemaValidator]
schemaToValidator: Json => Validation[String, SchemaValidator]
): IO[String, C] =
val validation =
for {
Expand Down Expand Up @@ -553,48 +553,64 @@ object JwtCredential {

def validateEncodedJWT(
jwt: JWT
)(didResolver: DidResolver): IO[String, Boolean] = {
)(didResolver: DidResolver): IO[String, Validation[String, Unit]] = {
JWTVerification.validateEncodedJwt(jwt)(didResolver: DidResolver)(claim =>
ZIO.fromEither(decode[JwtCredentialPayload](claim).left.map(_.toString))
Validation.fromEither(decode[JwtCredentialPayload](claim).left.map(_.toString))
)(_.iss)
}

def validateW3C(
payload: W3cVerifiableCredentialPayload
)(didResolver: DidResolver): IO[String, Boolean] = {
)(didResolver: DidResolver): IO[String, Validation[String, Unit]] = {
JWTVerification.validateEncodedJwt(payload.proof.jwt)(didResolver: DidResolver)(claim =>
ZIO.fromEither(decode[W3cCredentialPayload](claim).left.map(_.toString))
Validation.fromEither(decode[W3cCredentialPayload](claim).left.map(_.toString))
)(_.issuer.value)
}

def validateCredential(
verifiableCredentialPayload: VerifiableCredentialPayload
)(didResolver: DidResolver): IO[String, Validation[String, Unit]] = {
verifiableCredentialPayload match {
case (w3cVerifiableCredentialPayload: W3cVerifiableCredentialPayload) =>
JwtCredential.validateW3C(w3cVerifiableCredentialPayload)(didResolver)
case (jwtVerifiableCredentialPayload: JwtVerifiableCredentialPayload) =>
JwtCredential.validateEncodedJWT(jwtVerifiableCredentialPayload.jwt)(didResolver)
}
}

def validateJwtSchema(
jwt: JWT
)(schemaResolver: SchemaResolver)(
schemaToValidator: Json => Either[String, SchemaValidator]
): IO[String, Boolean] = {
val decodeJWT = ZIO
.fromTry(JwtCirce.decodeRawAll(jwt.value, JwtOptions(false, false, false)))
.mapError(_.getMessage)
schemaToValidator: Json => Validation[String, SchemaValidator]
): IO[String, Validation[String, Unit]] = {
val decodeJWT =
Validation.fromTry(JwtCirce.decodeRawAll(jwt.value, JwtOptions(false, false, false))).mapError(_.getMessage)

for {
decodedJwtTask <- decodeJWT
(_, claim, _) = decodedJwtTask
decodedClaim <- ZIO.fromEither(decode[JwtCredentialPayload](claim).left.map(_.toString))
validatedCredential <- CredentialPayloadValidation.validateSchema(decodedClaim)(schemaResolver)(
schemaToValidator
val validatedDecodedClaim: Validation[String, JwtCredentialPayload] =
for {
decodedJwtTask <- decodeJWT
(_, claim, _) = decodedJwtTask
decodedClaim <- Validation.fromEither(decode[JwtCredentialPayload](claim).left.map(_.toString))
} yield decodedClaim

ValidationUtils.foreach(
validatedDecodedClaim.map(decodedClaim =>
CredentialPayloadValidation.validateSchema(decodedClaim)(schemaResolver)(schemaToValidator)
)
} yield true
)(_.replicateZIODiscard(1))
}

def validateSchemaAndSignature(
jwt: JWT
)(didResolver: DidResolver)(schemaResolver: SchemaResolver)(
schemaToValidator: Json => Either[String, SchemaValidator]
): IO[String, Boolean] = {
schemaToValidator: Json => Validation[String, SchemaValidator]
): IO[String, Validation[String, Unit]] = {
for {
validatedJwtSchema <- validateJwtSchema(jwt)(schemaResolver)(schemaToValidator)
validateJwtSignature <- validateEncodedJWT(jwt)(didResolver)
} yield validatedJwtSchema && validateJwtSignature
} yield {
Validation.validateWith(validatedJwtSchema, validateJwtSignature)((a, _) => a)
}
}

def verifyDates(jwt: JWT, leeway: TemporalAmount)(implicit clock: Clock): Validation[String, Unit] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,51 +312,51 @@ object JwtPresentation {
JwtCirce.decodeRaw(jwt.value, publicKey).flatMap(decode[JwtPresentationPayload](_).toTry)
}

def validateEncodedJwt(jwt: JWT, publicKey: PublicKey): Boolean =
def validateEncodedJwt(jwt: JWT, publicKey: PublicKey): Validation[String, Unit] =
JWTVerification.validateEncodedJwt(jwt, publicKey)

def validateEncodedJWT(
jwt: JWT
)(didResolver: DidResolver): IO[String, Boolean] = {
)(didResolver: DidResolver): IO[String, Validation[String, Unit]] = {
JWTVerification.validateEncodedJwt(jwt)(didResolver: DidResolver)(claim =>
ZIO.fromEither(decode[JwtPresentationPayload](claim).left.map(_.toString))
Validation.fromEither(decode[JwtPresentationPayload](claim).left.map(_.toString))
)(_.iss)
}

def validateEncodedW3C(
jwt: JWT
)(didResolver: DidResolver): IO[String, Boolean] = {
)(didResolver: DidResolver): IO[String, Validation[String, Unit]] = {
JWTVerification.validateEncodedJwt(jwt)(didResolver: DidResolver)(claim =>
ZIO.fromEither(decode[W3cPresentationPayload](claim).left.map(_.toString))
Validation.fromEither(decode[W3cPresentationPayload](claim).left.map(_.toString))
)(_.holder)
}

def validateEnclosedCredentials(
jwt: JWT
)(didResolver: DidResolver): IO[List[String], Boolean] = {
def validateCredential(a: VerifiableCredentialPayload): IO[String, Boolean] = {
a match {
case (w3cVerifiableCredentialPayload: W3cVerifiableCredentialPayload) =>
JwtCredential.validateW3C(w3cVerifiableCredentialPayload)(didResolver)
case (jwtVerifiableCredentialPayload: JwtVerifiableCredentialPayload) =>
JwtCredential.validateEncodedJWT(jwtVerifiableCredentialPayload.jwt)(didResolver)
}
}
def validateCredentials(
decodedJwtPresentation: JwtPresentationPayload
): ZIO[Any, List[String], IndexedSeq[Boolean]] = {
ZIO.validatePar(decodedJwtPresentation.vp.verifiableCredential) { a =>
validateCredential(a)
}
}
)(didResolver: DidResolver): IO[List[String], Validation[String, Unit]] = {
val validateJwtPresentation = Validation.fromTry(decodeJwt(jwt)).mapError(_.toString)

val credentialValidationZIO =
ValidationUtils.foreach(
validateJwtPresentation
.map(validJwtPresentation => validateCredentials(validJwtPresentation)(didResolver))
)(identity)

val validatedCredentials =
credentialValidationZIO.map(validCredentialValidations => {
for {
decodedJwtPresentation <- ZIO.fromTry(decodeJwt(jwt)).mapError(error => error.toString :: Nil)
validatedCredentials <- validateCredentials(decodedJwtPresentation)
} yield validatedCredentials.forall(identity)
credentialValidations <- validCredentialValidations
_ <- Validation.validateAll(credentialValidations)
success <- Validation.unit
} yield success
})
}

validatedCredentials
def validateCredentials(
decodedJwtPresentation: JwtPresentationPayload
)(didResolver: DidResolver): ZIO[Any, List[String], IndexedSeq[Validation[String, Unit]]] = {
ZIO.validatePar(decodedJwtPresentation.vp.verifiableCredential) { a =>
JwtCredential.validateCredential(a)(didResolver)
}
}

def verifyDates(jwt: JWT, leeway: TemporalAmount)(implicit clock: Clock): Validation[String, Unit] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ class PlaceholderSchemaValidator extends SchemaValidator {
}

object PlaceholderSchemaValidator {
def fromSchema(schema: Json): Either[String, PlaceholderSchemaValidator] = Right(PlaceholderSchemaValidator())
def fromSchema(schema: Json): Validation[String, PlaceholderSchemaValidator] =
Validation.succeed(PlaceholderSchemaValidator())
}
1 change: 1 addition & 0 deletions pollux/project/build.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sbt.version=1.6.2

0 comments on commit 6348e13

Please sign in to comment.