image_search.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781
  1. import faiss
  2. import numpy as np
  3. from PIL import Image
  4. import io
  5. import os
  6. from typing import List, Tuple, Optional, Union
  7. import torch
  8. import torchvision.transforms as transforms
  9. import torchvision.models as models
  10. from torchvision.models import ResNet50_Weights
  11. from scipy import ndimage
  12. import torch.nn.functional as F
  13. from pymongo import MongoClient
  14. import datetime
  15. import time
  16. class ImageSearchEngine:
  17. def __init__(self):
  18. # 添加mongodb
  19. self.mongo_client = MongoClient("mongodb://root:faiss_image_search@localhost:27017/") # MongoDB 连接字符串
  20. self.mongo_db = self.mongo_client["faiss_index"] # 数据库名称
  21. self.mongo_collection = self.mongo_db["mat_vectors"] # 集合名称
  22. self.mongo_collection.create_index([("product_id", 1)], unique=True)
  23. self.mongo_collection.create_index([("faiss_id", 1)], unique=True)
  24. # 初始化一个id生成计数
  25. self.faiss_id_max = 0
  26. # 检查GPU是否可用(仅用于PyTorch模型)
  27. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  28. print(f"使用设备: {self.device}")
  29. # 定义基础预处理转换
  30. self.base_transform = transforms.Compose([
  31. transforms.Grayscale(num_output_channels=3), # 转换为灰度图但保持3通道
  32. transforms.ToTensor(),
  33. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  34. ])
  35. # 加载预训练的ResNet模型
  36. self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
  37. # 移除最后的全连接层
  38. self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
  39. self.model = self.model.to(self.device)
  40. self.model.eval()
  41. # 初始化FAISS索引(2048是ResNet50的特征维度)
  42. self.dimension = 2048
  43. # self.index = faiss.IndexFlatL2(self.dimension)
  44. # 改为支持删除的索引
  45. base_index = faiss.IndexFlatL2(self.dimension)
  46. self.index = faiss.IndexIDMap(base_index)
  47. # 尝试加载现有索引,如果不存在则创建新索引
  48. if self._load_index():
  49. print("成功加载现有索引")
  50. def _batch_generator(self, cursor, batch_size):
  51. """从MongoDB游标中分批生成数据"""
  52. batch = []
  53. for doc in cursor:
  54. batch.append(doc)
  55. if len(batch) == batch_size:
  56. yield batch
  57. batch = []
  58. if batch:
  59. yield batch
  60. def _process_image(self, image_path: str) -> Optional[torch.Tensor]:
  61. # 这里调用阿里云云函数处理图片
  62. """处理单张图片并提取特征。
  63. Args:
  64. image_path: 图片路径
  65. Returns:
  66. 处理后的特征向量,如果处理失败返回None
  67. """
  68. try:
  69. # 读取图片
  70. image = Image.open(image_path)
  71. # 确保图片是RGB模式
  72. if image.mode != 'RGB':
  73. image = image.convert('RGB')
  74. start_ms_time = time.time()
  75. # 提取多尺度特征
  76. multi_scale_features = self._extract_multi_scale_features(image)
  77. end_ms_time = time.time()
  78. print(f"提取多尺度特征耗时: { end_ms_time - start_ms_time } s",)
  79. if multi_scale_features is None:
  80. return None
  81. start_sw_time = time.time()
  82. # 提取滑动窗口特征
  83. sliding_window_features = self._extract_sliding_window_features(image)
  84. end_sw_time = time.time()
  85. print(f"提取滑动窗口耗时: { end_sw_time - start_sw_time } s",)
  86. if sliding_window_features is None:
  87. return None
  88. # 组合特征(加权平均)
  89. combined_feature = multi_scale_features * 0.6 + sliding_window_features * 0.4
  90. # 标准化特征
  91. combined_feature = F.normalize(combined_feature, p=2, dim=0)
  92. return combined_feature
  93. except Exception as e:
  94. print(f"处理图片时出错: {e}")
  95. return None
  96. def _extract_multi_scale_features(self, image: Image.Image) -> Optional[torch.Tensor]:
  97. """基于原图分辨率的多尺度特征提取(智能动态调整版)
  98. Args:
  99. image: PIL图片对象
  100. Returns:
  101. 多尺度特征向量,处理失败返回None
  102. """
  103. try:
  104. # 获取原图信息
  105. orig_w, orig_h = image.size
  106. max_edge = max(orig_w, orig_h)
  107. aspect_ratio = orig_w / orig_h
  108. # 动态调整策略 -------------------------------------------
  109. # 策略1:根据最大边长确定基准尺寸
  110. base_size = min(max_edge, 3000) # 不超过模型支持的最大尺寸
  111. # 策略2:自动生成窗口尺寸(等比数列)
  112. min_size = 224 # 最小特征尺寸
  113. num_scales = 4 # 固定采样点数
  114. scale_factors = np.logspace(0, 1, num_scales, base=2)
  115. window_sizes = [int(base_size * f) for f in scale_factors]
  116. window_sizes = sorted({min(max(s, min_size), 3000) for s in window_sizes})
  117. # 策略3:根据长宽比调整尺寸组合
  118. if aspect_ratio > 1.5: # 宽幅图像
  119. window_sizes = [int(s*aspect_ratio) for s in window_sizes]
  120. elif aspect_ratio < 0.67: # 竖幅图像
  121. window_sizes = [int(s/aspect_ratio) for s in window_sizes]
  122. # 预处理优化 --------------------------------------------
  123. # 选择最优基准尺寸(最接近原图尺寸的2的幂次)
  124. base_size = 2 ** int(np.log2(base_size))
  125. base_transform = transforms.Compose([
  126. transforms.Resize((base_size, base_size),
  127. interpolation=transforms.InterpolationMode.LANCZOS),
  128. self.base_transform
  129. ])
  130. # 半精度加速
  131. self.model.half()
  132. img_base = base_transform(image).unsqueeze(0).to(self.device).half()
  133. # 动态特征提取 ------------------------------------------
  134. features = []
  135. for size in window_sizes:
  136. # 保持长宽比的重采样
  137. target_size = (int(size*aspect_ratio), size) if aspect_ratio > 1 else (size, int(size/aspect_ratio))
  138. # GPU加速的智能插值
  139. img_tensor = torch.nn.functional.interpolate(
  140. img_base,
  141. size=target_size,
  142. mode= 'area' if size < base_size else 'bicubic', # 下采样用area,上采样用bicubic
  143. align_corners=False
  144. )
  145. # 自适应归一化(保持原图统计特性)
  146. if hasattr(self, 'adaptive_normalize'):
  147. img_tensor = self.adaptive_normalize(img_tensor)
  148. # 混合精度推理
  149. with torch.no_grad(), torch.cuda.amp.autocast():
  150. feature = self.model(img_tensor)
  151. features.append(feature.squeeze().float())
  152. # 动态权重分配 ------------------------------------------
  153. # 基于尺寸差异的权重(尺寸越接近原图权重越高)
  154. size_diffs = [abs(size - base_size) for size in window_sizes]
  155. weights = 1 / (torch.tensor(size_diffs, device=self.device) + 1e-6)
  156. weights = weights / weights.sum()
  157. # 加权融合
  158. final_feature = torch.stack([f * w for f, w in zip(features, weights)]).sum(dim=0)
  159. return final_feature
  160. except Exception as e:
  161. print(f"智能特征提取失败: {e}")
  162. return None
  163. def _extract_multi_scale_features_bak(self, image: Image.Image) -> Optional[torch.Tensor]:
  164. """提取多尺度特征。
  165. Args:
  166. image: PIL图片对象
  167. Returns:
  168. 多尺度特征向量,如果处理失败返回None
  169. """
  170. try:
  171. features_list = []
  172. window_sizes = [256, 512,1024,1560,2048,2560,3000]
  173. # 多尺度转换 - 增加更多尺度
  174. #self.multi_scale_sizes = [224, 384, 512, 768, 1024, 1536,2048,3000]
  175. for size in window_sizes:
  176. # 调整图片大小
  177. transform = transforms.Compose([
  178. transforms.Resize((size, size), interpolation=transforms.InterpolationMode.LANCZOS),
  179. self.base_transform
  180. ])
  181. # 应用变换
  182. img_tensor = transform(image).unsqueeze(0).to(self.device)
  183. # 提取特征
  184. with torch.no_grad():
  185. feature = self.model(img_tensor)
  186. features_list.append(feature.squeeze())
  187. # 计算加权平均,较大尺度的权重更高
  188. weights = torch.linspace(1, 2, len(features_list)).to(self.device)
  189. weights = weights / weights.sum()
  190. weighted_features = torch.stack([f * w for f, w in zip(features_list, weights)])
  191. final_feature = weighted_features.sum(dim=0)
  192. return final_feature
  193. except Exception as e:
  194. print(f"提取多尺度特征时出错: {e}")
  195. return None
  196. def _extract_sliding_window_features(self, image: Image.Image) -> Optional[torch.Tensor]:
  197. """优化版滑动窗口特征提取(动态调整+批量处理)
  198. Args:
  199. image: PIL图片对象
  200. Returns:
  201. 滑动窗口特征向量,处理失败返回None
  202. """
  203. try:
  204. # 获取原图信息
  205. orig_w, orig_h = image.size
  206. aspect_ratio = orig_w / orig_h
  207. # 动态窗口配置 -------------------------------------------
  208. # 根据原图尺寸自动选择关键窗口尺寸(示例逻辑,需根据实际调整)
  209. max_dim = max(orig_w, orig_h)
  210. window_sizes = sorted({
  211. int(2 ** np.round(np.log2(max_dim * 0.1))), # 约10%尺寸
  212. int(2 ** np.floor(np.log2(max_dim * 0.5))), # 约50%尺寸
  213. int(2 ** np.ceil(np.log2(max_dim))) # 接近原图尺寸
  214. } & {256, 512, 1024, 2048, 3000}) # 与预设尺寸取交集
  215. # 智能步长调整(窗口尺寸越大步长越大)
  216. stride_ratios = {256:0.5, 512:0.4, 1024:0.3, 2048:0.2, 3000:0.15}
  217. # 预处理优化 --------------------------------------------
  218. # 生成基准图像(最大窗口尺寸)
  219. max_win_size = max(window_sizes)
  220. base_size = (int(max_win_size * aspect_ratio), max_win_size) if aspect_ratio > 1 else \
  221. (max_win_size, int(max_win_size / aspect_ratio))
  222. transform = transforms.Compose([
  223. transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.LANCZOS),
  224. self.base_transform
  225. ])
  226. base_img = transform(image).to(self.device)
  227. # 半精度加速
  228. self.model.half()
  229. base_img = base_img.half()
  230. # 批量特征提取 ------------------------------------------
  231. all_features = []
  232. for win_size in window_sizes:
  233. # 动态步长选择
  234. stride = int(win_size * stride_ratios.get(win_size, 0.3))
  235. # 生成窗口坐标(考虑边缘填充)
  236. h, w = base_img.shape[1:]
  237. num_h = (h - win_size) // stride + 1
  238. num_w = (w - win_size) // stride + 1
  239. # 调整窗口数量上限(防止显存溢出)
  240. MAX_WINDOWS = 32 # 根据显存调整
  241. if num_h * num_w > MAX_WINDOWS:
  242. stride = int(np.sqrt(h * w * win_size**2 / MAX_WINDOWS))
  243. num_h = (h - win_size) // stride + 1
  244. num_w = (w - win_size) // stride + 1
  245. # 批量裁剪窗口
  246. windows = []
  247. for i in range(num_h):
  248. for j in range(num_w):
  249. top = i * stride
  250. left = j * stride
  251. window = base_img[:, top:top+win_size, left:left+win_size]
  252. windows.append(window)
  253. if not windows:
  254. continue
  255. # 批量处理(自动分块防止OOM)
  256. BATCH_SIZE = 8 # 根据显存调整
  257. with torch.no_grad(), torch.cuda.amp.autocast():
  258. for i in range(0, len(windows), BATCH_SIZE):
  259. batch = torch.stack(windows[i:i+BATCH_SIZE])
  260. features = self.model(batch)
  261. all_features.append(features.cpu().float()) # 转移至CPU释放显存
  262. # 特征融合 ---------------------------------------------
  263. if not all_features:
  264. return None
  265. final_feature = torch.cat([f.view(-1, f.shape[-1]) for f in all_features], dim=0)
  266. final_feature = final_feature.mean(dim=0).to(self.device)
  267. return final_feature
  268. except Exception as e:
  269. print(f"滑动窗口特征提取失败: {e}")
  270. return None
  271. def _extract_sliding_window_features_bak(self, image: Image.Image) -> Optional[torch.Tensor]:
  272. """使用滑动窗口提取特征。
  273. Args:
  274. image: PIL图片对象
  275. Returns:
  276. 滑动窗口特征向量,如果处理失败返回None
  277. """
  278. try:
  279. window_sizes = [256, 512,1024,1560,2048,2560,3000]
  280. stride_ratio = 0.25 # 步长比例
  281. features_list = []
  282. for window_size in window_sizes:
  283. # 调整图片大小,保持宽高比
  284. aspect_ratio = image.size[0] / image.size[1]
  285. if aspect_ratio > 1:
  286. new_width = int(window_size * aspect_ratio)
  287. new_height = window_size
  288. else:
  289. new_width = window_size
  290. new_height = int(window_size / aspect_ratio)
  291. transform = transforms.Compose([
  292. transforms.Resize((new_height, new_width), interpolation=transforms.InterpolationMode.LANCZOS),
  293. self.base_transform
  294. ])
  295. # 转换图片
  296. img_tensor = transform(image)
  297. # 计算步长
  298. stride = int(window_size * stride_ratio)
  299. # 使用滑动窗口提取特征
  300. for i in range(0, img_tensor.size(1) - window_size + 1, stride):
  301. for j in range(0, img_tensor.size(2) - window_size + 1, stride):
  302. window = img_tensor[:, i:i+window_size, j:j+window_size].unsqueeze(0).to(self.device)
  303. with torch.no_grad():
  304. feature = self.model(window)
  305. features_list.append(feature.squeeze())
  306. # 如果没有提取到特征,返回None
  307. if not features_list:
  308. return None
  309. # 计算所有特征的平均值
  310. final_feature = torch.stack(features_list).mean(dim=0)
  311. return final_feature
  312. except Exception as e:
  313. print(f"提取滑动窗口特征时出错: {e}")
  314. return None
  315. def extract_features(self, img: Image.Image) -> np.ndarray:
  316. """结合多尺度和滑动窗口提取特征。
  317. Args:
  318. img: PIL图像对象
  319. Returns:
  320. 特征向量
  321. """
  322. try:
  323. # 提取多尺度特征
  324. multi_scale_features = self._extract_multi_scale_features(img)
  325. if multi_scale_features is None:
  326. raise ValueError("无法提取多尺度特征")
  327. # 提取滑动窗口特征
  328. sliding_window_features = self._extract_sliding_window_features(img)
  329. if sliding_window_features is None:
  330. raise ValueError("无法提取滑动窗口特征")
  331. # 组合特征
  332. combined_feature = multi_scale_features * 0.6 + sliding_window_features * 0.4
  333. # 标准化特征
  334. combined_feature = F.normalize(combined_feature, p=2, dim=0)
  335. # 转换为numpy数组
  336. return combined_feature.cpu().numpy()
  337. except Exception as e:
  338. print(f"特征提取失败: {e}")
  339. raise
  340. def add_image_from_url(self, image_path: str, product_id: str) -> bool:
  341. """从URL添加图片到索引。
  342. Args:
  343. url: 图片URL
  344. product_id: 图片对应的商品ID
  345. Returns:
  346. 添加成功返回True,失败返回False
  347. """
  348. try:
  349. # 使用原有的特征提取逻辑
  350. feature = self._process_image(image_path)
  351. if feature is None:
  352. print("无法提取特征")
  353. return False
  354. # 转换为numpy数组并添加到索引
  355. feature_np = feature.cpu().numpy().reshape(1, -1).astype('float32')
  356. idx = self.faiss_id_max + 1
  357. print(f"当前: idx { idx }")
  358. if not isinstance(idx, int) or idx <= 0:
  359. print("ID生成失败")
  360. return False
  361. self.faiss_id_max = idx
  362. # 向数据库写入记录
  363. record = {
  364. "faiss_id": idx,
  365. "product_id": product_id,
  366. "vector": feature_np.flatten().tolist(), # 将numpy数组转为列表
  367. "created_at": datetime.datetime.utcnow() # 记录创建时间
  368. }
  369. self.mongo_collection.insert_one(record)
  370. # 为向量设置ID并添加到Faiss索引
  371. self.index.add_with_ids(feature_np, np.array([idx], dtype='int64'))
  372. print(f"已添加图片: product_id: {product_id}, faiss_id: {idx}")
  373. return True
  374. except Exception as e:
  375. print(f"添加图片时出错: {e}")
  376. return False
  377. def get_product_id_by_faiss_id(self, faiss_id: int) -> Optional[str]:
  378. """根据 faiss_id 查找 MongoDB 中的 product_id。
  379. Args:
  380. faiss_id: Faiss 索引中的 ID
  381. Returns:
  382. 对应的 product_id,如果未找到则返回 None
  383. """
  384. try:
  385. faiss_id = int(faiss_id)
  386. # 检查 faiss_id 是否有效
  387. if faiss_id < 0:
  388. print(f"无效的 faiss_id: {faiss_id}")
  389. return None
  390. # 查询 MongoDB
  391. query = {"faiss_id": faiss_id}
  392. record = self.mongo_collection.find_one(query)
  393. # 检查是否找到记录
  394. if record is None:
  395. print(f"未找到 faiss_id 为 {faiss_id} 的记录")
  396. return None
  397. # 返回 product_id
  398. product_id = record.get("product_id")
  399. if product_id is None:
  400. print(f"记录中缺少 product_id 字段: {record}")
  401. return None
  402. return str(product_id) # 确保返回字符串类型
  403. except Exception as e:
  404. print(f"查询 faiss_id 为 {faiss_id} 的记录时出错: {e}")
  405. return None
  406. def search(self, image_path: str = None, top_k: int = 5) -> List[Tuple[str, float]]:
  407. try:
  408. if image_path is None:
  409. print("搜索图片下载失败!")
  410. return []
  411. feature = self._process_image(image_path)
  412. if feature is None:
  413. print("无法提取查询图片的特征")
  414. return []
  415. # 将特征转换为numpy数组
  416. feature_np = feature.cpu().numpy().reshape(1, -1).astype('float32')
  417. start_vector_time = time.time()
  418. # 搜索最相似的图片
  419. distances, indices = self.index.search(feature_np, min(top_k, self.index.ntotal))
  420. end_vector_time = time.time()
  421. print(f"搜索vector耗时: {end_vector_time - start_vector_time}")
  422. start_other_time = time.time()
  423. # 返回结果
  424. results = []
  425. for faiss_id, dist in zip(indices[0], distances[0]):
  426. if faiss_id == -1: # Faiss返回-1表示无效结果
  427. continue
  428. # 将距离转换为相似度分数(0-1之间,1表示完全相似)
  429. similarity = 1.0 / (1.0 + dist)
  430. # 根据faiss_id获取product_id
  431. print(f"搜索结果->faiss_id: { faiss_id }")
  432. product_id = self.get_product_id_by_faiss_id(faiss_id)
  433. if product_id:
  434. results.append((product_id, similarity))
  435. end_other_time = time.time()
  436. print(f"查询结果耗时: {end_other_time - start_other_time}")
  437. return results
  438. except Exception as e:
  439. print(f"搜索图片时出错: {e}")
  440. return []
  441. def _load_index(self) -> bool:
  442. """从数据库分批加载数据并初始化faiss_id_max"""
  443. try:
  444. # 配置参数
  445. BATCH_SIZE = 10000
  446. # 获取文档总数
  447. total_docs = self.mongo_collection.count_documents({})
  448. if total_docs == 0:
  449. print("数据库为空,跳过索引加载")
  450. return True # 空数据库不算错误
  451. # 用于跟踪最大ID(兼容空数据情况)
  452. max_faiss_id = -1
  453. # 分批加载数据
  454. cursor = self.mongo_collection.find({}).batch_size(BATCH_SIZE)
  455. for batch in self._batch_generator(cursor, BATCH_SIZE):
  456. # 处理批次数据
  457. batch_vectors = []
  458. batch_ids = []
  459. current_max = -1
  460. for doc in batch:
  461. try:
  462. # 数据校验
  463. if len(doc['vector']) != self.dimension:
  464. continue
  465. if not isinstance(doc['faiss_id'], int):
  466. continue
  467. # 提取数据
  468. faiss_id = int(doc['faiss_id'])
  469. vector = doc['vector']
  470. print(f"load faiss_id :{ faiss_id }")
  471. # 更新最大值
  472. if faiss_id > current_max:
  473. current_max = faiss_id
  474. # 收集数据
  475. batch_vectors.append(vector)
  476. batch_ids.append(faiss_id)
  477. except Exception as e:
  478. print(f"文档处理异常: {str(e)}")
  479. continue
  480. # 批量添加到索引
  481. if batch_vectors:
  482. vectors_np = np.array(batch_vectors, dtype='float32')
  483. ids_np = np.array(batch_ids, dtype='int64')
  484. self.index.add_with_ids(vectors_np, ids_np)
  485. # 更新全局最大值
  486. if current_max > max_faiss_id:
  487. max_faiss_id = current_max
  488. print(f"向量总数: {self.index.ntotal}")
  489. # 设置初始值(如果已有更大值则保留)
  490. if max_faiss_id != -1:
  491. new_id = max_faiss_id
  492. self.faiss_id_max = new_id
  493. print(f"ID计数器初始化完成,当前值: {new_id}")
  494. return True
  495. except Exception as e:
  496. print(f"索引加载失败: {str(e)}")
  497. return False
  498. def clear(self) -> bool:
  499. """清除所有索引和 MongoDB 中的记录。
  500. Returns:
  501. 清除成功返回 True,失败返回 False
  502. """
  503. try:
  504. # 检查索引是否支持重置操作
  505. if not hasattr(self.index, "reset"):
  506. print("当前索引不支持重置操作")
  507. return False
  508. # 重置 Faiss 索引
  509. self.index.reset()
  510. print("已清除 Faiss 索引中的所有向量")
  511. # 删除 MongoDB 中的所有记录
  512. result = self.mongo_collection.delete_many({})
  513. print(f"已从 MongoDB 中删除 {result.deleted_count} 条记录")
  514. self.faiss_id_max = 0
  515. return True
  516. except Exception as e:
  517. print(f"清除索引时出错: {e}")
  518. return False
  519. def remove_image(self, image_path: str) -> bool:
  520. """从索引中移除指定图片。
  521. Args:
  522. image_path: 要移除的图片路径
  523. Returns:
  524. 是否成功移除
  525. """
  526. try:
  527. if image_path in self.image_paths:
  528. idx = self.image_paths.index(image_path)
  529. # 创建新的索引
  530. new_index = faiss.IndexFlatL2(self.dimension)
  531. # 获取所有特征
  532. all_features = faiss.vector_to_array(self.index.get_xb()).reshape(-1, self.dimension)
  533. # 移除指定图片的特征
  534. mask = np.ones(len(self.image_paths), dtype=bool)
  535. mask[idx] = False
  536. filtered_features = all_features[mask]
  537. # 更新索引
  538. if len(filtered_features) > 0:
  539. new_index.add(filtered_features)
  540. # 更新图片路径列表
  541. self.image_paths.pop(idx)
  542. self.product_ids.pop(idx)
  543. # 更新索引
  544. self.index = new_index
  545. # 保存更改
  546. self._save_index()
  547. print(f"已移除图片: {image_path}")
  548. return True
  549. else:
  550. print(f"图片不存在: {image_path}")
  551. return False
  552. except Exception as e:
  553. print(f"移除图片时出错: {e}")
  554. return False
  555. def remove_by_product_id(self, product_id: str) -> bool:
  556. """通过 product_id 删除向量索引和数据库记录。
  557. Args:
  558. product_id: 要删除的商品 ID
  559. Returns:
  560. 删除成功返回 True,失败返回 False
  561. """
  562. try:
  563. # 检查 product_id 是否有效
  564. if not product_id or not isinstance(product_id, str):
  565. print(f"无效的 product_id: {product_id}")
  566. return False
  567. # 查询 MongoDB 获取 faiss_id
  568. query = {"product_id": product_id}
  569. record = self.mongo_collection.find_one(query)
  570. # 检查是否找到记录
  571. if record is None:
  572. print(f"未找到 product_id 为 {product_id} 的记录")
  573. return False
  574. # 提取 faiss_id
  575. faiss_id = record.get("faiss_id")
  576. if faiss_id is None:
  577. print(f"记录中缺少 faiss_id 字段: {record}")
  578. return False
  579. # 删除 Faiss 索引中的向量
  580. if isinstance(self.index, faiss.IndexIDMap):
  581. # 检查 faiss_id 是否在索引中
  582. # ids = self.index.id_map.at(1) # 获取所有 ID
  583. # if faiss_id not in ids:
  584. # print(f"faiss_id {faiss_id} 不在索引中")
  585. # return False
  586. # 删除向量
  587. self.index.remove_ids(np.array([faiss_id], dtype='int64'))
  588. print(f"已从 Faiss 索引中删除 faiss_id: {faiss_id}")
  589. else:
  590. print("当前索引不支持删除操作")
  591. return False
  592. # 删除 MongoDB 中的记录
  593. result = self.mongo_collection.delete_one({"faiss_id": faiss_id})
  594. if result.deleted_count == 1:
  595. print(f"已从 MongoDB 中删除 faiss_id: {faiss_id}")
  596. return True
  597. else:
  598. print(f"未找到 faiss_id 为 {faiss_id} 的记录")
  599. return False
  600. except Exception as e:
  601. print(f"删除 product_id 为 {product_id} 的记录时出错: {e}")
  602. traceback.print_exc()
  603. return False
  604. def get_index_size(self) -> int:
  605. """获取索引中的图片数量。
  606. Returns:
  607. 索引中的图片数量
  608. """
  609. return len(self.image_paths)