image_search.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import faiss
  2. import numpy as np
  3. from PIL import Image
  4. import io
  5. import os
  6. from typing import List, Tuple, Optional, Union
  7. import torch
  8. import torchvision.transforms as transforms
  9. import torchvision.models as models
  10. from torchvision.models import ResNet50_Weights
  11. from scipy import ndimage
  12. import torch.nn.functional as F
  13. from pymongo import MongoClient
  14. import datetime
  15. import time
  16. class ImageSearchEngine:
  17. def __init__(self):
  18. # 强制使用CPU设备
  19. self.device = torch.device("cpu")
  20. print(f"使用设备: {self.device}")
  21. # 定义基础预处理转换
  22. self.base_transform = transforms.Compose([
  23. transforms.Grayscale(num_output_channels=3), # 转换为灰度图但保持3通道
  24. transforms.ToTensor(),
  25. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  26. ])
  27. # 加载预训练的ResNet模型并保持全精度
  28. self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
  29. # 移除最后的全连接层
  30. self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
  31. self.model = self.model.float().to(self.device).eval()
  32. # 初始化FAISS索引(2048是ResNet50的特征维度)
  33. self.dimension = 2048
  34. # self.index = faiss.IndexFlatL2(self.dimension)
  35. def _batch_generator(self, cursor, batch_size):
  36. """从MongoDB游标中分批生成数据"""
  37. batch = []
  38. for doc in cursor:
  39. batch.append(doc)
  40. if len(batch) == batch_size:
  41. yield batch
  42. batch = []
  43. if batch:
  44. yield batch
  45. def _process_image(self, image_path: str) -> Optional[torch.Tensor]:
  46. """处理单张图片并提取特征。
  47. Args:
  48. image_path: 图片路径
  49. Returns:
  50. 处理后的特征向量,如果处理失败返回None
  51. """
  52. try:
  53. # 读取图片
  54. image = Image.open(image_path)
  55. # 确保图片是RGB模式
  56. if image.mode != 'RGB':
  57. image = image.convert('RGB')
  58. start_ms_time = time.time()
  59. # 提取多尺度特征
  60. multi_scale_features = self._extract_multi_scale_features(image)
  61. end_ms_time = time.time()
  62. print(f"提取多尺度特征耗时: { end_ms_time - start_ms_time } s",)
  63. if multi_scale_features is None:
  64. return None
  65. start_sw_time = time.time()
  66. # 提取滑动窗口特征
  67. sliding_window_features = self._extract_sliding_window_features(image)
  68. end_sw_time = time.time()
  69. print(f"提取滑动窗口耗时: { end_sw_time - start_sw_time } s",)
  70. if sliding_window_features is None:
  71. return None
  72. # 组合特征(加权平均)
  73. combined_feature = multi_scale_features * 0.6 + sliding_window_features * 0.4
  74. # 标准化特征
  75. combined_feature = F.normalize(combined_feature, p=2, dim=0)
  76. return combined_feature
  77. except Exception as e:
  78. print(f"处理图片时出错: {e}")
  79. return None
  80. def _extract_multi_scale_features(self, image: Image.Image) -> Optional[torch.Tensor]:
  81. """基于原图分辨率的多尺度特征提取(智能动态调整版)
  82. Args:
  83. image: PIL图片对象
  84. Returns:
  85. 多尺度特征向量,处理失败返回None
  86. """
  87. try:
  88. # 三重精度保障
  89. self.model = self.model.float()
  90. # 获取原图信息
  91. orig_w, orig_h = image.size
  92. max_edge = max(orig_w, orig_h)
  93. aspect_ratio = orig_w / orig_h
  94. # 动态调整策略 -------------------------------------------
  95. # 策略1:根据最大边长确定基准尺寸
  96. base_size = min(max_edge, 2048) # 最大尺寸限制
  97. # 策略2:自动生成窗口尺寸(等比数列)
  98. min_size = 224 # 最小特征尺寸
  99. num_scales = 3 # 采样点数
  100. scale_factors = np.logspace(0, 1, num_scales, base=2)
  101. window_sizes = [int(base_size * f) for f in scale_factors]
  102. window_sizes = sorted({min(max(s, min_size), 2048) for s in window_sizes})
  103. # 策略3:根据长宽比调整尺寸组合
  104. if aspect_ratio > 1.5: # 宽幅图像
  105. window_sizes = [int(s*aspect_ratio) for s in window_sizes]
  106. elif aspect_ratio < 0.67: # 竖幅图像
  107. window_sizes = [int(s/aspect_ratio) for s in window_sizes]
  108. # 预处理优化 --------------------------------------------
  109. # 选择最优基准尺寸(最接近原图尺寸的2的幂次)
  110. base_size = 2 ** int(np.log2(base_size))
  111. base_transform = transforms.Compose([
  112. transforms.Resize((base_size, base_size),
  113. interpolation=transforms.InterpolationMode.LANCZOS),
  114. self.base_transform
  115. ])
  116. #
  117. img_base = base_transform(image).unsqueeze(0).to(torch.float32).to(self.device)
  118. # 动态特征提取 ------------------------------------------
  119. features = []
  120. for size in window_sizes:
  121. # 保持长宽比的重采样
  122. target_size = (int(size*aspect_ratio), size) if aspect_ratio > 1 else (size, int(size/aspect_ratio))
  123. # CPU兼容的插值
  124. img_tensor = torch.nn.functional.interpolate(
  125. img_base,
  126. size=target_size,
  127. mode='bilinear',
  128. align_corners=False
  129. ).to(torch.float32)
  130. # 自适应归一化(保持原图统计特性)
  131. if hasattr(self, 'adaptive_normalize'):
  132. img_tensor = self.adaptive_normalize(img_tensor)
  133. # 混合精度推理
  134. with torch.no_grad():
  135. feature = self.model(img_tensor).to(torch.float32)
  136. features.append(feature.squeeze().float())
  137. # 动态权重分配 ------------------------------------------
  138. # 基于尺寸差异的权重(尺寸越接近原图权重越高)
  139. size_diffs = [abs(size - base_size) for size in window_sizes]
  140. weights = 1 / (torch.tensor(size_diffs, device=self.device) + 1e-6)
  141. weights = weights / weights.sum()
  142. # 加权融合
  143. final_feature = torch.stack([f * w for f, w in zip(features, weights)]).sum(dim=0)
  144. return final_feature
  145. except Exception as e:
  146. print(f"智能特征提取失败: {e}")
  147. return None
  148. def _extract_sliding_window_features(self, image: Image.Image) -> Optional[torch.Tensor]:
  149. """优化版滑动窗口特征提取(动态调整+批量处理)
  150. Args:
  151. image: PIL图片对象
  152. Returns:
  153. 滑动窗口特征向量,处理失败返回None
  154. """
  155. try:
  156. # 三重精度保障
  157. self.model = self.model.float()
  158. # 获取原图信息
  159. orig_w, orig_h = image.size
  160. aspect_ratio = orig_w / orig_h
  161. # 动态窗口配置 -------------------------------------------
  162. # 根据原图尺寸自动选择关键窗口尺寸(示例逻辑,需根据实际调整)
  163. max_dim = max(orig_w, orig_h)
  164. window_sizes = sorted({
  165. int(2 ** np.round(np.log2(max_dim * 0.1))), # 约10%尺寸
  166. int(2 ** np.floor(np.log2(max_dim * 0.5))), # 约50%尺寸
  167. int(2 ** np.ceil(np.log2(max_dim))) # 接近原图尺寸
  168. } & {256, 512, 1024, 2048}) # 与预设尺寸取交集
  169. # 智能步长调整(窗口尺寸越大步长越大)
  170. stride_ratios = {256:0.5, 512:0.4, 1024:0.3, 2048:0.2}
  171. # 预处理优化 --------------------------------------------
  172. # 生成基准图像(最大窗口尺寸)
  173. max_win_size = max(window_sizes)
  174. base_size = (int(max_win_size * aspect_ratio), max_win_size) if aspect_ratio > 1 else \
  175. (max_win_size, int(max_win_size / aspect_ratio))
  176. transform = transforms.Compose([
  177. transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.BILINEAR),
  178. self.base_transform
  179. ])
  180. base_img = transform(image).to(torch.float32).to(self.device)
  181. # 批量特征提取 ------------------------------------------
  182. all_features = []
  183. for win_size in window_sizes:
  184. # 动态步长选择
  185. stride = int(win_size * stride_ratios.get(win_size, 0.3))
  186. # 生成窗口坐标(考虑边缘填充)
  187. h, w = base_img.shape[1:]
  188. num_h = (h - win_size) // stride + 1
  189. num_w = (w - win_size) // stride + 1
  190. # 调整窗口数量上限(防止显存溢出)
  191. MAX_WINDOWS = 16 # 最大窗口数
  192. if num_h * num_w > MAX_WINDOWS:
  193. stride = int(np.sqrt(h * w * win_size**2 / MAX_WINDOWS))
  194. num_h = (h - win_size) // stride + 1
  195. num_w = (w - win_size) // stride + 1
  196. # 批量裁剪窗口
  197. windows = []
  198. for i in range(num_h):
  199. for j in range(num_w):
  200. top = i * stride
  201. left = j * stride
  202. window = base_img[:, top:top+win_size, left:left+win_size]
  203. windows.append(window)
  204. if not windows:
  205. continue
  206. # 批量处理(自动分块防止OOM)
  207. BATCH_SIZE = 4 # 批处理大小
  208. with torch.no_grad():
  209. for i in range(0, len(windows), BATCH_SIZE):
  210. batch = torch.stack(windows[i:i+BATCH_SIZE]).to(torch.float32)
  211. features = self.model(batch).to(torch.float32)
  212. all_features.append(features.cpu().float()) # 转移至CPU释放显存
  213. # 特征融合 ---------------------------------------------
  214. if not all_features:
  215. return None
  216. final_feature = torch.cat([f.view(-1, f.shape[-1]) for f in all_features], dim=0)
  217. final_feature = final_feature.mean(dim=0).to(self.device)
  218. return final_feature.float()
  219. except Exception as e:
  220. print(f"滑动窗口特征提取失败: {e}")
  221. return None