RabbitMqClient.kt

  1. package com.hexagontk.messaging.rabbitmq

  2. import com.codahale.metrics.MetricRegistry
  3. import com.hexagontk.http.parseQueryString
  4. import com.hexagontk.core.*
  5. import com.hexagontk.helpers.retry
  6. import com.hexagontk.serialization.SerializationFormat
  7. import com.rabbitmq.client.*
  8. import com.rabbitmq.client.AMQP.BasicProperties
  9. import com.rabbitmq.client.impl.StandardMetricsCollector

  10. import java.io.Closeable
  11. import java.lang.Runtime.getRuntime
  12. import java.lang.System.Logger
  13. import java.net.URI
  14. import java.nio.charset.Charset.defaultCharset
  15. import java.util.UUID.randomUUID
  16. import java.util.concurrent.ArrayBlockingQueue
  17. import java.util.concurrent.Executors.newFixedThreadPool
  18. import kotlin.reflect.KClass

  19. /**
  20.  * Rabbit client.
  21.  *
  22.  * * TODO Review if channel handling is still needed in Java 4.1.x version
  23.  * * TODO Add metrics
  24.  * * TODO Ordered shutdown
  25.  */
  26. class RabbitMqClient(
  27.     private val connectionFactory: ConnectionFactory,
  28.     private val poolSize: Int = getRuntime().availableProcessors(),
  29.     private val serializationFormat: SerializationFormat
  30. ) : Closeable {

  31.     internal companion object {

  32.         private fun <T> setVar(value: T?, setter: (T) -> Unit) {
  33.             if (value != null)
  34.                 setter(value)
  35.         }

  36.         internal fun createConnectionFactory(uri: URI): ConnectionFactory {
  37.             require(uri.toString().isNotBlank())

  38.             val cf = ConnectionFactory()
  39.             cf.setUri(uri)

  40.             val queryParameters = parseQueryString(uri.query ?: "").values
  41.             val params = queryParameters.filterNot { (_, v) -> v.text.isBlank() }
  42.             fun value(name: String): String? = params[name]?.text
  43.             val automaticRecovery = value("automaticRecovery")?.toBoolean()
  44.             val recoveryInterval = value("recoveryInterval")?.toLong()
  45.             val shutdownTimeout = value("shutdownTimeout")?.toInt()
  46.             val heartbeat = value("heartbeat")?.toInt()
  47.             val metricsCollector = StandardMetricsCollector(MetricRegistry())

  48.             setVar(automaticRecovery) { cf.isAutomaticRecoveryEnabled = it }
  49.             setVar(recoveryInterval) { cf.networkRecoveryInterval = it }
  50.             setVar(shutdownTimeout) { cf.shutdownTimeout = it }
  51.             setVar(heartbeat) { cf.requestedHeartbeat = it }
  52.             setVar(metricsCollector) { cf.metricsCollector = it }

  53.             return cf
  54.         }
  55.     }

  56.     private val log: Logger = loggerOf(this::class)
  57.     private val args = hashMapOf<String, Any>()

  58.     @Volatile private var count: Int = 0
  59.     private val threadPool = newFixedThreadPool(poolSize) { Thread(it, "rabbitmq-" + count++) }
  60.     private var connection: Connection? = connectionFactory.newConnection()
  61.     private val metrics: Metrics = Metrics(connectionFactory.metricsCollector as StandardMetricsCollector)
  62.     private val listener = ConnectionListener()

  63.     /** . */
  64.     constructor (uri: URI, serializationFormat: SerializationFormat) :
  65.         this(createConnectionFactory(uri), serializationFormat = serializationFormat)

  66.     /** . */
  67.     val connected: Boolean get() = connection?.isOpen ?: false

  68.     /** @see Closeable.close */
  69.     override fun close() {
  70.         connection?.removeShutdownListener(listener)
  71.         (connection as? Recoverable)?.removeRecoveryListener(listener)
  72.         connection?.close()
  73.         connection = null
  74.         metrics.report()
  75.         log.info { "RabbitMQ client closed" }
  76.     }

  77.     /** . */
  78.     fun declareQueue(name: String) {
  79.         args["x-max-length-bytes"] = 1048576  // max queue length
  80.         withChannel { it.queueDeclare(name, false, false, false, args) }
  81.     }

  82.     /** . */
  83.     fun deleteQueue(name: String) {
  84.         withChannel { it.queueDelete(name) }
  85.     }

  86.     /** . */
  87.     fun bindExchange(exchange: String, exchangeType: String, routingKey: String, queue: String) {
  88.         withChannel {
  89.             it.queueDeclare(queue, false, false, false, null)
  90.             it.queuePurge(queue)
  91.             it.exchangeDeclare(exchange, exchangeType, false, false, false, null)
  92.             it.queueBind(queue, exchange, routingKey)
  93.         }
  94.     }

  95.     /** . */
  96.     fun <T : Any> consume(
  97.         exchange: String,
  98.         routingKey: String,
  99.         type: KClass<T>,
  100.         decoder: (Map<String, *>) -> T,
  101.         handler: (T) -> Unit,
  102.     ) {

  103.         withChannel {
  104.             it.queueDeclare(routingKey, false, false, false, null)
  105.             it.queuePurge(routingKey)
  106.             it.queueBind(routingKey, exchange, routingKey)
  107.         }
  108.         consume(routingKey, type, decoder, handler)
  109.     }

  110.     fun <T : Any, R : Any> consume(
  111.         queueName: String, type: KClass<T>, decoder: (Map<String, *>) -> T, handler: (T) -> R
  112.     ) {
  113.         val channel = createChannel()
  114.         val callback = Handler(
  115.             connectionFactory,
  116.             channel,
  117.             threadPool,
  118.             type,
  119.             handler,
  120.             serializationFormat = serializationFormat,
  121.             decoder
  122.         )
  123.         channel.basicConsume(queueName, false, callback)
  124.         log.info { "Consuming messages in $queueName" }
  125.     }

  126.     /**
  127.      * Tries to get a channel for five times. If it does not succeed it throws an
  128.      * IllegalStateException.
  129.      *
  130.      * @return A new channel.
  131.      */
  132.     private fun createChannel(): Channel =
  133.         retry(times = 3, delay = 50) {
  134.             if (connection?.isOpen != true) {
  135.                 connection = connectionFactory.newConnection()
  136.                 connection?.addShutdownListener(listener)
  137.                 (connection as Recoverable).addRecoveryListener(listener)
  138.                 log.warn { "Rabbit connection RESTORED" }
  139.             }
  140.             val channel = connection?.createChannel() ?: fail
  141.             channel.basicQos(poolSize)
  142.             channel.addShutdownListener(listener)
  143.             (channel as Recoverable).addRecoveryListener(listener)
  144.             channel
  145.         }

  146.     private fun <T> withChannel(callback: (Channel) -> T): T {
  147.         var channel: Channel? = null
  148.         try {
  149.             channel = createChannel()
  150.             return callback(channel)
  151.         }
  152.         finally {
  153.             if (channel != null && channel.isOpen)
  154.                 channel.close()
  155.         }
  156.     }

  157.     fun publish(queue: String, message: String, correlationId: String? = null) =
  158.         publish("", queue, message, correlationId)

  159.     fun publish(
  160.         exchange: String,
  161.         routingKey: String,
  162.         message: String,
  163.         correlationId: String? = null) {

  164.         withChannel { channel ->
  165.             publish(channel, exchange, routingKey, null, message, correlationId, null)
  166.         }
  167.     }

  168.     private fun publish(
  169.         channel: Channel,
  170.         exchange: String,
  171.         routingKey: String,
  172.         encoding: String?,
  173.         message: String,
  174.         correlationId: String?,
  175.         replyQueueName: String?) {

  176.         val builder = BasicProperties.Builder()

  177.         if (!correlationId.isNullOrBlank())
  178.             builder.correlationId(correlationId)

  179.         if (!replyQueueName.isNullOrBlank())
  180.             builder.replyTo(replyQueueName)

  181.         if (!encoding.isNullOrBlank())
  182.             builder.contentEncoding(encoding)

  183.         val props = builder.build()

  184.         val charset = if (encoding == null) defaultCharset() else charset(encoding)
  185.         channel.basicPublish(exchange, routingKey, props, message.toByteArray(charset))

  186.         log.debug {
  187.             """
  188.             EXCHANGE: $exchange ROUTING KEY: $routingKey
  189.             REPLY TO: $replyQueueName CORRELATION ID: $correlationId
  190.             BODY:
  191.             $message""".trimIndent()
  192.         }
  193.     }

  194.     fun call(requestQueue: String, message: String): String =
  195.         withChannel {
  196.             val correlationId = randomUUID().toString()
  197.             val replyQueueName = it.queueDeclare().queue
  198.             val charset = defaultCharset().name()

  199.             publish(it, "", requestQueue, charset, message, correlationId, replyQueueName)

  200.             val response = ArrayBlockingQueue<String>(1)
  201.             val consumer = object : DefaultConsumer(it) {
  202.                 override fun handleDelivery(
  203.                     consumerTag: String?,
  204.                     envelope: Envelope?,
  205.                     properties: BasicProperties?,
  206.                     body: ByteArray?) {

  207.                     if (properties?.correlationId == correlationId)
  208.                         response.offer(String(body ?: byteArrayOf()))
  209.                 }

  210.                 override fun handleCancelOk(consumerTag: String) {
  211.                     log.debug { "Explicit cancel for the consumer $consumerTag" }
  212.                 }
  213.             }

  214.             val consumerTag = it.basicConsume(replyQueueName, true, consumer)

  215.             val result: String = response.take() // Wait until there is an element in the array blocking queue
  216.             it.basicCancel(consumerTag)
  217.             result
  218.         }
  219. }