MongoDbStore.kt

  1. package com.hexagontk.store.mongodb

  2. import com.hexagontk.core.fail
  3. import com.hexagontk.core.filterNotEmpty
  4. import com.hexagontk.core.toLocalDateTime
  5. import com.hexagontk.store.Store
  6. import com.mongodb.ConnectionString
  7. import com.mongodb.client.FindIterable
  8. import com.mongodb.client.MongoClients
  9. import com.mongodb.client.MongoCollection
  10. import com.mongodb.client.MongoDatabase
  11. import com.mongodb.client.model.Filters
  12. import com.mongodb.client.model.ReplaceOptions
  13. import com.mongodb.client.model.Updates
  14. import org.bson.BsonBinary
  15. import org.bson.BsonString
  16. import org.bson.Document
  17. import org.bson.conversions.Bson
  18. import org.bson.types.Binary
  19. import java.net.URL
  20. import java.nio.ByteBuffer
  21. import java.time.LocalDateTime
  22. import java.time.ZoneId
  23. import java.time.ZoneOffset
  24. import java.util.*
  25. import kotlin.reflect.KClass
  26. import kotlin.reflect.KProperty1

  27. class MongoDbStore<T : Any, K : Any>(
  28.     override val type: KClass<T>,
  29.     override val key: KProperty1<T, K>,
  30.     private val database: MongoDatabase,
  31.     override val name: String = type.java.simpleName,
  32.     override val encoder: (T) -> Map<String, *>,
  33.     override val decoder: (Map<String, *>) -> T,
  34. ) : Store<T, K> {

  35.     companion object {
  36.         fun database(url: String): MongoDatabase = ConnectionString(url).let {
  37.             MongoClients.create(it).getDatabase(it.database ?: fail)
  38.         }
  39.     }

  40.     val collection: MongoCollection<Document> = this.database.getCollection(name)

  41.     constructor(
  42.         type: KClass<T>,
  43.         key: KProperty1<T, K>,
  44.         url: String,
  45.         name: String = type.java.simpleName,
  46.         encoder: (T) -> Map<String, *>,
  47.         decoder: (Map<String, *>) -> T,
  48.     ) :
  49.         this(type, key, database(url), name, encoder, decoder)

  50.     override fun insertOne(instance: T): K {
  51.         collection.insertOne(map(instance))
  52.         return key.get(instance)
  53.     }

  54.     override fun insertMany(instances: List<T>): List<K> {
  55.         collection.insertMany(instances.map { instance -> map(instance) })
  56.         return instances.map { key.get(it) }
  57.     }

  58.     override fun saveOne(instance: T): K? {
  59.         val filter = createKeyFilter(key.get(instance))
  60.         val options = ReplaceOptions().upsert(true)
  61.         val result = collection.replaceOne(filter, map(instance), options)
  62.         val upsertedId = result.upsertedId

  63.         @Suppress("UNCHECKED_CAST")
  64.         return if (upsertedId == null) null
  65.             else fromStore(upsertedId as Any) as K
  66.     }

  67.     override fun saveMany(instances: List<T>): List<K?> =
  68.         instances.map(this::saveOne)

  69.     override fun replaceOne(instance: T): Boolean {
  70.         val document = map(instance)
  71.         val filter = createKeyFilter(key.get(instance))
  72.         val result = collection.replaceOne(filter, document)
  73.         // *NOTE* that 'modifiedCount' returns 0 for matched records with unchanged update values
  74.         return result.matchedCount == 1L
  75.     }

  76.     override fun replaceMany(instances: List<T>): List<T> =
  77.         instances.mapNotNull { if (replaceOne(it)) it else null }

  78.     override fun updateOne(key: K, updates: Map<String, *>): Boolean {
  79.         require(updates.isNotEmpty())
  80.         val filter = createKeyFilter(key)
  81.         val update = createUpdate(updates)
  82.         val result = collection.updateOne(filter, update)
  83.         // *NOTE* that 'modifiedCount' returns 0 for matched records with unchanged update values
  84.         return result.matchedCount == 1L
  85.     }

  86.     override fun updateMany(filter: Map<String, *>, updates: Map<String, *>): Long {
  87.         require(updates.isNotEmpty())
  88.         val updateFilter = createFilter(filter)
  89.         val update = createUpdate(updates)
  90.         val result = collection.updateMany(updateFilter, update)
  91.         // *NOTE* that 'modifiedCount' returns 0 for matched records with unchanged update values
  92.         return result.matchedCount
  93.     }

  94.     override fun deleteOne(id: K): Boolean {
  95.         val filter = createKeyFilter(id)
  96.         val result = collection.deleteOne(filter)
  97.         return result.deletedCount == 1L
  98.     }

  99.     override fun deleteMany(filter: Map<String, *>): Long {
  100.         val deleteFilter = createFilter(filter)
  101.         val result = collection.deleteMany(deleteFilter)
  102.         return result.deletedCount
  103.     }

  104.     override fun findOne(key: K): T? {
  105.         val result = collection.find(createKeyFilter(key)).first()?.filterNotEmpty()
  106.         return if (result == null) null else fromStore(result)
  107.     }

  108.     override fun findOne(key: K, fields: List<String>): Map<String, *>? {
  109.         val filter = createKeyFilter(key)
  110.         val result = collection
  111.             .find(filter)
  112.             .projection(createProjection(fields))
  113.             .first()?.filterNotEmpty()

  114.         return result?.mapValues { fromStore(it.value) }
  115.     }

  116.     override fun findMany(
  117.         filter: Map<String, *>,
  118.         limit: Int?,
  119.         skip: Int?,
  120.         sort: Map<String, Boolean>
  121.     ): List<T> {

  122.         val findFilter = createFilter(filter)
  123.         val findSort = createSort(sort)
  124.         val query = collection.find(findFilter).sort(findSort)

  125.         pageQuery(limit, query, skip)

  126.         val result = query.into(ArrayList())
  127.         return result.map { fromStore(it.filterNotEmpty()) }
  128.     }

  129.     override fun findMany(
  130.         filter: Map<String, *>,
  131.         fields: List<String>,
  132.         limit: Int?,
  133.         skip: Int?,
  134.         sort: Map<String, Boolean>
  135.     ): List<Map<String, *>> {

  136.         val findFilter = createFilter(filter)
  137.         val projection = createProjection(fields)
  138.         val findSort = createSort(sort)
  139.         val query = collection.find(findFilter).projection(projection).sort(findSort)

  140.         pageQuery(limit, query, skip)

  141.         val result = query.into(ArrayList())

  142.         return result.map { resultMap ->
  143.             resultMap
  144.                 .map { pair -> pair.key to fromStore(pair.value) }
  145.                 .toMap()
  146.                 .filterNotEmpty()
  147.         }
  148.     }

  149.     override fun count(filter: Map<String, *>): Long {
  150.         val countFilter = createFilter(filter)
  151.         return collection.countDocuments(countFilter)
  152.     }

  153.     override fun drop() {
  154.         collection.drop()
  155.     }

  156.     private fun pageQuery(limit: Int?, query: FindIterable<Document>, skip: Int?) {
  157.         if (limit != null)
  158.             query.limit(limit)

  159.         if (skip != null)
  160.             query.skip(skip)
  161.     }

  162.     private fun map(instance: T): Document = Document(toStore(instance))

  163.     private fun createKeyFilter(key: K) = Filters.eq("_id", key)

  164.     private fun createFilter(filter: Map<String, *>): Bson = filter
  165.         .filterNotEmpty()
  166.         .map {
  167.             val keyFields = it.key.split(":")
  168.             val key = keyFields.firstOrNull() ?: fail
  169.             val collectionKey = if (key == this.key.name) "_id" else key
  170.             val operator = keyFields.getOrNull(1)
  171.             val value = it.value

  172.             when {
  173.                 value is List<*> ->
  174.                     if (value.size > 1) Filters.`in`(collectionKey, value)
  175.                     else Filters.eq(collectionKey, value.first())
  176.                 operator != null ->
  177.                     when (operator) {
  178.                         "gt" -> Filters.gt(collectionKey, value)
  179.                         "gte" -> Filters.gte(collectionKey, value)
  180.                         "lt" -> Filters.lt(collectionKey, value)
  181.                         "lte" -> Filters.lte(collectionKey, value)
  182.                         "re" -> Filters.regex(collectionKey, value.toString())
  183.                         else -> Filters.eq(collectionKey, value)
  184.                     }
  185.                 else ->
  186.                     Filters.eq(collectionKey, value)
  187.             }
  188.         }
  189.         .let {
  190.             if (it.isEmpty()) Document()
  191.             else Filters.and(it)
  192.         }

  193.     private fun createUpdate(update: Map<String, *>): Bson =
  194.         Updates.combine(
  195.             update
  196.                 .filter { it.value != null }
  197.                 .mapValues { toStore(it.value as Any) }
  198.                 .map { Updates.set(it.key, it.value) }
  199.         )

  200.     private fun createProjection(fields: List<String>): Bson =
  201.         if (fields.isEmpty()) Document()
  202.         else
  203.             fields
  204.                 .asSequence()
  205.                 .filter { fields.contains(it) }
  206.                 .map { it to 1 }
  207.                 .toMap()
  208.                 .toDocument()
  209.                 .append("_id", 0)

  210.     private fun createSort(fields: Map<String, Boolean>): Bson =
  211.         fields
  212.             .filter { fields.contains(it.key) }
  213.             .mapValues { if (it.value) -1 else 1 }
  214.             .toDocument()

  215.     private fun Map<String, *>.toDocument() = Document(this)

  216.     private fun toStore(instance: T): Map<String, Any> =
  217.         (encoder(instance) + ("_id" to key.get(instance)) - key.name)
  218.             .filterNotEmpty()
  219.             .mapKeys { it.key }
  220.             .mapValues { toStore(it.value) }

  221.     private fun fromStore(map: Map<String, Any>): T =
  222.         (map + (key.name to map["_id"]))
  223.             .filterNotEmpty()
  224.             .mapValues { fromStore(it.value) }
  225.             .let(decoder)

  226.     private fun fromStore(value: Any): Any = when (value) {
  227.         is Binary -> value.data
  228.         is BsonBinary -> value.data
  229.         is BsonString -> value.value
  230.         is Date -> value.toLocalDateTime()
  231.         is Iterable<*> -> value.map { i -> i?.let { fromStore(it) } }
  232.         is Map<*, *> -> value.mapValues { v -> v.value?.let { fromStore(it) } }
  233.         else -> value
  234.     }

  235.     private fun toStore(value: Any): Any = when (value) {
  236.         is Enum<*> -> value.name
  237.         is ByteArray -> BsonBinary(value)
  238.         is ByteBuffer -> BsonBinary(value.array())
  239.         is URL -> value.toString()
  240.         is LocalDateTime -> value
  241.             .atZone(ZoneId.systemDefault())
  242.             .withZoneSameInstant(ZoneOffset.UTC)
  243.             .toLocalDateTime()
  244.         is Iterable<*> -> value.map { i -> i?.let { toStore(it) } }
  245.         is Map<*, *> -> value.mapValues { v -> v.value?.let { toStore(it) } }
  246.         else -> value
  247.     }
  248. }