milvus_helpers.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import sys
  2. from config import MILVUS_HOST, MILVUS_PORT, VECTOR_DIMENSION, METRIC_TYPE
  3. from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
  4. from logs import LOGGER
  5. class MilvusHelper:
  6. """
  7. MilvusHelper class to manager the Milvus Collection.
  8. Args:
  9. host (`str`):
  10. Milvus server Host.
  11. port (`str|int`):
  12. Milvus server port.
  13. ...
  14. """
  15. def __init__(self, host=MILVUS_HOST, port=MILVUS_PORT):
  16. try:
  17. self.collection = None
  18. connections.connect(host=host, port=port)
  19. LOGGER.debug(f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}")
  20. except Exception as e:
  21. LOGGER.error(f"Failed to connect Milvus: {e}")
  22. sys.exit(1)
  23. def set_collection(self, collection_name):
  24. try:
  25. self.collection = Collection(name=collection_name)
  26. except Exception as e:
  27. LOGGER.error(f"Failed to load data to Milvus: {e}")
  28. sys.exit(1)
  29. def has_collection(self, collection_name):
  30. # Return if Milvus has the collection
  31. try:
  32. return utility.has_collection(collection_name)
  33. except Exception as e:
  34. LOGGER.error(f"Failed to load data to Milvus: {e}")
  35. sys.exit(1)
  36. def create_collection(self, collection_name):
  37. # Create milvus collection if not exists
  38. try:
  39. if not self.has_collection(collection_name):
  40. im_hash = FieldSchema(name='im_hash', dtype=DataType.VARCHAR, descrition='id to image', max_length=500,
  41. is_primary=True, auto_id=False)
  42. product_id = FieldSchema(name='product_id', dtype=DataType.VARCHAR, descrition='product_id to image', max_length=500,
  43. is_primary=False, auto_id=False)
  44. im_vector = FieldSchema(name="im_vector", dtype=DataType.FLOAT_VECTOR, descrition="image vectors",
  45. dim=VECTOR_DIMENSION, is_primary=False)
  46. schema = CollectionSchema(fields=[im_hash, product_id,im_vector], description="collection_name: "+collection_name)
  47. self.collection = Collection(name=collection_name, schema=schema)
  48. self.create_index(collection_name)
  49. LOGGER.debug(f"Create Milvus collection: {collection_name}")
  50. else:
  51. self.set_collection(collection_name)
  52. return "OK"
  53. except Exception as e:
  54. LOGGER.error(f"Failed to load data to Milvus: {e}")
  55. sys.exit(1)
  56. # 创建分区
  57. def create_partition(self, partition_name):
  58. # Create milvus collection if not exists
  59. try:
  60. if partition_name is not None:
  61. if not self.collection.has_partition(partition_name):
  62. self.collection.create_partition(partition_name)
  63. LOGGER.debug(f"Create Milvus partition: {partition_name}")
  64. # else:
  65. # self.set_collection(collection_name)
  66. return "OK"
  67. except Exception as e:
  68. LOGGER.error(f"Failed to load data to Milvus: {e}")
  69. sys.exit(1)
  70. def insert(self, collection_name,partition_name,im_hash,product_id,im_vector):
  71. # Batch insert vectors to milvus collection
  72. try:
  73. data = [im_hash,product_id,im_vector]
  74. print(data)
  75. self.set_collection(collection_name)
  76. self.create_partition(partition_name)
  77. mr = self.collection.insert(data,partition_name)
  78. ids = mr.primary_keys
  79. self.collection.load()
  80. LOGGER.debug(
  81. f"Insert vectors to Milvus in collection: {collection_name} with {len(im_vector)} rows")
  82. return ids
  83. except Exception as e:
  84. LOGGER.error(f"Failed to load data to Milvus: {e}")
  85. sys.exit(1)
  86. def create_index(self, collection_name):
  87. # Create IVF_FLAT index on milvus collection
  88. try:
  89. self.set_collection(collection_name)
  90. default_index = {"index_type": "IVF_FLAT", "metric_type": METRIC_TYPE, "params": {"nlist": 16384}}
  91. status = self.collection.create_index(field_name="im_vector", index_params=default_index)
  92. if not status.code:
  93. LOGGER.debug(
  94. f"Successfully create index in collection:{collection_name} with param:{default_index}")
  95. return status
  96. else:
  97. raise Exception(status.message)
  98. except Exception as e:
  99. LOGGER.error(f"Failed to create index: {e}")
  100. sys.exit(1)
  101. def delete_collection(self, collection_name):
  102. # Delete Milvus collection
  103. try:
  104. self.set_collection(collection_name)
  105. self.collection.drop()
  106. LOGGER.debug("Successfully drop collection!")
  107. return "ok"
  108. except Exception as e:
  109. LOGGER.error(f"Failed to drop collection: {e}")
  110. sys.exit(1)
  111. def delete_record(self, collection_name,partition_name,expr):
  112. # Delete Milvus collection
  113. try:
  114. self.set_collection(collection_name)
  115. res = self.collection.delete(expr,partition_name)
  116. LOGGER.debug("Successfully delete record")
  117. return res
  118. except Exception as e:
  119. LOGGER.error(f"Failed to delete record: {e}")
  120. sys.exit(1)
  121. def search_vectors(self, collection_name,partition_names,im_vector, top_k):
  122. try:
  123. self.set_collection(collection_name)
  124. search_params = {"metric_type": METRIC_TYPE, "params": {"nprobe": 16}}
  125. res = self.collection.search(
  126. im_vector,
  127. anns_field="im_vector",
  128. param=search_params,
  129. limit=top_k,
  130. expr= None,
  131. # expr= "product_id like \"63a6dd57cd3dd570bb943e81\"",
  132. partition_names=partition_names,
  133. output_fields=["product_id"])
  134. LOGGER.debug(f"Successfully search in collection: {res}")
  135. return res
  136. except Exception as e:
  137. LOGGER.error(f"Failed to search vectors in Milvus: {e}")
  138. sys.exit(1)
  139. def count(self, collection_name):
  140. # Get the number of milvus collection
  141. try:
  142. self.set_collection(collection_name)
  143. num = self.collection.num_entities
  144. LOGGER.debug(f"Successfully get the num:{num} of the collection:{collection_name}")
  145. return num
  146. except Exception as e:
  147. LOGGER.error(f"Failed to count vectors in Milvus: {e}")
  148. sys.exit(1)