VerifySpecCallback.kt

  1. package com.hexagonkt.rest.tools.openapi

  2. import com.atlassian.oai.validator.OpenApiInteractionValidator
  3. import com.atlassian.oai.validator.OpenApiInteractionValidator.createForInlineApiSpecification
  4. import com.atlassian.oai.validator.model.Request
  5. import com.atlassian.oai.validator.model.Request.Method
  6. import com.atlassian.oai.validator.model.Response
  7. import com.atlassian.oai.validator.model.SimpleRequest
  8. import com.atlassian.oai.validator.model.SimpleResponse
  9. import com.atlassian.oai.validator.report.ValidationReport
  10. import com.atlassian.oai.validator.report.ValidationReport.Message
  11. import com.hexagonkt.http.handlers.HttpCallback
  12. import com.hexagonkt.http.handlers.HttpContext
  13. import com.hexagonkt.http.model.ContentType
  14. import com.hexagonkt.http.model.HttpMethod
  15. import com.hexagonkt.http.model.HttpMethod.*
  16. import java.net.URL
  17. import kotlin.jvm.optionals.getOrNull

  18. /**
  19.  * Callback that verifies server calls comply with a given OpenAPI spec.
  20.  *
  21.  * @param spec Location of the spec used to validate HTTP calls.
  22.  */
  23. class VerifySpecCallback(spec: URL) : HttpCallback {

  24.     private val messagePrefix: String = "\n- "
  25.     private val specText: String = spec.readText()
  26.     private val validator: OpenApiInteractionValidator =
  27.         createForInlineApiSpecification(specText).build()

  28.     override fun invoke(context: HttpContext): HttpContext {
  29.         val requestReport = validator.validateRequest(request(context))

  30.         val result = context.next()

  31.         val resultMethod = method(result.method)
  32.         val responseReport = validator.validateResponse(result.path, resultMethod, response(result))

  33.         val callReport = responseReport.merge(requestReport)

  34.         return if (callReport.hasErrors()) result.badRequest(message(callReport))
  35.         else result
  36.     }

  37.     private fun message(report: ValidationReport): String {
  38.         val messages = report.messages.map(::messageToText).distinct()
  39.         return messages.joinToString(messagePrefix, "Invalid call:$messagePrefix")
  40.     }

  41.     private fun messageToText(it: Message): String {
  42.         val level = it.level
  43.         val key = it.key
  44.         val context = it.context
  45.             .map { c ->
  46.                 val op = c.apiOperation
  47.                     .getOrNull()
  48.                     ?.let { ao ->
  49.                         val method = ao.method
  50.                         val apiPath = ao.apiPath
  51.                         "$method ${apiPath.normalised()}"
  52.                     }
  53.                     ?: ""

  54.                 val loc = c.location.getOrNull()?.name ?: ""

  55.                 "$op $loc"
  56.             }
  57.             .orElse("")

  58.         val message = it.message
  59.         val additionalInfo = it.additionalInfo
  60.         val nestedMessages = it.nestedMessages

  61.         return "$level: $key [$context] $message $additionalInfo $nestedMessages"
  62.     }

  63.     private fun request(context: HttpContext): Request {
  64.         val request = context.request
  65.         val builder = SimpleRequest.Builder(method(context.method), context.path, true)

  66.         if (request.bodyString().isNotEmpty())
  67.             builder.withBody(request.bodyString())

  68.         request.contentType?.text?.let(builder::withContentType)
  69.         request.headers.httpFields.values.forEach { builder.withHeader(it.name, it.strings()) }
  70.         request.accept.map(ContentType::text).forEach(builder::withAccept)
  71.         request.authorization?.text?.let(builder::withAuthorization)
  72.         request.queryParameters.httpFields.values.forEach {
  73.             builder.withQueryParam(it.name, it.strings())
  74.         }

  75.         return builder.build()
  76.     }

  77.     private fun response(context: HttpContext): Response {
  78.         val response = context.response
  79.         val builder = SimpleResponse.Builder(context.status.code)

  80.         builder.withBody(response.bodyString())

  81.         response.contentType?.text?.let(builder::withContentType)
  82.         response.headers.httpFields.values.forEach { builder.withHeader(it.name, it.strings()) }

  83.         return builder.build()
  84.     }

  85.     private fun method(method: HttpMethod): Method =
  86.         when (method) {
  87.             GET -> Method.GET
  88.             HEAD -> Method.HEAD
  89.             POST -> Method.POST
  90.             PUT -> Method.PUT
  91.             DELETE -> Method.DELETE
  92.             TRACE -> Method.TRACE
  93.             OPTIONS -> Method.OPTIONS
  94.             PATCH -> Method.PATCH
  95.         }
  96. }