Skip to content

Commit

Permalink
fix #1652
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieuancelin committed Jul 6, 2023
1 parent 5d46d00 commit 37746e6
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 77 deletions.
2 changes: 1 addition & 1 deletion otoroshi/app/models/wasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ case class WasmPlugin(
override def theDescription: String = description
override def theTags: Seq[String] = tags
override def theMetadata: Map[String, String] = metadata
def pool()(implicit env: Env): WasmVmPool = WasmVmPool.forPlugin(id)
def pool()(implicit env: Env): WasmVmPool = WasmVmPool.forPlugin(this)
}

object WasmPlugin {
Expand Down
129 changes: 67 additions & 62 deletions otoroshi/app/next/plugins/wasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,93 +328,98 @@ class WasmAccessValidator extends NgAccessValidator {
override def defaultConfigObject: Option[NgPluginConfig] = WasmConfig().some

override def access(ctx: NgAccessContext)(implicit env: Env, ec: ExecutionContext): Future[NgAccess] = {

val config = ctx
.cachedConfig(internalName)(WasmConfig.format)
.getOrElse(WasmConfig())

if (config.source.kind == WasmSourceKind.Local) {
val vmfc = if (config.source.kind == WasmSourceKind.Local) {
val localPlugin = env.proxyState.wasmPlugin(config.source.path).get
val localConfig = localPlugin.config
localPlugin.pool()
.getPooledVm()
.flatMap { vm =>
vm.call(WasmFunctionParameters.ExtismFuntionCall(config.functionName.orElse(localConfig.functionName).getOrElse("access"), ctx.wasmJson.stringify), None)
.flatMap {
case Right(res) =>
val response = Json.parse(res._1)
AttrsHelper.updateAttrs(ctx.attrs, response)
val result = (response \ "result").asOpt[Boolean].getOrElse(false)
if (result) {
NgAccess.NgAllowed.vfuture
} else {
val error = (response \ "error").asOpt[JsObject].getOrElse(Json.obj())
Errors
.craftResponseResult(
(error \ "message").asOpt[String].getOrElse("An error occured"),
Results.Status((error \ "status").asOpt[Int].getOrElse(403)),
ctx.request,
None,
None,
attrs = ctx.attrs,
maybeRoute = ctx.route.some
)
.map(r => NgAccess.NgDenied(r))
}
case Left(err) =>
localPlugin.pool().getPooledVm().map(vm => (vm, localConfig))
} else {
config.pool().getPooledVm().map(vm => (vm, config))
}

vmfc.flatMap {
case (vm, localConfig) =>
vm.call(WasmFunctionParameters.ExtismFuntionCall(config.functionName.orElse(localConfig.functionName).getOrElse("access"), ctx.wasmJson.stringify), None)
.flatMap {
case Right(res) =>
val response = Json.parse(res._1)
AttrsHelper.updateAttrs(ctx.attrs, response)
val result = (response \ "result").asOpt[Boolean].getOrElse(false)
if (result) {
NgAccess.NgAllowed.vfuture
} else {
val error = (response \ "error").asOpt[JsObject].getOrElse(Json.obj())
Errors
.craftResponseResult(
(err \ "error").asOpt[String].getOrElse("An error occured"),
Results.Status(400),
(error \ "message").asOpt[String].getOrElse("An error occured"),
Results.Status((error \ "status").asOpt[Int].getOrElse(403)),
ctx.request,
None,
None,
attrs = ctx.attrs,
maybeRoute = ctx.route.some
)
.map(r => NgAccess.NgDenied(r))
}
.andThen {
case _ => vm.release()
}
}
} else {
WasmUtils
.execute(config, "access", ctx.wasmJson, ctx.attrs.some, None)
.flatMap {
case Right(res) =>
val response = Json.parse(res)
AttrsHelper.updateAttrs(ctx.attrs, response)
val result = (response \ "result").asOpt[Boolean].getOrElse(false)
if (result) {
NgAccess.NgAllowed.vfuture
} else {
val error = (response \ "error").asOpt[JsObject].getOrElse(Json.obj())
}
case Left(err) =>
Errors
.craftResponseResult(
(error \ "message").asOpt[String].getOrElse("An error occured"),
Results.Status((error \ "status").asOpt[Int].getOrElse(403)),
(err \ "error").asOpt[String].getOrElse("An error occured"),
Results.Status(400),
ctx.request,
None,
None,
attrs = ctx.attrs,
maybeRoute = ctx.route.some
)
.map(r => NgAccess.NgDenied(r))
}
case Left(err) =>
Errors
.craftResponseResult(
(err \ "error").asOpt[String].getOrElse("An error occured"),
Results.Status(400),
ctx.request,
None,
None,
attrs = ctx.attrs,
maybeRoute = ctx.route.some
)
.map(r => NgAccess.NgDenied(r))
}
.andThen {
case _ => vm.release()
}
}
}
//} else {
// WasmUtils
// .execute(config, "access", ctx.wasmJson, ctx.attrs.some, None)
// .flatMap {
// case Right(res) =>
// val response = Json.parse(res)
// AttrsHelper.updateAttrs(ctx.attrs, response)
// val result = (response \ "result").asOpt[Boolean].getOrElse(false)
// if (result) {
// NgAccess.NgAllowed.vfuture
// } else {
// val error = (response \ "error").asOpt[JsObject].getOrElse(Json.obj())
// Errors
// .craftResponseResult(
// (error \ "message").asOpt[String].getOrElse("An error occured"),
// Results.Status((error \ "status").asOpt[Int].getOrElse(403)),
// ctx.request,
// None,
// None,
// attrs = ctx.attrs,
// maybeRoute = ctx.route.some
// )
// .map(r => NgAccess.NgDenied(r))
// }
// case Left(err) =>
// Errors
// .craftResponseResult(
// (err \ "error").asOpt[String].getOrElse("An error occured"),
// Results.Status(400),
// ctx.request,
// None,
// None,
// attrs = ctx.attrs,
// maybeRoute = ctx.route.some
// )
// .map(r => NgAccess.NgDenied(r))
// }
//}
}
}

Expand Down
41 changes: 29 additions & 12 deletions otoroshi/app/wasm/pool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.extism.sdk.manifest.{Manifest, MemoryOptions}
import org.extism.sdk.otoroshi._
import org.extism.sdk.wasm.WasmSourceResolver
import otoroshi.env.Env
import otoroshi.models.WasmPlugin
import otoroshi.utils.syntax.implicits._
import otoroshi.wasm.CacheableWasmScript.CachedWasmScript
import otoroshi.wasm.proxywasm.VmData
Expand Down Expand Up @@ -118,7 +119,7 @@ case class WasmVm(index: Int,
if (killAtRelease.get()) {
queue.offer(WasmVmAction.WasmVmKillAction)
} else {
pool.release(this, index)
pool.release(this)
}
}

Expand Down Expand Up @@ -156,18 +157,28 @@ case class WasmVmPoolAction(promise: Promise[WasmVm], options: WasmVmInitOptions
}

object WasmVmPool {

private[wasm] val logger = Logger("otoroshi-wasm-vm-pool")
private[wasm] val engine = new OtoroshiEngine()
private val instances = new TrieMap[String, WasmVmPool]()
def forPlugin(id: String, maxCalls: Int = Int.MaxValue)(implicit env: Env): WasmVmPool = {
instances.getOrUpdate(id) {
new WasmVmPool(id, None, maxCalls, env)

def forPlugin(plugin: WasmPlugin)(implicit env: Env): WasmVmPool = {
instances.getOrUpdate(plugin.id) {
new WasmVmPool(plugin.id, None, env)
}
}

def forConfig(config: => WasmConfig)(implicit env: Env): WasmVmPool = {
// TODO: not sure if it works well with updated config
instances.getOrUpdate(config.source.cacheKey) {
new WasmVmPool(config.source.cacheKey, config.some, env)
}
}

private[wasm] def removePlugin(id: String): Unit = instances.remove(id)
}

class WasmVmPool(pluginId: String, optConfig: Option[WasmConfig], maxCalls: Int, val env: Env) {
class WasmVmPool(stableId: => String, optConfig: => Option[WasmConfig], val env: Env) {

WasmVmPool.logger.debug("new WasmVmPool")

Expand All @@ -192,8 +203,8 @@ class WasmVmPool(pluginId: String, optConfig: Option[WasmConfig], maxCalls: Int,
wasmConfig() match {
case None =>
destroyCurrent()
WasmVmPool.removePlugin(pluginId)
Future.failed(new RuntimeException(s"No more plugin ${pluginId}"))
WasmVmPool.removePlugin(stableId)
Future.failed(new RuntimeException(s"No more plugin ${stableId}"))
case Some(wcfg) => {
val changed = hasChanged(wcfg)
val available = hasAvailableVm(wcfg)
Expand Down Expand Up @@ -259,13 +270,13 @@ class WasmVmPool(pluginId: String, optConfig: Option[WasmConfig], maxCalls: Int,
val vmDataRef = new AtomicReference[VmData](null)
val addedFunctions = options.addHostFunctions(vmDataRef)
val functions: Array[OtoroshiHostFunction[_ <: OtoroshiHostUserData]] = if (options.importDefaultHostFunctions) {
HostFunctions.getFunctions(config, pluginId, None)(env, env.otoroshiExecutionContext) ++ addedFunctions
HostFunctions.getFunctions(config, stableId, None)(env, env.otoroshiExecutionContext) ++ addedFunctions
} else {
addedFunctions.toArray[OtoroshiHostFunction[_ <: OtoroshiHostUserData]]
}
val memories = LinearMemories.getMemories(config)
val instance = template.instantiate(engine, functions, memories, config.wasi)
val vm = WasmVm(index, maxCalls, options.resetMemory, instance, vmDataRef, memories, functions, this)
val vm = WasmVm(index, options.maxCalls, options.resetMemory, instance, vmDataRef, memories, functions, this)
availableVms.offer(vm)
creatingRef.compareAndSet(true, false)
}
Expand All @@ -284,7 +295,7 @@ class WasmVmPool(pluginId: String, optConfig: Option[WasmConfig], maxCalls: Int,
}
}

private[wasm] def release(vm: WasmVm, index: Int): Unit = synchronized {
private[wasm] def release(vm: WasmVm): Unit = synchronized {
availableVms.synchronized {
availableVms.offer(vm)
inUseVms.remove(vm)
Expand All @@ -298,7 +309,7 @@ class WasmVmPool(pluginId: String, optConfig: Option[WasmConfig], maxCalls: Int,
}

private def wasmConfig(): Option[WasmConfig] = {
optConfig.orElse(env.proxyState.wasmPlugin(pluginId).map(_.config))
optConfig.orElse(env.proxyState.wasmPlugin(stableId).map(_.config))
}

private def hasAvailableVm(plugin: WasmConfig): Boolean = availableVms.size() > 0 && (inUseVms.size < plugin.instances)
Expand Down Expand Up @@ -333,12 +344,18 @@ class WasmVmPool(pluginId: String, optConfig: Option[WasmConfig], maxCalls: Int,
}
}

case class WasmVmInitOptions(importDefaultHostFunctions: Boolean = true, resetMemory: Boolean = true, addHostFunctions: (AtomicReference[VmData]) => Seq[OtoroshiHostFunction[_ <: OtoroshiHostUserData]] = _ => Seq.empty)
case class WasmVmInitOptions(
importDefaultHostFunctions: Boolean = true,
resetMemory: Boolean = true,
maxCalls: Int = Int.MaxValue,
addHostFunctions: (AtomicReference[VmData]) => Seq[OtoroshiHostFunction[_ <: OtoroshiHostUserData]] = _ => Seq.empty
)

object WasmVmInitOptions {
def empty(): WasmVmInitOptions = WasmVmInitOptions(
importDefaultHostFunctions = true,
resetMemory = true,
maxCalls = Int.MaxValue,
addHostFunctions = _ => Seq.empty
)
}
4 changes: 2 additions & 2 deletions otoroshi/app/wasm/proxywasm/coraza.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class CorazaPlugin(wasm: WasmConfig, val config: CorazaWafConfig, key: String, e
private lazy val contextId = new AtomicInteger(0)
private lazy val state =
new ProxyWasmState(CorazaPlugin.rootContextIds.incrementAndGet(), contextId, Some((l, m) => logCallback(l, m)), env)
private lazy val pool: WasmVmPool = new WasmVmPool(key, wasm.some, 2000, env)
private lazy val pool: WasmVmPool = new WasmVmPool(key, wasm.some, env)

def logCallback(level: org.slf4j.event.Level, msg: String): Unit = {
CorazaTrailEvent(level, msg).toAnalytics()
Expand Down Expand Up @@ -313,7 +313,7 @@ class CorazaPlugin(wasm: WasmConfig, val config: CorazaWafConfig, key: String, e
}

def start(attrs: TypedMap): Future[Unit] = {
pool.getPooledVm(WasmVmInitOptions(false, true, createFunctions)).flatMap { vm =>
pool.getPooledVm(WasmVmInitOptions(false, true, 2000, createFunctions)).flatMap { vm =>
val data = VmData.withRules(rules)
attrs.put(otoroshi.wasm.proxywasm.CorazaPluginKeys.CorazaWasmVmKey -> vm)
vm.finitialize {
Expand Down
1 change: 1 addition & 0 deletions otoroshi/app/wasm/wasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ case class WasmConfig(
instances: Int = 1,
authorizations: WasmAuthorizations = WasmAuthorizations()
) extends NgPluginConfig {
def pool()(implicit env: Env): WasmVmPool = WasmVmPool.forConfig(this)
def json: JsValue = Json.obj(
"source" -> source.json,
"memoryPages" -> memoryPages,
Expand Down

0 comments on commit 37746e6

Please sign in to comment.