image_search.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. import faiss
  2. import numpy as np
  3. from typing import List, Tuple, Optional
  4. # import torch
  5. from pymongo import MongoClient
  6. import datetime
  7. import time
  8. import requests
  9. import config
  10. import os
  11. def RequestAliyunFc(url):
  12. # 定义URL
  13. api = config.ALIYUN_FUNCTION_URL # 替换为实际的URL
  14. print("func api:",api)
  15. # 定义参数
  16. payload = {'url': url}
  17. print("payload", payload)
  18. # headers = {
  19. # 'Content-Type': 'application/json' # 明确指定请求头为JSON
  20. # }
  21. # 发送POST请求
  22. response = requests.post(api, json=payload)
  23. # 检查请求是否成功
  24. if response.status_code == 200:
  25. # 获取返回的JSON数据
  26. json_data = response.json()
  27. return json_data.get("vector")
  28. print(f'请求失败,状态码: {response.text}')
  29. return None
  30. class ImageSearchEngine:
  31. def __init__(self):
  32. mongodburi = os.getenv("MONGODB")
  33. if not mongodburi:
  34. mongodburi = "mongodb://root:hangzhou_manage_mat_2025@localhost:37018/"
  35. # 添加mongodb
  36. self.mongo_client = MongoClient(mongodburi) # MongoDB 连接字符串
  37. self.mongo_db = self.mongo_client["faiss_index"] # 数据库名称
  38. self.mongo_collection = self.mongo_db["mat_vectors"] # 集合名称
  39. self.mongo_collection.create_index([("product_id", 1)], unique=True)
  40. self.mongo_collection.create_index([("faiss_id", 1)], unique=True)
  41. # 初始化一个id生成计数
  42. self.faiss_id_max = 0
  43. self.dimension = 2048
  44. # 检查GPU是否可用(仅用于PyTorch模型)
  45. # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  46. # print(f"使用设备: {self.device}")
  47. # 改为支持删除的索引
  48. base_index = faiss.IndexFlatL2(self.dimension)
  49. self.index = faiss.IndexIDMap(base_index)
  50. # 尝试加载现有索引,如果不存在则创建新索引
  51. if self._load_index():
  52. print("成功加载现有索引")
  53. def _batch_generator(self, cursor, batch_size):
  54. """从MongoDB游标中分批生成数据"""
  55. batch = []
  56. for doc in cursor:
  57. batch.append(doc)
  58. if len(batch) == batch_size:
  59. yield batch
  60. batch = []
  61. if batch:
  62. yield batch
  63. def _process_image(self, image_url: str):
  64. try:
  65. # 这里调用阿里云云函数处理图片
  66. return RequestAliyunFc(image_url)
  67. except Exception as e:
  68. print(f"处理图片时出错: {e}")
  69. return None
  70. def add_image_from_url(self, image_path: str, product_id: str) -> bool:
  71. """从URL添加图片到索引。
  72. Args:
  73. url: 图片URL
  74. product_id: 图片对应的商品ID
  75. Returns:
  76. 添加成功返回True,失败返回False
  77. """
  78. try:
  79. # 使用原有的特征提取逻辑
  80. feature = self._process_image(image_path)
  81. if feature is None:
  82. print("无法提取特征")
  83. return False
  84. # 转换为numpy数组并添加到索引
  85. # feature_np = feature.cpu().numpy().reshape(1, -1).astype('float32')
  86. feature_np = np.array(feature).reshape(1, -1).astype('float32')
  87. idx = self.faiss_id_max + 1
  88. print(f"当前: idx { idx }")
  89. if not isinstance(idx, int) or idx <= 0:
  90. print("ID生成失败")
  91. return False
  92. self.faiss_id_max = idx
  93. #走upsert的逻辑
  94. # 向数据库写入记录
  95. record = {
  96. "faiss_id": idx,
  97. "product_id": product_id,
  98. "vector": feature_np.flatten().tolist(), # 将numpy数组转为列表
  99. "created_at": datetime.datetime.utcnow() # 记录创建时间
  100. }
  101. self.mongo_collection.insert_one(record)
  102. # 为向量设置ID并添加到Faiss索引
  103. self.index.add_with_ids(feature_np, np.array([idx], dtype='int64'))
  104. print(f"已添加图片: product_id: {product_id}, faiss_id: {idx}")
  105. return True
  106. except Exception as e:
  107. print(f"添加图片时出错: {e}")
  108. return False
  109. def get_product_id_by_faiss_id(self, faiss_id: int) -> Optional[str]:
  110. """根据 faiss_id 查找 MongoDB 中的 product_id。
  111. Args:
  112. faiss_id: Faiss 索引中的 ID
  113. Returns:
  114. 对应的 product_id,如果未找到则返回 None
  115. """
  116. try:
  117. faiss_id = int(faiss_id)
  118. # 检查 faiss_id 是否有效
  119. if faiss_id < 0:
  120. print(f"无效的 faiss_id: {faiss_id}")
  121. return None
  122. # 查询 MongoDB
  123. query = {"faiss_id": faiss_id}
  124. record = self.mongo_collection.find_one(query)
  125. # 检查是否找到记录
  126. if record is None:
  127. print(f"未找到 faiss_id 为 {faiss_id} 的记录")
  128. return None
  129. # 返回 product_id
  130. product_id = record.get("product_id")
  131. if product_id is None:
  132. print(f"记录中缺少 product_id 字段: {record}")
  133. return None
  134. return str(product_id) # 确保返回字符串类型
  135. except Exception as e:
  136. print(f"查询 faiss_id 为 {faiss_id} 的记录时出错: {e}")
  137. return None
  138. def search(self, image_path: str = None, top_k: int = 5) -> List[Tuple[str, float]]:
  139. try:
  140. if image_path is None:
  141. print("搜索图片下载失败!")
  142. return []
  143. feature = self._process_image(image_path)
  144. if feature is None:
  145. print("无法提取查询图片的特征")
  146. return []
  147. # 将特征转换为numpy数组
  148. fn = np.array(feature)
  149. feature_np = fn.reshape(1, -1).astype('float32')
  150. start_vector_time = time.time()
  151. # 搜索最相似的图片
  152. distances, indices = self.index.search(feature_np, min(top_k, self.index.ntotal))
  153. end_vector_time = time.time()
  154. print(f"搜索vector耗时: {end_vector_time - start_vector_time}")
  155. start_other_time = time.time()
  156. # 返回结果
  157. results = []
  158. for faiss_id, dist in zip(indices[0], distances[0]):
  159. if faiss_id == -1: # Faiss返回-1表示无效结果
  160. continue
  161. # 将距离转换为相似度分数(0-1之间,1表示完全相似)
  162. similarity = 1.0 / (1.0 + dist)
  163. # 根据faiss_id获取product_id
  164. print(f"搜索结果->faiss_id: { faiss_id }")
  165. product_id = self.get_product_id_by_faiss_id(faiss_id)
  166. if product_id:
  167. results.append((product_id, similarity))
  168. end_other_time = time.time()
  169. print(f"查询结果耗时: {end_other_time - start_other_time}")
  170. return results
  171. except Exception as e:
  172. print(f"搜索图片时出错: {e}")
  173. return []
  174. def _load_index(self) -> bool:
  175. """从数据库分批加载数据并初始化faiss_id_max"""
  176. try:
  177. # 配置参数
  178. BATCH_SIZE = 10000
  179. # 获取文档总数
  180. total_docs = self.mongo_collection.count_documents({})
  181. if total_docs == 0:
  182. print("数据库为空,跳过索引加载")
  183. return True # 空数据库不算错误
  184. # 用于跟踪最大ID(兼容空数据情况)
  185. max_faiss_id = -1
  186. # 分批加载数据
  187. cursor = self.mongo_collection.find({}).batch_size(BATCH_SIZE)
  188. for batch in self._batch_generator(cursor, BATCH_SIZE):
  189. # 处理批次数据
  190. batch_vectors = []
  191. batch_ids = []
  192. current_max = -1
  193. for doc in batch:
  194. try:
  195. # 数据校验
  196. if len(doc['vector']) != self.dimension:
  197. continue
  198. if not isinstance(doc['faiss_id'], int):
  199. continue
  200. # 提取数据
  201. faiss_id = int(doc['faiss_id'])
  202. vector = doc['vector']
  203. print(f"load faiss_id :{ faiss_id }")
  204. # 更新最大值
  205. if faiss_id > current_max:
  206. current_max = faiss_id
  207. # 收集数据
  208. batch_vectors.append(vector)
  209. batch_ids.append(faiss_id)
  210. except Exception as e:
  211. print(f"文档处理异常: {str(e)}")
  212. continue
  213. # 批量添加到索引
  214. if batch_vectors:
  215. vectors_np = np.array(batch_vectors, dtype='float32')
  216. ids_np = np.array(batch_ids, dtype='int64')
  217. self.index.add_with_ids(vectors_np, ids_np)
  218. # 更新全局最大值
  219. if current_max > max_faiss_id:
  220. max_faiss_id = current_max
  221. print(f"向量总数: {self.index.ntotal}")
  222. # 设置初始值(如果已有更大值则保留)
  223. if max_faiss_id != -1:
  224. new_id = max_faiss_id
  225. self.faiss_id_max = new_id
  226. print(f"ID计数器初始化完成,当前值: {new_id}")
  227. return True
  228. except Exception as e:
  229. print(f"索引加载失败: {str(e)}")
  230. return False
  231. def clear(self) -> bool:
  232. """清除所有索引和 MongoDB 中的记录。
  233. Returns:
  234. 清除成功返回 True,失败返回 False
  235. """
  236. try:
  237. # 检查索引是否支持重置操作
  238. if not hasattr(self.index, "reset"):
  239. print("当前索引不支持重置操作")
  240. return False
  241. # 重置 Faiss 索引
  242. self.index.reset()
  243. print("已清除 Faiss 索引中的所有向量")
  244. # 删除 MongoDB 中的所有记录
  245. result = self.mongo_collection.delete_many({})
  246. print(f"已从 MongoDB 中删除 {result.deleted_count} 条记录")
  247. self.faiss_id_max = 0
  248. return True
  249. except Exception as e:
  250. print(f"清除索引时出错: {e}")
  251. return False
  252. def remove_by_product_id(self, product_id: str) -> bool:
  253. """通过 product_id 删除向量索引和数据库记录。
  254. Args:
  255. product_id: 要删除的商品 ID
  256. Returns:
  257. 删除成功返回 True,失败返回 False
  258. """
  259. try:
  260. # 检查 product_id 是否有效
  261. if not product_id or not isinstance(product_id, str):
  262. print(f"无效的 product_id: {product_id}")
  263. return False
  264. # 查询 MongoDB 获取 faiss_id
  265. query = {"product_id": product_id}
  266. record = self.mongo_collection.find_one(query)
  267. # 检查是否找到记录
  268. if record is None:
  269. print(f"未找到 product_id 为 {product_id} 的记录")
  270. return False
  271. # 提取 faiss_id
  272. faiss_id = record.get("faiss_id")
  273. if faiss_id is None:
  274. print(f"记录中缺少 faiss_id 字段: {record}")
  275. return False
  276. # 删除 Faiss 索引中的向量
  277. if isinstance(self.index, faiss.IndexIDMap):
  278. # 检查 faiss_id 是否在索引中
  279. # ids = self.index.id_map.at(1) # 获取所有 ID
  280. # if faiss_id not in ids:
  281. # print(f"faiss_id {faiss_id} 不在索引中")
  282. # return False
  283. # 删除向量
  284. self.index.remove_ids(np.array([faiss_id], dtype='int64'))
  285. print(f"已从 Faiss 索引中删除 faiss_id: {faiss_id}")
  286. else:
  287. print("当前索引不支持删除操作")
  288. return False
  289. # 删除 MongoDB 中的记录
  290. result = self.mongo_collection.delete_one({"faiss_id": faiss_id})
  291. if result.deleted_count == 1:
  292. print(f"已从 MongoDB 中删除 faiss_id: {faiss_id}")
  293. return True
  294. else:
  295. print(f"未找到 faiss_id 为 {faiss_id} 的记录")
  296. return False
  297. except Exception as e:
  298. print(f"删除 product_id 为 {product_id} 的记录时出错: {e}")
  299. return False
  300. def get_index_size(self) -> int:
  301. """获取索引中的图片数量。
  302. Returns:
  303. 索引中的图片数量
  304. """
  305. return len(self.image_paths)