NettyHttpServer.kt

  1. package com.hexagontk.http.server.netty

  2. import com.hexagontk.core.fieldsMapOf
  3. import com.hexagontk.core.security.createKeyManagerFactory
  4. import com.hexagontk.core.security.createTrustManagerFactory
  5. import com.hexagontk.http.HttpFeature
  6. import com.hexagontk.http.HttpFeature.*
  7. import com.hexagontk.http.SslSettings
  8. import com.hexagontk.http.model.HttpProtocol
  9. import com.hexagontk.http.model.HttpProtocol.*
  10. import com.hexagontk.http.server.HttpServer
  11. import com.hexagontk.http.server.HttpServerPort
  12. import com.hexagontk.http.server.HttpServerSettings
  13. import com.hexagontk.http.handlers.HttpHandler
  14. import io.netty.bootstrap.ServerBootstrap
  15. import io.netty.channel.*
  16. import io.netty.channel.nio.NioIoHandler
  17. import io.netty.channel.socket.nio.NioServerSocketChannel
  18. import io.netty.handler.codec.http.*
  19. import io.netty.handler.ssl.ClientAuth.OPTIONAL
  20. import io.netty.handler.ssl.ClientAuth.REQUIRE
  21. import io.netty.handler.ssl.SslContext
  22. import io.netty.handler.ssl.SslContextBuilder
  23. import java.net.InetSocketAddress
  24. import java.util.concurrent.Executor
  25. import java.util.concurrent.Executors.newVirtualThreadPerTaskExecutor
  26. import java.util.concurrent.TimeUnit.SECONDS
  27. import javax.net.ssl.KeyManagerFactory
  28. import javax.net.ssl.TrustManagerFactory

  29. /**
  30.  * Implements [HttpServerPort] using Netty [Channel].
  31.  */
  32. open class NettyHttpServer(
  33.     private val bossGroupThreads: Int = 1,
  34.     private val workerGroupThreads: Int = 0,
  35.     private val executor: Executor? = newVirtualThreadPerTaskExecutor(),
  36.     private val soBacklog: Int = 4 * 1_024,
  37.     private val soReuseAddr: Boolean = true,
  38.     private val soKeepAlive: Boolean = true,
  39.     private val shutdownQuietSeconds: Long = 0,
  40.     private val shutdownTimeoutSeconds: Long = 0,
  41.     private val keepAliveHandler: Boolean = true,
  42.     private val httpAggregatorHandler: Boolean = true,
  43.     private val chunkedHandler: Boolean = true,
  44.     private val enableWebsockets: Boolean = true,
  45. ) : HttpServerPort {

  46.     private var nettyChannel: Channel? = null
  47.     private var bossEventLoop: MultithreadEventLoopGroup? = null
  48.     private var workerEventLoop: MultithreadEventLoopGroup? = null

  49.     constructor() : this(
  50.         bossGroupThreads = 1,
  51.         workerGroupThreads = 0,
  52.         executor = newVirtualThreadPerTaskExecutor(),
  53.         soBacklog = 4 * 1_024,
  54.         soReuseAddr = true,
  55.         soKeepAlive = true,
  56.         shutdownQuietSeconds = 0,
  57.         shutdownTimeoutSeconds = 0,
  58.         keepAliveHandler = true,
  59.         httpAggregatorHandler = true,
  60.         chunkedHandler = true,
  61.         enableWebsockets = true,
  62.     )

  63.     override fun runtimePort(): Int =
  64.         (nettyChannel?.localAddress() as? InetSocketAddress)?.port
  65.             ?: error("Error fetching runtime port")

  66.     override fun started() =
  67.         nettyChannel?.isOpen ?: false

  68.     override fun startUp(server: HttpServer) {
  69.         val bossGroup = groupSupplier(bossGroupThreads)
  70.         val workerGroup =
  71.             if (workerGroupThreads < 0) bossGroup
  72.             else groupSupplier(workerGroupThreads)

  73.         try {
  74.             val settings = server.settings
  75.             val sslSettings = settings.sslSettings
  76.             val handlers: Map<HttpMethod, HttpHandler> =
  77.                 server.handler
  78.                     .byMethod()
  79.                     .mapKeys { HttpMethod.valueOf(it.key.toString()) }

  80.             val nettyServer = serverBootstrapSupplier(bossGroup, workerGroup)
  81.                 .childHandler(createInitializer(sslSettings, handlers, settings))

  82.             val address = settings.bindAddress
  83.             val port = settings.bindPort
  84.             val future = nettyServer.bind(address, port).sync()

  85.             nettyChannel = future.channel()
  86.             bossEventLoop = bossGroup
  87.             workerEventLoop = workerGroup
  88.         }
  89.         catch (_: Exception) {
  90.             bossGroup.shutdownGracefully()
  91.             workerGroup.shutdownGracefully()
  92.         }
  93.     }

  94.     open fun groupSupplier(it: Int): MultithreadEventLoopGroup =
  95.         MultiThreadIoEventLoopGroup(NioIoHandler.newFactory())

  96.     open fun serverBootstrapSupplier(
  97.         bossGroup: MultithreadEventLoopGroup,
  98.         workerGroup: MultithreadEventLoopGroup,
  99.     ): ServerBootstrap =
  100.         ServerBootstrap().group(bossGroup, workerGroup)
  101.             .channel(NioServerSocketChannel::class.java)
  102.             .option(ChannelOption.SO_BACKLOG, soBacklog)
  103.             .option(ChannelOption.SO_REUSEADDR, soReuseAddr)
  104.             .childOption(ChannelOption.SO_KEEPALIVE, soKeepAlive)
  105.             .childOption(ChannelOption.SO_REUSEADDR, soReuseAddr)

  106.     private fun createInitializer(
  107.         sslSettings: SslSettings?,
  108.         handlers: Map<HttpMethod, HttpHandler>,
  109.         settings: HttpServerSettings
  110.     ) =
  111.         when {
  112.             sslSettings != null -> sslInitializer(sslSettings, handlers, settings)
  113.             else -> HttpChannelInitializer(
  114.                 handlers,
  115.                 executor,
  116.                 settings,
  117.                 keepAliveHandler,
  118.                 httpAggregatorHandler,
  119.                 chunkedHandler,
  120.                 enableWebsockets,
  121.             )
  122.         }

  123.     private fun sslInitializer(
  124.         sslSettings: SslSettings,
  125.         handlers: Map<HttpMethod, HttpHandler>,
  126.         settings: HttpServerSettings
  127.     ): HttpsChannelInitializer =
  128.         HttpsChannelInitializer(
  129.             handlers,
  130.             sslContext(sslSettings),
  131.             sslSettings,
  132.             executor,
  133.             settings,
  134.             keepAliveHandler,
  135.             httpAggregatorHandler,
  136.             chunkedHandler,
  137.             enableWebsockets,
  138.         )

  139.     private fun sslContext(sslSettings: SslSettings): SslContext {
  140.         val keyManager = keyManagerFactory(sslSettings)

  141.         val sslContextBuilder = SslContextBuilder
  142.             .forServer(keyManager)
  143.             .clientAuth(if (sslSettings.clientAuth) REQUIRE else OPTIONAL)

  144.         val trustManager = trustManagerFactory(sslSettings)

  145.         return if (trustManager == null) sslContextBuilder.build()
  146.             else sslContextBuilder.trustManager(trustManager).build()
  147.     }

  148.     private fun trustManagerFactory(sslSettings: SslSettings): TrustManagerFactory? {
  149.         val trustStoreUrl = sslSettings.trustStore ?: return null
  150.         return createTrustManagerFactory(trustStoreUrl, sslSettings.trustStorePassword)
  151.     }

  152.     private fun keyManagerFactory(sslSettings: SslSettings): KeyManagerFactory {
  153.         val keyStoreUrl = sslSettings.keyStore ?: error("")
  154.         return createKeyManagerFactory(keyStoreUrl, sslSettings.keyStorePassword)
  155.     }

  156.     override fun shutDown() {
  157.         workerEventLoop
  158.             ?.shutdownGracefully(shutdownQuietSeconds, shutdownTimeoutSeconds, SECONDS)?.sync()
  159.         bossEventLoop
  160.             ?.shutdownGracefully(shutdownQuietSeconds, shutdownTimeoutSeconds, SECONDS)?.sync()

  161.         nettyChannel = null
  162.         bossEventLoop = null
  163.         workerEventLoop = null
  164.     }

  165.     override fun supportedProtocols(): Set<HttpProtocol> =
  166.         setOf(HTTP, HTTPS, HTTP2)

  167.     override fun supportedFeatures(): Set<HttpFeature> =
  168.         setOf(ZIP, COOKIES, MULTIPART, WEBSOCKETS, SSE)

  169.     override fun options(): Map<String, *> =
  170.         fieldsMapOf(
  171.             NettyHttpServer::bossGroupThreads to bossGroupThreads,
  172.             NettyHttpServer::workerGroupThreads to workerGroupThreads,
  173.             NettyHttpServer::executor to executor,
  174.             NettyHttpServer::soBacklog to soBacklog,
  175.             NettyHttpServer::soKeepAlive to soKeepAlive,
  176.             NettyHttpServer::shutdownQuietSeconds to shutdownQuietSeconds,
  177.             NettyHttpServer::shutdownTimeoutSeconds to shutdownTimeoutSeconds,
  178.             NettyHttpServer::keepAliveHandler to keepAliveHandler,
  179.             NettyHttpServer::httpAggregatorHandler to httpAggregatorHandler,
  180.             NettyHttpServer::chunkedHandler to chunkedHandler,
  181.             NettyHttpServer::enableWebsockets to enableWebsockets,
  182.         )
  183. }