zzz.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import numpy as np
  2. from PIL import Image
  3. from typing import Optional, Tuple
  4. import torch
  5. import torchvision.transforms as transforms
  6. import torchvision.models as models
  7. from torchvision.models import ResNet50_Weights
  8. import torch.nn.functional as F
  9. from torch.cuda.amp import autocast
  10. import time
  11. import gc
  12. class ImageSearchEngine:
  13. def __init__(self):
  14. # 检查GPU是否可用
  15. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  16. print(f"使用设备: {self.device}")
  17. # 图像预处理参数
  18. self.max_image_size = 2048 # 最大图像尺寸限制
  19. self.base_transform = transforms.Compose([
  20. transforms.Grayscale(num_output_channels=3), # 转换为灰度图但保持3通道
  21. transforms.ToTensor(),
  22. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  23. ])
  24. # 加载预训练的ResNet模型
  25. self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
  26. # 移除最后的全连接层
  27. self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
  28. self.model = self.model.to(self.device)
  29. self.model.eval()
  30. # 特征维度
  31. self.dimension = 2048
  32. # 内存管理参数
  33. self.min_batch_size = 4
  34. self.max_batch_size = 16
  35. self.memory_threshold = 0.8 # 显存使用阈值
  36. def get_available_memory(self) -> Tuple[float, float]:
  37. """获取当前可用显存信息"""
  38. if torch.cuda.is_available():
  39. torch.cuda.empty_cache()
  40. total = torch.cuda.get_device_properties(0).total_memory
  41. allocated = torch.cuda.memory_allocated()
  42. return allocated / total, total - allocated
  43. return 0.0, 0.0
  44. def get_dynamic_batch_size(self, image_size: int) -> int:
  45. """动态计算批处理大小"""
  46. if not torch.cuda.is_available():
  47. return self.min_batch_size
  48. mem_usage, _ = self.get_available_memory()
  49. if mem_usage > self.memory_threshold:
  50. return self.min_batch_size
  51. # 根据图像大小动态调整
  52. size_factor = (self.max_image_size / image_size) ** 2
  53. batch_size = min(self.max_batch_size,
  54. max(self.min_batch_size,
  55. int(self.max_batch_size * size_factor)))
  56. return batch_size
  57. def preprocess_image(self, image: Image.Image) -> Optional[Image.Image]:
  58. """图像预处理"""
  59. try:
  60. # 确保图像尺寸合适
  61. if max(image.size) > self.max_image_size:
  62. scale = self.max_image_size / max(image.size)
  63. new_size = tuple(int(dim * scale) for dim in image.size)
  64. image = image.resize(new_size, Image.BILINEAR) # 使用BILINEAR提高性能
  65. return image
  66. except Exception as e:
  67. print(f"图像预处理失败: {e}")
  68. return None
  69. def _process_image(self, image_path: str) -> Optional[torch.Tensor]:
  70. """处理单张图片并提取特征"""
  71. try:
  72. image = Image.open(image_path)
  73. if image.mode != 'RGB':
  74. image = image.convert('RGB')
  75. # 预处理图像
  76. image = self.preprocess_image(image)
  77. if image is None:
  78. return None
  79. # 提取特征
  80. with torch.no_grad(), autocast(): # 使用混合精度计算
  81. start_ms_time = time.time()
  82. multi_scale_features = self._extract_multi_scale_features(image)
  83. end_ms_time = time.time()
  84. print(f"提取多尺度特征耗时: {end_ms_time - start_ms_time:.2f}s")
  85. if multi_scale_features is None:
  86. return None
  87. start_sw_time = time.time()
  88. sliding_window_features = self._extract_sliding_window_features(image)
  89. end_sw_time = time.time()
  90. print(f"提取滑动窗口特征耗时: {end_sw_time - start_sw_time:.2f}s")
  91. if sliding_window_features is None:
  92. return None
  93. # 特征融合(加权平均)
  94. combined_feature = multi_scale_features * 0.7 + sliding_window_features * 0.3
  95. combined_feature = F.normalize(combined_feature, p=2, dim=0)
  96. return combined_feature
  97. except Exception as e:
  98. print(f"处理图片时出错: {e}")
  99. return None
  100. finally:
  101. # 清理显存
  102. if torch.cuda.is_available():
  103. torch.cuda.empty_cache()
  104. gc.collect()
  105. def _extract_multi_scale_features(self, image: Image.Image) -> Optional[torch.Tensor]:
  106. """优化的多尺度特征提取"""
  107. try:
  108. features_list = []
  109. width, height = image.size
  110. min_dim = min(width, height)
  111. # 优化的尺度选择
  112. scales = [0.25, 0.5, 0.75, 1.0] # 减少尺度数量
  113. window_sizes = sorted(list({int(min_dim * s) for s in scales}))
  114. window_sizes = [s for s in window_sizes if 64 <= s <= self.max_image_size]
  115. if not window_sizes:
  116. return None
  117. batch_size = self.get_dynamic_batch_size(max(window_sizes))
  118. for size in window_sizes:
  119. transform = transforms.Compose([
  120. transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
  121. transforms.CenterCrop(size),
  122. self.base_transform
  123. ])
  124. with torch.no_grad(), autocast():
  125. img_tensor = transform(image).unsqueeze(0).to(self.device)
  126. feature = self.model(img_tensor)
  127. features_list.append(feature.squeeze())
  128. # 加权平均(较小尺度权重更高,适应部分图搜索)
  129. weights = torch.linspace(2, 1, len(features_list), device=self.device)
  130. weights /= weights.sum()
  131. final_feature = torch.stack(features_list) * weights[:, None]
  132. return final_feature.sum(dim=0)
  133. except Exception as e:
  134. print(f"提取多尺度特征时出错: {e}")
  135. return None
  136. def _extract_sliding_window_features(self, image: Image.Image) -> Optional[torch.Tensor]:
  137. """优化的滑动窗口特征提取"""
  138. try:
  139. if image is None or image.size[0] < 64 or image.size[1] < 64:
  140. return None
  141. orig_w, orig_h = image.size
  142. aspect_ratio = orig_w / orig_h
  143. # 优化窗口配置
  144. base_size = min(512, min(orig_w, orig_h)) # 使用较小的基础窗口大小
  145. window_sizes = [base_size]
  146. # 图像预处理
  147. if aspect_ratio > 1:
  148. base_size = (int(base_size * aspect_ratio), base_size)
  149. else:
  150. base_size = (base_size, int(base_size / aspect_ratio))
  151. transform = transforms.Compose([
  152. transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.BILINEAR),
  153. self.base_transform
  154. ])
  155. try:
  156. base_img = transform(image).to(self.device)
  157. except Exception as e:
  158. print(f"图像转换失败: {e}")
  159. return None
  160. all_features = []
  161. total_windows = 0
  162. for win_size in window_sizes:
  163. # 动态步长
  164. stride = max(int(win_size * 0.25), 16) # 固定步长比例为0.25
  165. # 计算窗口数量
  166. h, w = base_img.shape[1:]
  167. num_h = (h - win_size) // stride + 1
  168. num_w = (w - win_size) // stride + 1
  169. # 内存优化
  170. batch_size = self.get_dynamic_batch_size(win_size)
  171. batch = []
  172. for i in range(num_h):
  173. for j in range(num_w):
  174. if self.get_available_memory()[0] > self.memory_threshold:
  175. print("显存使用率过高,正在清理...")
  176. torch.cuda.empty_cache()
  177. gc.collect()
  178. top = i * stride
  179. left = j * stride
  180. window = base_img[:, top:top+win_size, left:left+win_size]
  181. if torch.isnan(window).any() or torch.isinf(window).any():
  182. continue
  183. batch.append(window)
  184. total_windows += 1
  185. if len(batch) >= batch_size:
  186. with torch.no_grad(), autocast():
  187. try:
  188. batch_tensor = torch.stack(batch)
  189. features = self.model(batch_tensor)
  190. all_features.append(features.cpu())
  191. except RuntimeError as e:
  192. print(f"批处理失败,减小批次大小: {e}")
  193. if batch_size > self.min_batch_size:
  194. batch_size = max(batch_size // 2, self.min_batch_size)
  195. continue
  196. batch = []
  197. # 处理剩余窗口
  198. if batch:
  199. with torch.no_grad(), autocast():
  200. try:
  201. batch_tensor = torch.stack(batch)
  202. features = self.model(batch_tensor)
  203. all_features.append(features.cpu())
  204. except RuntimeError as e:
  205. print(f"处理剩余窗口失败: {e}")
  206. if not all_features:
  207. return None
  208. print(f"总处理窗口数: {total_windows}")
  209. # 特征融合
  210. try:
  211. final_features = torch.cat([f.view(-1, f.shape[-1]) for f in all_features], dim=0)
  212. # 特征采样
  213. if final_features.size(0) > 1000:
  214. indices = torch.randperm(final_features.size(0))[:1000]
  215. final_features = final_features[indices]
  216. return final_features.mean(dim=0).to(self.device)
  217. except Exception as e:
  218. print(f"特征融合失败: {e}")
  219. return None
  220. except Exception as e:
  221. print(f"滑动窗口特征提取失败: {e}")
  222. return None
  223. finally:
  224. # 清理显存
  225. if torch.cuda.is_available():
  226. torch.cuda.empty_cache()
  227. gc.collect()