Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http (fix, breaking): RPCContext.current.getThreadLocal interface change to avoid unsafe type cast #3548

Merged
merged 6 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ case class FinagleRPCContext(request: Request) extends RPCContext {
FinagleBackend.setThreadLocal(key, value)
}

override def getThreadLocal[A](key: String): Option[A] = {
override def getThreadLocal(key: String): Option[Any] = {
FinagleBackend.getThreadLocal(key)
}

override def getThreadLocalUnsafe[A](key: String): Option[A] = {
getThreadLocal(key).map(_.asInstanceOf[A])
}

override def httpRequest: HttpMessage.Request = {
request.toHttpRequest
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ThreadLocalStorageTest extends AirSpec {

@Endpoint(path = "/rpc-context")
def rpcContext: String = {
RPCContext.current.getThreadLocal[String]("client_id").getOrElse("unknown")
RPCContext.current.getThreadLocal("client_id").map(_.toString).getOrElse("unknown")
}

@Endpoint(path = "/rpc-header")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package wvlet.airframe.http.grpc

import io.grpc.*
import wvlet.airframe.http.internal.TLSSupport
import wvlet.airframe.http.{Http, HttpMessage, RPCContext, RPCEncoding}
import wvlet.log.LogSupport

Expand Down Expand Up @@ -57,16 +58,9 @@ case class GrpcContext(
metadata: Metadata,
descriptor: MethodDescriptor[_, _]
) extends RPCContext
with TLSSupport
with LogSupport {

// Grpc doesn't provide a mutable thread-local stage, so create our own TLS here.
private lazy val tls =
ThreadLocal.withInitial[collection.mutable.Map[String, Any]](() => mutable.Map.empty[String, Any])

private def storage: collection.mutable.Map[String, Any] = {
tls.get()
}

// Return the accept header
def accept: String = metadata.accept
def encoding: RPCEncoding = accept match {
Expand All @@ -79,11 +73,11 @@ case class GrpcContext(
}

override def setThreadLocal[A](key: String, value: A): Unit = {
storage.put(key, value)
setTLS(key, value)
}

override def getThreadLocal[A](key: String): Option[A] = {
storage.get(key).asInstanceOf[Option[A]]
override def getThreadLocal(key: String): Option[Any] = {
getTLS(key)
}

override def httpRequest: HttpMessage.Request = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ trait DemoApi extends LogSupport {

def getRPCContext: Option[String] = {
val ctx = RPCContext.current
ctx.getThreadLocal[String]("client_id")
ctx.getThreadLocal("client_id").map(_.toString)
}

def getRequest: Request = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ package wvlet.airframe.http.netty

import wvlet.airframe.http.HttpMessage.{Request, Response}
import wvlet.airframe.http.*
import wvlet.airframe.http.internal.TLSSupport
import wvlet.airframe.rx.Rx
import wvlet.log.LogSupport

import scala.collection.mutable
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.util.{Failure, Success}

object NettyBackend extends HttpBackend[Request, Response, Rx] with LogSupport { self =>
object NettyBackend extends HttpBackend[Request, Response, Rx] with TLSSupport with LogSupport { self =>
private val rxBackend = new RxNettyBackend

override protected implicit val httpRequestAdapter: HttpRequestAdapter[Request] =
Expand Down Expand Up @@ -89,21 +90,16 @@ object NettyBackend extends HttpBackend[Request, Response, Rx] with LogSupport {
f.toRx.map(body)
}

private lazy val tls =
ThreadLocal.withInitial[collection.mutable.Map[String, Any]](() => mutable.Map.empty[String, Any])

private def storage: collection.mutable.Map[String, Any] = tls.get()

override def withThreadLocalStore(request: => Rx[Response]): Rx[Response] = {
//
request
}

override def setThreadLocal[A](key: String, value: A): Unit = {
storage.put(key, value)
setTLS(key, value)
}

override def getThreadLocal[A](key: String): Option[A] = {
storage.get(key).asInstanceOf[Option[A]]
getTLS(key).map(_.asInstanceOf[A])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ package wvlet.airframe.http.netty

import wvlet.airframe.http.HttpMessage.Request
import wvlet.airframe.http.RPCContext
import wvlet.airframe.http.internal.TLSSupport

class NettyRPCContext(val httpRequest: Request) extends RPCContext {
override def setThreadLocal[A](key: String, value: A): Unit = {
NettyBackend.setThreadLocal(key, value)
}
override def getThreadLocal[A](key: String): Option[A] = {
NettyBackend.getThreadLocal(key)
}
import scala.collection.mutable

class NettyRPCContext(val httpRequest: Request) extends RPCContext with TLSSupport {
override def setThreadLocal[A](key: String, value: A): Unit = setTLS(key, value)
override def getThreadLocal(key: String): Option[Any] = getTLS(key)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,25 @@ class NettyBackendTest extends AirSpec {
val key = ULID.newULIDString

test("must be None by default") {
NettyBackend.getThreadLocal[Int](key) shouldBe None
NettyBackend.getThreadLocal(key) shouldBe None
}

test("store different content for each thread") {
NettyBackend.setThreadLocal[Int](key, 123)
NettyBackend.setThreadLocal(key, 123)

var valueInThread: Option[Int] = None

val t = new Thread {
override def run(): Unit = {
NettyBackend.getThreadLocal[Int](key) shouldBe None
NettyBackend.setThreadLocal[Int](key, 456)
valueInThread = NettyBackend.getThreadLocal[Int](key)
NettyBackend.getThreadLocal(key) shouldBe None
NettyBackend.setThreadLocal(key, 456)
valueInThread = NettyBackend.getThreadLocal(key)
}
}
t.start()
t.join()

NettyBackend.getThreadLocal[Int](key) shouldBe Some(123)
NettyBackend.getThreadLocal(key) shouldBe Some(123)
valueInThread shouldBe Some(456)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@ object NettyLoggingTest extends AirSpec {

@RPC
class MyRPC extends LogSupport {
private var requestCount = 0

def hello(): Unit = {
RPCContext.current.setThreadLocal("user", "xxxx_yyyy")
debug("hello rpc")
if (requestCount == 0) {
RPCContext.current.setThreadLocal("user", "xxxx_yyyy")
}
requestCount += 1
trace("hello rpc")
}
}

Expand All @@ -46,7 +51,7 @@ object NettyLoggingTest extends AirSpec {
.withName("log-test-server")
.withExtraLogEntries { () =>
val m = ListMap.newBuilder[String, Any]
RPCContext.current.getThreadLocal[String]("user").foreach { v =>
RPCContext.current.getThreadLocal("user").foreach { v =>
m += "user" -> v
}
m += ("custom_log_entry" -> "test")
Expand All @@ -67,12 +72,20 @@ object NettyLoggingTest extends AirSpec {

test("add server custom log") { (syncClient: SyncClient) =>
syncClient.send(Http.POST("/wvlet.airframe.http.netty.NettyLoggingTest.MyRPC/hello"))
val logEntry = serverLogger.getLogs.head
val logs = serverLogger.getLogs
val logEntry = logs(0)
debug(logEntry)
logEntry shouldContain ("server_name" -> "log-test-server")
logEntry shouldContain ("custom_log_entry" -> "test")
logEntry shouldContain ("user" -> "xxxx_yyyy")

test("do not set TLS in the second request") {
syncClient.send(Http.POST("/wvlet.airframe.http.netty.NettyLoggingTest.MyRPC/hello"))
val l = serverLogger.getLogs(1)
debug(l)
l shouldNotContain ("user" -> "xxxx_yyyy")
}

test("add client custom log") {
val clientLogEntry = clientLogger.getLogs.head
debug(clientLogEntry)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package wvlet.airframe.http.internal

import wvlet.airframe.http.{RPCContext, EmptyRPCContext}
import wvlet.airframe.http.{EmptyRPCContext, RPCContext}

object LocalRPCContext {
private val localContext = new ThreadLocal[RPCContext]()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.http.internal

import scala.collection.mutable

/**
* Thread-local storage support
*/
private[http] trait TLSSupport {
private lazy val tls = ThreadLocal.withInitial[mutable.Map[String, Any]](() => mutable.Map.empty[String, Any])
private def tlsStorage(): mutable.Map[String, Any] = tls.get()

def setTLS(key: String, value: Any): Unit = tlsStorage().put(key, value)
def getTLS(key: String): Option[Any] = tlsStorage().get(key)
}
20 changes: 16 additions & 4 deletions airframe-http/src/main/scala/wvlet/airframe/http/RPCContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ trait RPCContext {
def httpRequest: HttpMessage.Request

def rpcCallContext: Option[RPCCallContext] = {
getThreadLocal[RPCCallContext](HttpBackend.TLS_KEY_RPC)
getThreadLocal(HttpBackend.TLS_KEY_RPC) match {
case Some(c: RPCCallContext) => Some(c)
case _ => None
}
}

/**
Expand All @@ -52,10 +55,19 @@ trait RPCContext {
* Get a thread-local variable that is available only within the request scope. The type must be specified
* explicitly.
* @param key
* @tparam A
* @return
*/
def getThreadLocal[A](key: String): Option[A]
@deprecated("Use getThreadLocal(key: String): Any instead", "24.5.0")
def getThreadLocalUnsafe[A](key: String): Option[A] = {
getThreadLocal(key).map(_.asInstanceOf[A])
}

/**
* Get a thread-local variable that is available only within the request scope.
* @param key
* @return
*/
def getThreadLocal(key: String): Option[Any]
}

/**
Expand All @@ -65,7 +77,7 @@ object EmptyRPCContext extends RPCContext {
override def setThreadLocal[A](key: String, value: A): Unit = {
// no-op
}
override def getThreadLocal[A](key: String): Option[A] = {
override def getThreadLocal(key: String): Option[Any] = {
// no-op
None
}
Expand Down
6 changes: 3 additions & 3 deletions docs/airframe-http.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ val server = Netty.server
// Add a custom log entry
m += "application_version" -> "1.0"
// Add a thread-local parameter to the log
RPCContext.current.getThreadLocal[String]("user_id").map { uid =>
RPCContext.current.getThreadLocal("user_id").map { uid =>
m += "user_id" -> uid
}
m.result
}
// [optional] Disable server-side logging (log/http_server.json)
.noLogging
// Add a custom MessageCodec mapping
// [optional] Add a custom MessageCodec mapping
.withCustomCodec{ case s: Surface.of[MyClass] => ... }

server.start { server =>
Expand Down Expand Up @@ -372,7 +372,7 @@ object AuthLogFilter extends RxHttpFilter with LogSupport {
def apply(request: Request, next: RxHttpEndpoint): Rx[Response] = {
next(request).map { response =>
// Read the thread-local parameter set in the context(request)
RPCContext.current.getThreadLocal[String]("user_id").map { uid =>
RPCContext.current.getThreadLocal("user_id").map { uid =>
info(s"user_id: ${uid}")
}
response
Expand Down
4 changes: 2 additions & 2 deletions docs/airframe-rpc.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ String "100" will be translated into an Int value `100` automatically.

### RPCContext

Since Airframe 22.8.0, airframe-rpc introduced `RPCContext` for reading and writing the thread-local storage, and referencing the original HTTP request:
Since Airframe 22.8.0, airframe-rpc introduced `RPCContext.current` for reading and writing the thread-local storage, and referencing the original HTTP request:

```scala
import wvlet.airframe.http._
Expand All @@ -456,7 +456,7 @@ import wvlet.airframe.http._
trait MyAPI {
def hello: String = {
// Read the thread-local storage
val userName = RPCContext.current.getThreadLocal[String]("context_user")
val userName = RPCContext.current.getThreadLocal("context_user")
s"Hello ${userName}"
}

Expand Down
Loading