generate.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import pymongo
  2. import numpy as np
  3. import faiss
  4. import time
  5. from bson.objectid import ObjectId
  6. # MongoDB连接配置
  7. client = pymongo.MongoClient("mongodb://root:faiss_image_search@localhost:27017/")
  8. db = client["faiss_index"]
  9. collection = db["mat_vectors"]
  10. collection.create_index([("product_id", 1)], unique=True)
  11. collection.create_index([("faiss_id", 1)], unique=True)
  12. # FAISS配置
  13. dimension = 2048
  14. base_index = faiss.IndexFlatL2(dimension)
  15. index = faiss.IndexIDMap(base_index)
  16. # 生成随机向量
  17. def generate_random_vector(dimension):
  18. return np.random.random(dimension).astype('float32')
  19. # 插入100万条数据
  20. def insert_million_records():
  21. batch_size = 10000 # 每批插入的数据量
  22. total_records = 200000
  23. start_time = time.time()
  24. for i in range(0, total_records, batch_size):
  25. batch = []
  26. for j in range(batch_size):
  27. faiss_id = i + j
  28. vector = generate_random_vector(dimension)
  29. index.add_with_ids(np.array([vector]), np.array([faiss_id]))
  30. record = {
  31. "_id": ObjectId(),
  32. "product_id": ObjectId(),
  33. "faiss_id": faiss_id,
  34. "vector": vector.tolist()
  35. }
  36. batch.append(record)
  37. collection.insert_many(batch)
  38. print(f"Inserted {i + batch_size} records")
  39. end_time = time.time()
  40. print(f"Total time taken: {end_time - start_time} seconds")
  41. if __name__ == "__main__":
  42. insert_million_records()