CorsCallback.kt

  1. package com.hexagonkt.http.server.callbacks

  2. import com.hexagonkt.core.text.Glob
  3. import com.hexagonkt.http.model.*
  4. import com.hexagonkt.http.model.HttpMethod.Companion.ALL
  5. import com.hexagonkt.http.model.HttpMethod.OPTIONS
  6. import com.hexagonkt.http.model.HttpStatusType.SUCCESS
  7. import com.hexagonkt.http.handlers.HttpContext

  8. /**
  9.  * HTTP CORS callback. It holds info for CORS.
  10.  */
  11. class CorsCallback(
  12.     private val allowedOrigin: Regex,
  13.     private val allowedMethods: Set<HttpMethod> = ALL,
  14.     private val allowedHeaders: Set<String> = emptySet(),
  15.     private val exposedHeaders: Set<String> = emptySet(),
  16.     private val supportCredentials: Boolean = true,
  17.     private val preFlightStatus: HttpStatus = NO_CONTENT_204,
  18.     private val preFlightMaxAge: Long = 0
  19. ) : (HttpContext) -> HttpContext {

  20.     private companion object {
  21.         const val ALLOW_ORIGIN = "access-control-allow-origin"
  22.         const val ALLOW_CREDENTIALS = "access-control-allow-credentials"
  23.         const val REQUEST_METHOD = "access-control-request-method"
  24.         const val EXPOSE_HEADERS = "access-control-expose-headers"
  25.         const val REQUEST_HEADERS = "access-control-request-headers"
  26.         const val ALLOW_HEADERS = "access-control-allow-headers"
  27.         const val MAX_AGE = "access-control-max-age"
  28.     }

  29.     constructor(
  30.         allowedOrigin: String = "*",
  31.         allowedMethods: Set<HttpMethod> = ALL,
  32.         allowedHeaders: Set<String> = emptySet(),
  33.         exposedHeaders: Set<String> = emptySet(),
  34.         supportCredentials: Boolean = true,
  35.         preFlightStatus: HttpStatus = NO_CONTENT_204,
  36.         preFlightMaxAge: Long = 0) :
  37.         this(
  38.             Glob(allowedOrigin).regex,
  39.             allowedMethods,
  40.             allowedHeaders,
  41.             exposedHeaders,
  42.             supportCredentials,
  43.             preFlightStatus,
  44.             preFlightMaxAge
  45.         )

  46.     init {
  47.         val preFlightStatusType = preFlightStatus.type

  48.         require(preFlightStatusType == SUCCESS) {
  49.             "Preflight Status must be a success status: $preFlightStatusType"
  50.         }
  51.     }

  52.     override fun invoke(context: HttpContext): HttpContext =
  53.         context.simpleRequest().let {
  54.             if (context.request.method == OPTIONS) it.preFlightRequest()
  55.             else it
  56.         }.let {
  57.             if (it.response.status != FORBIDDEN_403) it.next()
  58.             else it
  59.         }

  60.     private fun allowOrigin(origin: String): Boolean =
  61.         allowedOrigin.matches(origin)

  62.     private fun accessControlAllowOrigin(origin: String): String =
  63.         if (allowedOrigin.pattern == ".*" && !supportCredentials) "*"
  64.         else origin

  65.     private fun HttpContext.simpleRequest(): HttpContext {
  66.         val origin = request.origin() ?: return this
  67.         if (!allowOrigin(origin))
  68.             return forbidden("Not allowed origin: $origin")

  69.         val accessControlAllowOrigin = accessControlAllowOrigin(origin)
  70.         var h = response.headers + Header(ALLOW_ORIGIN, accessControlAllowOrigin)

  71.         if (accessControlAllowOrigin != "*")
  72.             h += Header("vary", "Origin")

  73.         if (supportCredentials)
  74.             h += Header(ALLOW_CREDENTIALS, true)

  75.         val accessControlRequestMethod = request.headers[REQUEST_METHOD]
  76.         if (request.method == OPTIONS && accessControlRequestMethod != null)
  77.             return badRequest()

  78.         if (request.method !in allowedMethods)
  79.             return forbidden("Not allowed method: ${request.method}")

  80.         if (exposedHeaders.isNotEmpty()) {
  81.             val requestHeaderNames = request.headers.httpFields.keys.toSet()
  82.             val requestHeaders = requestHeaderNames.filter { it in exposedHeaders }

  83.             h += Header(EXPOSE_HEADERS, requestHeaders.joinToString(","))
  84.         }

  85.         return send(preFlightStatus, headers = h)
  86.     }

  87.     private fun HttpContext.preFlightRequest(): HttpContext {

  88.         val methodHeader = request.headers[REQUEST_METHOD]?.value as? String
  89.         val requestMethod = methodHeader
  90.             ?: return forbidden("$REQUEST_METHOD required header not found")

  91.         val method = HttpMethod.valueOf(requestMethod)
  92.         if (method !in allowedMethods)
  93.             return forbidden("Not allowed method: $method")

  94.         val accessControlRequestHeaders = request.headers[REQUEST_HEADERS]?.value as? String

  95.         var h = response.headers

  96.         if (accessControlRequestHeaders != null) {
  97.             val allowedHeaders = accessControlRequestHeaders
  98.                 .split(",")
  99.                 .map { it.trim() }
  100.                 .all { it in allowedHeaders }

  101.             if (!allowedHeaders && this@CorsCallback.allowedHeaders.isNotEmpty())
  102.                 return forbidden("Not allowed headers")

  103.             val headers = this@CorsCallback.allowedHeaders
  104.             val requestHeaders = headers.ifEmpty { request.headers.httpFields.keys.toSet() }
  105.             h += Header(ALLOW_HEADERS, requestHeaders.joinToString(","))
  106.         }

  107.         h += Header(REQUEST_METHOD, allowedMethods.joinToString(","))

  108.         if (preFlightMaxAge > 0)
  109.             h += Header(MAX_AGE, preFlightMaxAge.toString())

  110.         val origin = request.origin() ?: ""
  111.         return when {
  112.             allowOrigin(origin) && origin.isBlank() ->
  113.                 send(preFlightStatus, headers = h)
  114.             allowOrigin(origin) ->
  115.                 send(preFlightStatus, headers = h + Header(ALLOW_ORIGIN, accessControlAllowOrigin(origin)))
  116.             !allowOrigin(origin) && origin.isNotBlank() ->
  117.                 forbidden("Not allowed origin: $origin")
  118.             else ->
  119.                 forbidden("Forbidden pre-flight request")
  120.         }
  121.     }
  122. }