ServletFilter.kt

  1. package com.hexagontk.http.server.servlet

  2. import com.hexagontk.core.info
  3. import com.hexagontk.core.loggerOf
  4. import com.hexagontk.core.media.TEXT_PLAIN
  5. import com.hexagontk.core.toText
  6. import com.hexagontk.http.handlers.bodyToBytes
  7. import com.hexagontk.http.handlers.HttpHandler
  8. import com.hexagontk.http.model.HttpResponse
  9. import com.hexagontk.http.model.HttpResponsePort
  10. import jakarta.servlet.*
  11. import jakarta.servlet.http.Cookie
  12. import jakarta.servlet.http.HttpFilter
  13. import jakarta.servlet.http.HttpServletRequest
  14. import jakarta.servlet.http.HttpServletResponse
  15. import java.lang.System.Logger

  16. class ServletFilter(pathHandler: HttpHandler) : HttpFilter() {

  17.     private companion object {
  18.         val logger: Logger = loggerOf(ServletFilter::class)
  19.     }

  20.     private val handlers: Map<String, HttpHandler> =
  21.         pathHandler.byMethod().mapKeys { it.key.toString() }

  22.     override fun init(filterConfig: FilterConfig) {
  23.         val filterName = filterConfig.filterName
  24.         val parameterNames = filterConfig.initParameterNames.toList().joinToString(", ") {
  25.             "$it = ${filterConfig.getInitParameter(it)}"
  26.         }
  27.         logger.info {
  28.             """'$filterName' Servlet filter initialized.
  29.               |  * Context path: ${filterConfig.servletContext.contextPath}
  30.               |  * Parameters: $parameterNames
  31.             """.trimMargin()
  32.         }
  33.     }

  34.     override fun destroy() {
  35.         logger.info { "Servlet filter destroyed" }
  36.     }

  37.     override fun doFilter(
  38.         request: HttpServletRequest, response: HttpServletResponse, chain: FilterChain) {
  39.         doFilter(request, response)
  40.     }

  41.     private fun doFilter(request: HttpServletRequest, response: HttpServletResponse) {

  42.         val requestAdapter = ServletRequestAdapterSync(request)
  43.         val handlerResponse = handlers[request.method]
  44.             ?.process(requestAdapter)
  45.             ?.response
  46.             ?: HttpResponse()

  47.         try {
  48.             responseToServlet(requestAdapter.protocol.secure, handlerResponse, response)
  49.             response.outputStream.write(bodyToBytes(handlerResponse.body))
  50.         }
  51.         catch (e: Exception) {
  52.             response.addHeader("content-type", TEXT_PLAIN.fullType)
  53.             response.status = 500
  54.             response.outputStream.write(bodyToBytes(e.toText()))
  55.         }
  56.         finally {
  57.             response.outputStream.flush()
  58.         }
  59.     }

  60.     private fun responseToServlet(
  61.         secureRequest: Boolean,
  62.         response: HttpResponsePort,
  63.         servletResponse: HttpServletResponse
  64.     ) {
  65.         response.headers.all.forEach { (k, v) ->
  66.             v.forEach { servletResponse.addHeader(k, it.text) }
  67.         }

  68.         response.cookies
  69.             .filter { if (secureRequest) true else !it.secure }
  70.             .forEach {
  71.                 val cookie = Cookie(it.name, it.value).apply {
  72.                     maxAge = it.maxAge.toInt()
  73.                     secure = it.secure
  74.                     path = it.path
  75.                     isHttpOnly = it.httpOnly
  76.                     it.domain?.let { d -> domain = d }
  77.                 }
  78.                 servletResponse.addCookie(cookie)
  79.             }

  80.         response.contentType?.let { servletResponse.addHeader("content-type", it.text) }
  81.         servletResponse.status = response.status
  82.     }
  83. }