image_search.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import numpy as np
  2. from PIL import Image
  3. from typing import Optional
  4. import torch
  5. import torchvision.transforms as transforms
  6. import torchvision.models as models
  7. from torchvision.models import ResNet50_Weights
  8. import torch.nn.functional as F
  9. import time
  10. class ImageSearchEngine:
  11. def __init__(self):
  12. # 检查GPU是否可用(仅用于PyTorch模型)
  13. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  14. print(f"使用设备: {self.device}")
  15. # 定义基础预处理转换
  16. self.base_transform = transforms.Compose([
  17. transforms.Grayscale(num_output_channels=3), # 转换为灰度图但保持3通道
  18. transforms.ToTensor(),
  19. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  20. ])
  21. # 加载预训练的ResNet模型
  22. self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
  23. # 移除最后的全连接层
  24. self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
  25. self.model = self.model.to(self.device)
  26. self.model.eval()
  27. # 初始化FAISS索引(2048是ResNet50的特征维度)
  28. self.dimension = 2048
  29. # self.index = faiss.IndexFlatL2(self.dimension)
  30. # 改为支持删除的索引
  31. # base_index = faiss.IndexFlatL2(self.dimension)
  32. # self.index = faiss.IndexIDMap(base_index)
  33. def _process_image(self, image_path: str) -> Optional[torch.Tensor]:
  34. """处理单张图片并提取特征。
  35. Args:
  36. image_path: 图片路径
  37. Returns:
  38. 处理后的特征向量,如果处理失败返回None
  39. """
  40. try:
  41. # 读取图片
  42. image = Image.open(image_path)
  43. # 确保图片是RGB模式
  44. if image.mode != 'RGB':
  45. image = image.convert('RGB')
  46. start_ms_time = time.time()
  47. # 提取多尺度特征
  48. multi_scale_features = self._extract_multi_scale_features(image)
  49. end_ms_time = time.time()
  50. print(f"提取多尺度特征耗时: { end_ms_time - start_ms_time } s",)
  51. if multi_scale_features is None:
  52. return None
  53. start_sw_time = time.time()
  54. # 提取滑动窗口特征
  55. sliding_window_features = self._extract_sliding_window_features(image)
  56. end_sw_time = time.time()
  57. print(f"提取滑动窗口耗时: { end_sw_time - start_sw_time } s",)
  58. if sliding_window_features is None:
  59. return None
  60. # 组合特征(加权平均)
  61. combined_feature = multi_scale_features * 0.6 + sliding_window_features * 0.4
  62. # 标准化特征
  63. combined_feature = F.normalize(combined_feature, p=2, dim=0)
  64. return combined_feature
  65. except Exception as e:
  66. print(f"处理图片时出错: {e}")
  67. return None
  68. def _extract_multi_scale_features(self, image: Image.Image) -> Optional[torch.Tensor]:
  69. """提取多尺度特征。"""
  70. try:
  71. features_list = []
  72. width, height = image.size
  73. min_dim = min(width, height)
  74. max_dim = max(width, height)
  75. # 动态生成候选尺寸,基于原图尺寸
  76. scales = [0.25, 0.5, 0.75, 1.0, 1.5, 2.0]
  77. fixed_sizes = [256, 512, 1024, 2048]
  78. candidate_sizes = [int(min_dim * s) for s in scales] + fixed_sizes
  79. max_allowed = int(max_dim * 1.5)
  80. window_sizes = [size for size in candidate_sizes if 64 <= size <= max_allowed]
  81. window_sizes = sorted(list(set(window_sizes)))
  82. if not window_sizes:
  83. return None
  84. for size in window_sizes:
  85. # 保持宽高比调整较小边,并中心裁剪
  86. transform = transforms.Compose([
  87. transforms.Resize(size, interpolation=transforms.InterpolationMode.LANCZOS),
  88. transforms.CenterCrop(size),
  89. self.base_transform
  90. ])
  91. img_tensor = transform(image).unsqueeze(0).to(self.device)
  92. with torch.no_grad():
  93. feature = self.model(img_tensor)
  94. features_list.append(feature.squeeze())
  95. # 加权平均(较大尺度权重更高)
  96. weights = torch.linspace(1, 2, len(features_list), device=self.device)
  97. weights /= weights.sum()
  98. final_feature = torch.stack(features_list) * weights[:, None]
  99. return final_feature.sum(dim=0)
  100. except Exception as e:
  101. print(f"提取多尺度特征时出错: {e}")
  102. return None
  103. def _extract_sliding_window_features(self, image: Image.Image) -> Optional[torch.Tensor]:
  104. """优化版滑动窗口特征提取(动态调整+批量处理)
  105. Args:
  106. image: PIL图片对象
  107. Returns:
  108. 滑动窗口特征向量,处理失败返回None
  109. """
  110. try:
  111. # 基础图片检查
  112. if image is None or image.size[0] < 64 or image.size[1] < 64:
  113. print("图片无效或尺寸过小")
  114. return None
  115. # 获取原图信息
  116. orig_w, orig_h = image.size
  117. aspect_ratio = orig_w / orig_h
  118. max_dim = max(orig_w, orig_h)
  119. # 动态窗口配置 -------------------------------------------
  120. # 使用对数尺度生成窗口尺寸,确保合理的尺寸分布
  121. base_sizes = {256, 512, 1024, 2048}
  122. log_size = np.log2(max_dim)
  123. dynamic_sizes = {
  124. int(2 ** size) for size in [
  125. np.floor(log_size - 1), # 约50%原尺寸
  126. np.ceil(log_size), # 接近原尺寸
  127. ]
  128. }
  129. window_sizes = sorted(base_sizes & dynamic_sizes)
  130. if not window_sizes:
  131. # 如果没有合适的预设尺寸,选择最接近的基础尺寸
  132. closest_size = min(base_sizes, key=lambda x: abs(np.log2(x) - log_size))
  133. window_sizes = [closest_size]
  134. # 智能步长配置(窗口越大,步长比例越大)
  135. def get_stride_ratio(size):
  136. # 使用线性插值计算步长比例
  137. size_ratio = np.clip(size / 2048, 0.2, 0.8)
  138. return 0.2 + size_ratio * 0.3 # 步长比例范围:0.2-0.5
  139. # 预处理优化 --------------------------------------------
  140. # 生成基准图像(使用最大窗口尺寸)
  141. max_win_size = max(window_sizes)
  142. if aspect_ratio > 1:
  143. base_size = (int(max_win_size * aspect_ratio), max_win_size)
  144. else:
  145. base_size = (max_win_size, int(max_win_size / aspect_ratio))
  146. # 图像转换和加载
  147. transform = transforms.Compose([
  148. transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.LANCZOS),
  149. self.base_transform
  150. ])
  151. try:
  152. base_img = transform(image).to(self.device)
  153. except Exception as e:
  154. print(f"图像转换失败: {e}")
  155. return None
  156. # 特征提取 ---------------------------------------------
  157. all_features = []
  158. total_windows = 0
  159. for win_size in window_sizes:
  160. # 计算动态步长
  161. stride_ratio = get_stride_ratio(win_size)
  162. stride = max(int(win_size * stride_ratio), 16) # 确保最小步长
  163. # 计算窗口数量
  164. h, w = base_img.shape[1:]
  165. num_h = (h - win_size) // stride + 1
  166. num_w = (w - win_size) // stride + 1
  167. # 内存优化:控制单个尺寸下的最大窗口数
  168. MAX_WINDOWS_PER_SIZE = 64
  169. if num_h * num_w > MAX_WINDOWS_PER_SIZE:
  170. adjusted_stride = int(np.sqrt((h * w) / MAX_WINDOWS_PER_SIZE))
  171. stride = max(stride, adjusted_stride)
  172. num_h = (h - win_size) // stride + 1
  173. num_w = (w - win_size) // stride + 1
  174. print(f"处理窗口 {win_size}x{win_size}, 步长 {stride}, 窗口数 {num_h * num_w}")
  175. # 批量处理窗口
  176. batch = []
  177. batch_size = min(16, num_h * num_w) # 动态批次大小
  178. for i in range(num_h):
  179. for j in range(num_w):
  180. top = i * stride
  181. left = j * stride
  182. window = base_img[:, top:top+win_size, left:left+win_size]
  183. if torch.isnan(window).any() or torch.isinf(window).any():
  184. continue
  185. batch.append(window)
  186. total_windows += 1
  187. if len(batch) >= batch_size:
  188. with torch.no_grad():
  189. try:
  190. batch_tensor = torch.stack(batch)
  191. features = self.model(batch_tensor)
  192. all_features.append(features.cpu()) # 转移到CPU释放显存
  193. except RuntimeError as e:
  194. print(f"批处理失败,尝试减小批次大小: {e}")
  195. if batch_size > 4:
  196. batch_size //= 2
  197. continue
  198. batch = []
  199. # 处理剩余的窗口
  200. if batch:
  201. with torch.no_grad():
  202. try:
  203. batch_tensor = torch.stack(batch)
  204. features = self.model(batch_tensor)
  205. all_features.append(features.cpu())
  206. except RuntimeError as e:
  207. print(f"处理剩余窗口失败: {e}")
  208. # 特征融合 ---------------------------------------------
  209. if not all_features:
  210. print("未能提取到有效特征")
  211. return None
  212. print(f"总处理窗口数: {total_windows}")
  213. # 合并所有特征
  214. try:
  215. final_features = torch.cat([f.view(-1, f.shape[-1]) for f in all_features], dim=0)
  216. # 如果特征数量过多,进行随机采样
  217. if final_features.size(0) > 1000:
  218. indices = torch.randperm(final_features.size(0))[:1000]
  219. final_features = final_features[indices]
  220. return final_features.mean(dim=0).to(self.device)
  221. except Exception as e:
  222. print(f"特征融合失败: {e}")
  223. return None
  224. except Exception as e:
  225. print(f"滑动窗口特征提取失败: {e}")
  226. return None