CorsCallback.kt

package com.hexagonkt.http.server.callbacks

import com.hexagonkt.core.text.Glob
import com.hexagonkt.http.model.*
import com.hexagonkt.http.model.HttpMethod.Companion.ALL
import com.hexagonkt.http.model.HttpMethod.OPTIONS
import com.hexagonkt.http.model.HttpStatusType.SUCCESS
import com.hexagonkt.http.handlers.HttpContext

/**
 * HTTP CORS callback. It holds info for CORS.
 */
class CorsCallback(
    private val allowedOrigin: Regex,
    private val allowedMethods: Set<HttpMethod> = ALL,
    private val allowedHeaders: Set<String> = emptySet(),
    private val exposedHeaders: Set<String> = emptySet(),
    private val supportCredentials: Boolean = true,
    private val preFlightStatus: HttpStatus = NO_CONTENT_204,
    private val preFlightMaxAge: Long = 0
) : (HttpContext) -> HttpContext {

    private companion object {
        const val ALLOW_ORIGIN = "access-control-allow-origin"
        const val ALLOW_CREDENTIALS = "access-control-allow-credentials"
        const val REQUEST_METHOD = "access-control-request-method"
        const val EXPOSE_HEADERS = "access-control-expose-headers"
        const val REQUEST_HEADERS = "access-control-request-headers"
        const val ALLOW_HEADERS = "access-control-allow-headers"
        const val MAX_AGE = "access-control-max-age"
    }

    constructor(
        allowedOrigin: String = "*",
        allowedMethods: Set<HttpMethod> = ALL,
        allowedHeaders: Set<String> = emptySet(),
        exposedHeaders: Set<String> = emptySet(),
        supportCredentials: Boolean = true,
        preFlightStatus: HttpStatus = NO_CONTENT_204,
        preFlightMaxAge: Long = 0) :
        this(
            Glob(allowedOrigin).regex,
            allowedMethods,
            allowedHeaders,
            exposedHeaders,
            supportCredentials,
            preFlightStatus,
            preFlightMaxAge
        )

    init {
        val preFlightStatusType = preFlightStatus.type

        require(preFlightStatusType == SUCCESS) {
            "Preflight Status must be a success status: $preFlightStatusType"
        }
    }

    override fun invoke(context: HttpContext): HttpContext =
        context.simpleRequest().let {
            if (context.request.method == OPTIONS) it.preFlightRequest()
            else it
        }.let {
            if (it.response.status != FORBIDDEN_403) it.next()
            else it
        }

    private fun allowOrigin(origin: String): Boolean =
        allowedOrigin.matches(origin)

    private fun accessControlAllowOrigin(origin: String): String =
        if (allowedOrigin.pattern == ".*" && !supportCredentials) "*"
        else origin

    private fun HttpContext.simpleRequest(): HttpContext {
        val origin = request.origin() ?: return this
        if (!allowOrigin(origin))
            return forbidden("Not allowed origin: $origin")

        val accessControlAllowOrigin = accessControlAllowOrigin(origin)
        var h = response.headers + Header(ALLOW_ORIGIN, accessControlAllowOrigin)

        if (accessControlAllowOrigin != "*")
            h += Header("vary", "Origin")

        if (supportCredentials)
            h += Header(ALLOW_CREDENTIALS, true)

        val accessControlRequestMethod = request.headers[REQUEST_METHOD]
        if (request.method == OPTIONS && accessControlRequestMethod != null)
            return badRequest()

        if (request.method !in allowedMethods)
            return forbidden("Not allowed method: ${request.method}")

        if (exposedHeaders.isNotEmpty()) {
            val requestHeaderNames = request.headers.httpFields.keys.toSet()
            val requestHeaders = requestHeaderNames.filter { it in exposedHeaders }

            h += Header(EXPOSE_HEADERS, requestHeaders.joinToString(","))
        }

        return send(preFlightStatus, headers = h)
    }

    private fun HttpContext.preFlightRequest(): HttpContext {

        val methodHeader = request.headers[REQUEST_METHOD]?.value as? String
        val requestMethod = methodHeader
            ?: return forbidden("$REQUEST_METHOD required header not found")

        val method = HttpMethod.valueOf(requestMethod)
        if (method !in allowedMethods)
            return forbidden("Not allowed method: $method")

        val accessControlRequestHeaders = request.headers[REQUEST_HEADERS]?.value as? String

        var h = response.headers

        if (accessControlRequestHeaders != null) {
            val allowedHeaders = accessControlRequestHeaders
                .split(",")
                .map { it.trim() }
                .all { it in allowedHeaders }

            if (!allowedHeaders && this@CorsCallback.allowedHeaders.isNotEmpty())
                return forbidden("Not allowed headers")

            val headers = this@CorsCallback.allowedHeaders
            val requestHeaders = headers.ifEmpty { request.headers.httpFields.keys.toSet() }
            h += Header(ALLOW_HEADERS, requestHeaders.joinToString(","))
        }

        h += Header(REQUEST_METHOD, allowedMethods.joinToString(","))

        if (preFlightMaxAge > 0)
            h += Header(MAX_AGE, preFlightMaxAge.toString())

        val origin = request.origin() ?: ""
        return when {
            allowOrigin(origin) && origin.isBlank() ->
                send(preFlightStatus, headers = h)
            allowOrigin(origin) ->
                send(preFlightStatus, headers = h + Header(ALLOW_ORIGIN, accessControlAllowOrigin(origin)))
            !allowOrigin(origin) && origin.isNotBlank() ->
                forbidden("Not allowed origin: $origin")
            else ->
                forbidden("Forbidden pre-flight request")
        }
    }
}