ServletFilter.kt

  1. package com.hexagonkt.http.server.servlet

  2. import com.hexagonkt.core.logging.Logger
  3. import com.hexagonkt.core.media.TEXT_PLAIN
  4. import com.hexagonkt.core.toText
  5. import com.hexagonkt.http.handlers.bodyToBytes
  6. import com.hexagonkt.http.handlers.HttpHandler
  7. import com.hexagonkt.http.model.HttpResponse
  8. import com.hexagonkt.http.model.HttpResponsePort
  9. import jakarta.servlet.*
  10. import jakarta.servlet.http.Cookie
  11. import jakarta.servlet.http.HttpFilter
  12. import jakarta.servlet.http.HttpServletRequest
  13. import jakarta.servlet.http.HttpServletResponse

  14. class ServletFilter(pathHandler: HttpHandler) : HttpFilter() {

  15.     private companion object {
  16.         val logger: Logger = Logger(ServletFilter::class)
  17.     }

  18.     private val handlers: Map<String, HttpHandler> =
  19.         pathHandler.byMethod().mapKeys { it.key.toString() }

  20.     override fun init(filterConfig: FilterConfig) {
  21.         val filterName = filterConfig.filterName
  22.         val parameterNames = filterConfig.initParameterNames.toList().joinToString(", ") {
  23.             "$it = ${filterConfig.getInitParameter(it)}"
  24.         }
  25.         logger.info {
  26.             """'$filterName' Servlet filter initialized.
  27.               |  * Context path: ${filterConfig.servletContext.contextPath}
  28.               |  * Parameters: $parameterNames
  29.             """.trimMargin()
  30.         }
  31.     }

  32.     override fun destroy() {
  33.         logger.info { "Servlet filter destroyed" }
  34.     }

  35.     override fun doFilter(
  36.         request: HttpServletRequest, response: HttpServletResponse, chain: FilterChain) {
  37.         doFilter(request, response)
  38.     }

  39.     private fun doFilter(request: HttpServletRequest, response: HttpServletResponse) {

  40.         val requestAdapter = ServletRequestAdapterSync(request)
  41.         val handlerResponse = handlers[request.method]
  42.             ?.process(requestAdapter)
  43.             ?.response
  44.             ?: HttpResponse()

  45.         try {
  46.             responseToServlet(requestAdapter.protocol.secure, handlerResponse, response)
  47.             response.outputStream.write(bodyToBytes(handlerResponse.body))
  48.         }
  49.         catch (e: Exception) {
  50.             response.addHeader("content-type", TEXT_PLAIN.fullType)
  51.             response.status = 500
  52.             response.outputStream.write(e.toText().toByteArray())
  53.         }
  54.         finally {
  55.             response.outputStream.flush()
  56.         }
  57.     }

  58.     private fun responseToServlet(
  59.         secureRequest: Boolean,
  60.         response: HttpResponsePort,
  61.         servletResponse: HttpServletResponse
  62.     ) {
  63.         response.headers.values.forEach { (k, v) ->
  64.             v.forEach { servletResponse.addHeader(k, it.toString()) }
  65.         }

  66.         response.cookies
  67.             .filter { if (secureRequest) true else !it.secure }
  68.             .forEach {
  69.                 val cookie = Cookie(it.name, it.value).apply {
  70.                     maxAge = it.maxAge.toInt()
  71.                     secure = it.secure
  72.                     path = it.path
  73.                     isHttpOnly = it.httpOnly
  74.                     it.domain?.let { d -> domain = d }
  75.                 }
  76.                 servletResponse.addCookie(cookie)
  77.             }

  78.         response.contentType?.let { servletResponse.addHeader("content-type", it.text) }
  79.         servletResponse.status = response.status.code
  80.     }
  81. }