Skip to content

Commit

Permalink
add throttle and config
Browse files Browse the repository at this point in the history
Co-Authored-By: Johannes Rudolph <[email protected]>

revert code format change

Update Http2ServerSettings.scala

Update Http2ServerSpec.scala

rework test - still needs proper asserts

refactor tests

scalafmt
  • Loading branch information
pjfanning committed Nov 8, 2023
1 parent 6a767e5 commit efd0bd8
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 126 deletions.
5 changes: 5 additions & 0 deletions http-core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ pekko.http {
# Fail the connection if a sent ping is not acknowledged within this timeout.
# When zero the ping-interval is used, if set the value must be evenly divisible by less than or equal to the ping-interval.
ping-timeout = 0s

# Configure the throttle for Reset Frames (https://github.com/apache/incubator-pekko-http/issues/332)
resets-throttle-cost = 100
resets-throttle-burst = 100
resets-throttle-interval = 1s
}

websocket {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import pekko.http.scaladsl.settings.{
ParserSettings,
ServerSettings
}
import pekko.stream.{ BidiShape, Graph, StreamTcpException }
import pekko.stream.{ BidiShape, Graph, StreamTcpException, ThrottleMode }
import pekko.stream.TLSProtocol._
import pekko.stream.scaladsl.{ BidiFlow, Flow, Keep, Source }
import pekko.util.ByteString
Expand Down Expand Up @@ -127,6 +127,7 @@ private[http] object Http2Blueprint {
serverDemux(settings.http2Settings, initialDemuxerSettings, upgraded) atop
FrameLogger.logFramesIfEnabled(settings.http2Settings.logFrames) atop // enable for debugging
hpackCoding(masterHttpHeaderParser, settings.parserSettings) atop
rapidResetMitigation(settings.http2Settings) atop
framing(log) atop
errorHandling(log) atop
idleTimeoutIfConfigured(settings.idleTimeout)
Expand Down Expand Up @@ -198,6 +199,20 @@ private[http] object Http2Blueprint {
Flow[FrameEvent].map(FrameRenderer.render).prepend(Source.single(Http2Protocol.ClientConnectionPreface)),
Flow[ByteString].via(new Http2FrameParsing(shouldReadPreface = false, log)))

private def rapidResetMitigation(
settings: Http2ServerSettings): BidiFlow[FrameEvent, FrameEvent, FrameEvent, FrameEvent, NotUsed] = {
def frameCost(event: FrameEvent): Int = event match {
case _: FrameEvent.DataFrame => 0
case _: FrameEvent.WindowUpdateFrame => 0 // TODO: should we throttle these?
case _ => 1
}

BidiFlow.fromFlows(
Flow[FrameEvent],
Flow[FrameEvent].throttle(settings.resetsThrottleCost, settings.resetsThrottleInterval,
settings.resetsThrottleBurst, frameCost, ThrottleMode.Enforcing))
}

/**
* Runs hpack encoding and decoding. Incoming frames that are processed are HEADERS and CONTINUATION.
* Outgoing frame is ParsedHeadersFrame.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ trait Http2ServerSettings {

def getPingTimeout: Duration = Duration.ofMillis(pingTimeout.toMillis)
def withPingTimeout(timeout: Duration): Http2ServerSettings = withPingTimeout(timeout.toMillis.millis)

def getResetsThrottleCost(): Int = resetsThrottleCost
def getResetsThrottleBurst(): Int = resetsThrottleBurst

def getResetsThrottleInterval: Duration = Duration.ofMillis(resetsThrottleInterval.toMillis)

def withResetsThrottleInterval(interval: Duration): Http2ServerSettings =
withResetsThrottleInterval(interval.toMillis.millis)
}
object Http2ServerSettings extends SettingsCompanion[Http2ServerSettings] {
def create(config: Config): Http2ServerSettings = scaladsl.settings.Http2ServerSettings(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ trait Http2ServerSettings extends javadsl.settings.Http2ServerSettings with Http
def pingTimeout: FiniteDuration
def withPingTimeout(timeout: FiniteDuration): Http2ServerSettings = copy(pingTimeout = timeout)

def resetsThrottleCost: Int
def withResetsThrottleCost(cost: Int) = copy(resetsThrottleCost = cost)

def resetsThrottleBurst: Int
def withResetsThrottleBurst(burst: Int) = copy(resetsThrottleBurst = burst)

def resetsThrottleInterval: FiniteDuration
def withResetsThrottleInterval(interval: FiniteDuration) = copy(resetsThrottleInterval = interval)

@InternalApi
private[http] def internalSettings: Option[Http2InternalServerSettings]
@InternalApi
Expand All @@ -124,6 +133,9 @@ object Http2ServerSettings extends SettingsCompanion[Http2ServerSettings] {
logFrames: Boolean,
pingInterval: FiniteDuration,
pingTimeout: FiniteDuration,
resetsThrottleCost: Int,
resetsThrottleBurst: Int,
resetsThrottleInterval: FiniteDuration,
internalSettings: Option[Http2InternalServerSettings])
extends Http2ServerSettings {
require(maxConcurrentStreams >= 0, "max-concurrent-streams must be >= 0")
Expand Down Expand Up @@ -151,6 +163,9 @@ object Http2ServerSettings extends SettingsCompanion[Http2ServerSettings] {
logFrames = c.getBoolean("log-frames"),
pingInterval = c.getFiniteDuration("ping-interval"),
pingTimeout = c.getFiniteDuration("ping-timeout"),
resetsThrottleCost = c.getInt("resets-throttle-cost"),
resetsThrottleBurst = c.getInt("resets-throttle-burst"),
resetsThrottleInterval = c.getFiniteDuration("resets-throttle-interval"),
None // no possibility to configure internal settings with config
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* license agreements; and to You under the Apache License, version 2.0:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* This file is part of the Apache Pekko project, which was derived from Akka.
*/

/*
* Copyright (C) 2018-2022 Lightbend Inc. <https://www.lightbend.com>
*/

package org.apache.pekko.http.impl.engine.http2

import org.apache.pekko
import pekko.http.impl.engine.http2.Http2Protocol.FrameType
import pekko.http.impl.engine.http2.framing.FrameRenderer
import pekko.util.ByteStringBuilder
import org.scalatest.concurrent.Eventually

import java.nio.ByteOrder

/**
* This tests the http2 server support for rapid resets.
*/
class Http2ServerResetSpec extends Http2SpecWithMaterializer("""
pekko.http.server.remote-address-header = on
pekko.http.server.http2.log-frames = on
""")
with Eventually {
override def failOnSevereMessages: Boolean = false

"The Http/2 server implementation" should {
"cancel connection during rapid reset attack".inAllStagesStopped(new TestSetup with RequestResponseProbes {
implicit val bigEndian: ByteOrder = ByteOrder.BIG_ENDIAN
val bb = new ByteStringBuilder
bb.putInt(0)
val rstFrame = FrameRenderer.renderFrame(FrameType.RST_STREAM, ByteFlag.Zero, 1, bb.result())
val longFrame = Seq.fill(1000)(rstFrame).reduce(_ ++ _)
try {
network.sendBytes(longFrame)
} catch {
case assertionError: AssertionError =>
assertionError.getMessage should include("message CancelSubscription")
assertionError.getMessage should include("org.apache.pekko.stream.RateExceededException")
}
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,33 @@ package org.apache.pekko.http.impl.engine.http2

import org.apache.pekko
import pekko.NotUsed
import pekko.event.Logging
import pekko.http.impl.engine.http2.FrameEvent._
import pekko.http.impl.engine.http2.Http2Protocol.{ ErrorCode, Flags, FrameType, SettingIdentifier }
import pekko.http.impl.engine.http2.framing.FrameRenderer
import pekko.http.impl.engine.server.{ HttpAttributes, ServerTerminator }
import pekko.http.impl.engine.ws.ByteStringSinkProbe
import pekko.http.impl.util.PekkoSpecWithMaterializer

Check warning on line 22 in http2-tests/src/test/scala/org/apache/pekko/http/impl/engine/http2/Http2ServerSpec.scala

View workflow job for this annotation

GitHub Actions / Compile and test (2.12, 11)

Unused import

Check warning on line 22 in http2-tests/src/test/scala/org/apache/pekko/http/impl/engine/http2/Http2ServerSpec.scala

View workflow job for this annotation

GitHub Actions / Compile and test (2.12, 8)

Unused import
import pekko.http.impl.util.LogByteStringTools
import pekko.http.scaladsl.Http
import pekko.http.scaladsl.client.RequestBuilding.Get
import pekko.http.scaladsl.model._
import pekko.http.scaladsl.model.headers.{ CacheDirectives, RawHeader }
import pekko.http.scaladsl.settings.ServerSettings
import pekko.stream.Attributes
import pekko.stream.Attributes.LogLevels
import pekko.stream.OverflowStrategy
import pekko.stream.scaladsl.{ BidiFlow, Flow, Keep, Sink, Source, SourceQueueWithComplete }
import pekko.stream.testkit.TestPublisher.{ ManualProbe, Probe }
import pekko.stream.scaladsl.{ BidiFlow, Flow, Source, SourceQueueWithComplete }
import pekko.stream.testkit.TestPublisher.ManualProbe
import pekko.stream.testkit.scaladsl.StreamTestKit
import pekko.stream.testkit.{ TestPublisher, TestSubscriber }
import pekko.stream.testkit.TestPublisher
import pekko.testkit._
import pekko.util.{ ByteString, ByteStringBuilder }
import pekko.util.ByteString

import org.scalatest.concurrent.Eventually
import org.scalatest.concurrent.PatienceConfiguration.Timeout

import java.net.InetSocketAddress
import java.nio.ByteOrder
import javax.net.ssl.SSLContext

import scala.annotation.nowarn
import scala.collection.immutable
import scala.concurrent.duration._
import scala.concurrent.{ Await, ExecutionContext, Future, Promise }
import scala.concurrent.{ Await, Promise }

/**
* This tests the http2 server protocol logic.
Expand All @@ -59,7 +52,7 @@ import scala.concurrent.{ Await, ExecutionContext, Future, Promise }
* * if applicable: provide application-level response
* * validate the produced response frames
*/
class Http2ServerSpec extends PekkoSpecWithMaterializer("""
class Http2ServerSpec extends Http2SpecWithMaterializer("""
pekko.http.server.remote-address-header = on
pekko.http.server.http2.log-frames = on
""")
Expand Down Expand Up @@ -1785,119 +1778,7 @@ class Http2ServerSpec extends PekkoSpecWithMaterializer("""
dataStream.expectCancellation()
terminated.futureValue
})
"Not get unresponsive during attack".inAssertAllStagesStopped(new TestSetup with RequestResponseProbes {
implicit val bigEndian: ByteOrder = ByteOrder.BIG_ENDIAN
val bb = new ByteStringBuilder
bb.putInt(0)
val rstFrame = FrameRenderer.renderFrame(FrameType.RST_STREAM, ByteFlag.Zero, 1, bb.result())
val longFrame = Seq.fill(10000)(rstFrame).reduce(_ ++ _)
println(s"Size: ${longFrame.size}")
(0 to 100).foreach { _ =>
val start = System.nanoTime()
network.sendBytes(longFrame)
val end = System.nanoTime()
val s = (end - start).toFloat / 1000000000f
println(
f"Latency: ${(end - start) / 1000000.0}%.2f ms throughput ${longFrame.size.toFloat / s / 1000 / 1000}%.2f MB/s")
}
})
}
}

implicit class InWithStoppedStages(name: String) {
def inAssertAllStagesStopped(runTest: => TestSetup) =
name in StreamTestKit.assertAllStagesStopped {
val setup = runTest

// force connection to shutdown (in case it is an invalid state)
setup.network.fromNet.sendError(new RuntimeException)
setup.network.toNet.cancel()

// and then assert that all stages, substreams in particular, are stopped
}
}

protected /* To make ByteFlag warnings go away */ abstract class TestSetupWithoutHandshake {
implicit def ec: ExecutionContext = system.dispatcher

private val framesOut: Http2FrameProbe = Http2FrameProbe()
private val toNet = framesOut.plainDataProbe
private val fromNet = TestPublisher.probe[ByteString]()

def handlerFlow: Flow[HttpRequest, HttpResponse, NotUsed]

// hook to modify server, for example add attributes
def modifyServer(server: BidiFlow[HttpResponse, ByteString, ByteString, HttpRequest, ServerTerminator]) = server

// hook to modify server settings
def settings: ServerSettings = ServerSettings(system).withServerHeader(None)

final def theServer: BidiFlow[HttpResponse, ByteString, ByteString, HttpRequest, ServerTerminator] =
modifyServer(Http2Blueprint.serverStack(settings, system.log, telemetry = NoOpTelemetry,
dateHeaderRendering = Http().dateHeaderRendering))
.atop(LogByteStringTools.logByteStringBidi("network-plain-text").addAttributes(
Attributes(LogLevels(Logging.DebugLevel, Logging.DebugLevel, Logging.DebugLevel))))

val serverTerminator =
handlerFlow
.joinMat(theServer)(Keep.right)
.join(Flow.fromSinkAndSource(toNet.sink, Source.fromPublisher(fromNet)))
.withAttributes(Attributes.inputBuffer(1, 1))
.run()

val network = new NetworkSide(fromNet, toNet, framesOut) with Http2FrameHpackSupport
}

class NetworkSide(val fromNet: Probe[ByteString], val toNet: ByteStringSinkProbe, val framesOut: Http2FrameProbe)
extends WindowTracking {
override def frameProbeDelegate = framesOut

def sendBytes(bytes: ByteString): Unit = fromNet.sendNext(bytes)

}

/** Basic TestSetup that has already passed the exchange of the connection preface */
abstract class TestSetup(initialClientSettings: Setting*) extends TestSetupWithoutHandshake {
network.sendBytes(Http2Protocol.ClientConnectionPreface)
network.expectSETTINGS()

network.sendFrame(SettingsFrame(immutable.Seq.empty ++ initialClientSettings))
network.expectSettingsAck()
}

/** Provides the user handler flow as `requestIn` and `responseOut` probes for manual stream interaction */
trait RequestResponseProbes extends TestSetupWithoutHandshake {
private lazy val requestIn = TestSubscriber.probe[HttpRequest]()
private lazy val responseOut = TestPublisher.probe[HttpResponse]()

def handlerFlow: Flow[HttpRequest, HttpResponse, NotUsed] =
Flow.fromSinkAndSource(Sink.fromSubscriber(requestIn), Source.fromPublisher(responseOut))

lazy val user = new UserSide(requestIn, responseOut)

def expectGracefulCompletion(): Unit = {
network.toNet.expectComplete()
user.requestIn.expectComplete()
}
}

class UserSide(val requestIn: TestSubscriber.Probe[HttpRequest], val responseOut: TestPublisher.Probe[HttpResponse]) {
def expectRequest(): HttpRequest = requestIn.requestNext().removeAttribute(Http2.streamId)
def expectRequestRaw(): HttpRequest = requestIn.requestNext() // TODO, make it so that internal headers are not listed in `headers` etc?
def emitResponse(streamId: Int, response: HttpResponse): Unit =
responseOut.sendNext(response.addAttribute(Http2.streamId, streamId))

}

/** Provides the user handler flow as a handler function */
trait HandlerFunctionSupport extends TestSetupWithoutHandshake {
def parallelism: Int = 2
def handler: HttpRequest => Future[HttpResponse] =
_ => Future.successful(HttpResponse())

def handlerFlow: Flow[HttpRequest, HttpResponse, NotUsed] =
Http2Blueprint.handleWithStreamIdHeader(parallelism)(handler)
}

def bytes(num: Int, byte: Byte): ByteString = ByteString(Array.fill[Byte](num)(byte))
}
Loading

0 comments on commit efd0bd8

Please sign in to comment.