milvus_helpers.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import sys
  2. from config import MILVUS_HOST, MILVUS_PORT, VECTOR_DIMENSION, METRIC_TYPE
  3. from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
  4. from logs import LOGGER
  5. class MilvusHelper:
  6. """
  7. MilvusHelper class to manager the Milvus Collection.
  8. Args:
  9. host (`str`):
  10. Milvus server Host.
  11. port (`str|int`):
  12. Milvus server port.
  13. ...
  14. """
  15. def __init__(self, host=MILVUS_HOST, port=MILVUS_PORT):
  16. try:
  17. self.collection = None
  18. connections.connect(host=host, port=port)
  19. LOGGER.debug(f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}")
  20. except Exception as e:
  21. LOGGER.error(f"Failed to connect Milvus: {e}")
  22. sys.exit(1)
  23. def set_collection(self, collection_name):
  24. try:
  25. self.collection = Collection(name=collection_name)
  26. except Exception as e:
  27. LOGGER.error(f"Failed to load data to Milvus: {e}")
  28. sys.exit(1)
  29. def has_collection(self, collection_name):
  30. # Return if Milvus has the collection
  31. try:
  32. return utility.has_collection(collection_name)
  33. except Exception as e:
  34. LOGGER.error(f"Failed to load data to Milvus: {e}")
  35. sys.exit(1)
  36. def create_collection(self, collection_name):
  37. # Create milvus collection if not exists
  38. try:
  39. if not self.has_collection(collection_name):
  40. field1 = FieldSchema(name='id', dtype=DataType.VARCHAR, descrition='id to image', max_length=500,
  41. is_primary=True, auto_id=False)
  42. field2 = FieldSchema(name='path', dtype=DataType.VARCHAR, descrition='path to image', max_length=500,
  43. is_primary=False, auto_id=False)
  44. field3 = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, descrition="image embedding vectors",
  45. dim=VECTOR_DIMENSION, is_primary=False)
  46. schema = CollectionSchema(fields=[field1, field2,field3], description="collection_name: "+collection_name)
  47. self.collection = Collection(name=collection_name, schema=schema)
  48. self.create_index(collection_name)
  49. LOGGER.debug(f"Create Milvus collection: {collection_name}")
  50. else:
  51. self.set_collection(collection_name)
  52. return "OK"
  53. except Exception as e:
  54. LOGGER.error(f"Failed to load data to Milvus: {e}")
  55. sys.exit(1)
  56. def insert(self, collection_name, id,path, vectors):
  57. # Batch insert vectors to milvus collection
  58. try:
  59. data = [id,path, vectors]
  60. self.set_collection(collection_name)
  61. mr = self.collection.insert(data)
  62. ids = mr.primary_keys
  63. self.collection.load()
  64. LOGGER.debug(
  65. f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows")
  66. return ids
  67. except Exception as e:
  68. LOGGER.error(f"Failed to load data to Milvus: {e}")
  69. sys.exit(1)
  70. def create_index(self, collection_name):
  71. # Create IVF_FLAT index on milvus collection
  72. try:
  73. self.set_collection(collection_name)
  74. default_index = {"index_type": "IVF_SQ8", "metric_type": METRIC_TYPE, "params": {"nlist": 16384}}
  75. status = self.collection.create_index(field_name="embedding", index_params=default_index)
  76. if not status.code:
  77. LOGGER.debug(
  78. f"Successfully create index in collection:{collection_name} with param:{default_index}")
  79. return status
  80. else:
  81. raise Exception(status.message)
  82. except Exception as e:
  83. LOGGER.error(f"Failed to create index: {e}")
  84. sys.exit(1)
  85. def delete_collection(self, collection_name):
  86. # Delete Milvus collection
  87. try:
  88. self.set_collection(collection_name)
  89. self.collection.drop()
  90. LOGGER.debug("Successfully drop collection!")
  91. return "ok"
  92. except Exception as e:
  93. LOGGER.error(f"Failed to drop collection: {e}")
  94. sys.exit(1)
  95. def search_vectors(self, collection_name, vectors, top_k):
  96. # Search vector in milvus collection
  97. try:
  98. self.set_collection(collection_name)
  99. search_params = {"metric_type": METRIC_TYPE, "params": {"nprobe": 16}}
  100. res = self.collection.search(vectors, anns_field="embedding", param=search_params, limit=top_k)
  101. LOGGER.debug(f"Successfully search in collection: {res}")
  102. return res
  103. except Exception as e:
  104. LOGGER.error(f"Failed to search vectors in Milvus: {e}")
  105. sys.exit(1)
  106. def count(self, collection_name):
  107. # Get the number of milvus collection
  108. try:
  109. self.set_collection(collection_name)
  110. num = self.collection.num_entities
  111. LOGGER.debug(f"Successfully get the num:{num} of the collection:{collection_name}")
  112. return num
  113. except Exception as e:
  114. LOGGER.error(f"Failed to count vectors in Milvus: {e}")
  115. sys.exit(1)