CorsCallback.kt
- package com.hexagonkt.http.server.callbacks
- import com.hexagonkt.core.text.Glob
- import com.hexagonkt.http.model.*
- import com.hexagonkt.http.model.HttpMethod.Companion.ALL
- import com.hexagonkt.http.model.HttpMethod.OPTIONS
- import com.hexagonkt.http.model.HttpStatusType.SUCCESS
- import com.hexagonkt.http.handlers.HttpContext
- /**
- * HTTP CORS callback. It holds info for CORS.
- */
- class CorsCallback(
- private val allowedOrigin: Regex,
- private val allowedMethods: Set<HttpMethod> = ALL,
- private val allowedHeaders: Set<String> = emptySet(),
- private val exposedHeaders: Set<String> = emptySet(),
- private val supportCredentials: Boolean = true,
- private val preFlightStatus: HttpStatus = NO_CONTENT_204,
- private val preFlightMaxAge: Long = 0
- ) : (HttpContext) -> HttpContext {
- private companion object {
- const val ALLOW_ORIGIN = "access-control-allow-origin"
- const val ALLOW_CREDENTIALS = "access-control-allow-credentials"
- const val REQUEST_METHOD = "access-control-request-method"
- const val EXPOSE_HEADERS = "access-control-expose-headers"
- const val REQUEST_HEADERS = "access-control-request-headers"
- const val ALLOW_HEADERS = "access-control-allow-headers"
- const val MAX_AGE = "access-control-max-age"
- }
- constructor(
- allowedOrigin: String = "*",
- allowedMethods: Set<HttpMethod> = ALL,
- allowedHeaders: Set<String> = emptySet(),
- exposedHeaders: Set<String> = emptySet(),
- supportCredentials: Boolean = true,
- preFlightStatus: HttpStatus = NO_CONTENT_204,
- preFlightMaxAge: Long = 0) :
- this(
- Glob(allowedOrigin).regex,
- allowedMethods,
- allowedHeaders,
- exposedHeaders,
- supportCredentials,
- preFlightStatus,
- preFlightMaxAge
- )
- init {
- val preFlightStatusType = preFlightStatus.type
- require(preFlightStatusType == SUCCESS) {
- "Preflight Status must be a success status: $preFlightStatusType"
- }
- }
- override fun invoke(context: HttpContext): HttpContext =
- context.simpleRequest().let {
- if (context.request.method == OPTIONS) it.preFlightRequest()
- else it
- }.let {
- if (it.response.status != FORBIDDEN_403) it.next()
- else it
- }
- private fun allowOrigin(origin: String): Boolean =
- allowedOrigin.matches(origin)
- private fun accessControlAllowOrigin(origin: String): String =
- if (allowedOrigin.pattern == ".*" && !supportCredentials) "*"
- else origin
- private fun HttpContext.simpleRequest(): HttpContext {
- val origin = request.origin() ?: return this
- if (!allowOrigin(origin))
- return forbidden("Not allowed origin: $origin")
- val accessControlAllowOrigin = accessControlAllowOrigin(origin)
- var h = response.headers + Header(ALLOW_ORIGIN, accessControlAllowOrigin)
- if (accessControlAllowOrigin != "*")
- h += Header("vary", "Origin")
- if (supportCredentials)
- h += Header(ALLOW_CREDENTIALS, true)
- val accessControlRequestMethod = request.headers[REQUEST_METHOD]
- if (request.method == OPTIONS && accessControlRequestMethod != null)
- return badRequest()
- if (request.method !in allowedMethods)
- return forbidden("Not allowed method: ${request.method}")
- if (exposedHeaders.isNotEmpty()) {
- val requestHeaderNames = request.headers.httpFields.keys.toSet()
- val requestHeaders = requestHeaderNames.filter { it in exposedHeaders }
- h += Header(EXPOSE_HEADERS, requestHeaders.joinToString(","))
- }
- return send(preFlightStatus, headers = h)
- }
- private fun HttpContext.preFlightRequest(): HttpContext {
- val methodHeader = request.headers[REQUEST_METHOD]?.value as? String
- val requestMethod = methodHeader
- ?: return forbidden("$REQUEST_METHOD required header not found")
- val method = HttpMethod.valueOf(requestMethod)
- if (method !in allowedMethods)
- return forbidden("Not allowed method: $method")
- val accessControlRequestHeaders = request.headers[REQUEST_HEADERS]?.value as? String
- var h = response.headers
- if (accessControlRequestHeaders != null) {
- val allowedHeaders = accessControlRequestHeaders
- .split(",")
- .map { it.trim() }
- .all { it in allowedHeaders }
- if (!allowedHeaders && this@CorsCallback.allowedHeaders.isNotEmpty())
- return forbidden("Not allowed headers")
- val headers = this@CorsCallback.allowedHeaders
- val requestHeaders = headers.ifEmpty { request.headers.httpFields.keys.toSet() }
- h += Header(ALLOW_HEADERS, requestHeaders.joinToString(","))
- }
- h += Header(REQUEST_METHOD, allowedMethods.joinToString(","))
- if (preFlightMaxAge > 0)
- h += Header(MAX_AGE, preFlightMaxAge.toString())
- val origin = request.origin() ?: ""
- return when {
- allowOrigin(origin) && origin.isBlank() ->
- send(preFlightStatus, headers = h)
- allowOrigin(origin) ->
- send(preFlightStatus, headers = h + Header(ALLOW_ORIGIN, accessControlAllowOrigin(origin)))
- !allowOrigin(origin) && origin.isNotBlank() ->
- forbidden("Not allowed origin: $origin")
- else ->
- forbidden("Forbidden pre-flight request")
- }
- }
- }