import numpy as np from PIL import Image from typing import Optional, Tuple import torch import torchvision.transforms as transforms import torchvision.models as models from torchvision.models import ResNet50_Weights import torch.nn.functional as F from torch.cuda.amp import autocast import time import gc class ImageSearchEngine: def __init__(self): # 检查GPU是否可用 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {self.device}") # 图像预处理参数 self.max_image_size = 2048 # 最大图像尺寸限制 self.base_transform = transforms.Compose([ transforms.Grayscale(num_output_channels=3), # 转换为灰度图但保持3通道 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载预训练的ResNet模型 self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) # 移除最后的全连接层 self.model = torch.nn.Sequential(*list(self.model.children())[:-1]) self.model = self.model.to(self.device) self.model.eval() # 特征维度 self.dimension = 2048 # 内存管理参数 self.min_batch_size = 4 self.max_batch_size = 16 self.memory_threshold = 0.8 # 显存使用阈值 def get_available_memory(self) -> Tuple[float, float]: """获取当前可用显存信息""" if torch.cuda.is_available(): torch.cuda.empty_cache() total = torch.cuda.get_device_properties(0).total_memory allocated = torch.cuda.memory_allocated() return allocated / total, total - allocated return 0.0, 0.0 def get_dynamic_batch_size(self, image_size: int) -> int: """动态计算批处理大小""" if not torch.cuda.is_available(): return self.min_batch_size mem_usage, _ = self.get_available_memory() if mem_usage > self.memory_threshold: return self.min_batch_size # 根据图像大小动态调整 size_factor = (self.max_image_size / image_size) ** 2 batch_size = min(self.max_batch_size, max(self.min_batch_size, int(self.max_batch_size * size_factor))) return batch_size def preprocess_image(self, image: Image.Image) -> Optional[Image.Image]: """图像预处理""" try: # 确保图像尺寸合适 if max(image.size) > self.max_image_size: scale = self.max_image_size / max(image.size) new_size = tuple(int(dim * scale) for dim in image.size) image = image.resize(new_size, Image.BILINEAR) # 使用BILINEAR提高性能 return image except Exception as e: print(f"图像预处理失败: {e}") return None def _process_image(self, image_path: str) -> Optional[torch.Tensor]: """处理单张图片并提取特征""" try: image = Image.open(image_path) if image.mode != 'RGB': image = image.convert('RGB') # 预处理图像 image = self.preprocess_image(image) if image is None: return None # 提取特征 with torch.no_grad(), autocast(): # 使用混合精度计算 start_ms_time = time.time() multi_scale_features = self._extract_multi_scale_features(image) end_ms_time = time.time() print(f"提取多尺度特征耗时: {end_ms_time - start_ms_time:.2f}s") if multi_scale_features is None: return None start_sw_time = time.time() sliding_window_features = self._extract_sliding_window_features(image) end_sw_time = time.time() print(f"提取滑动窗口特征耗时: {end_sw_time - start_sw_time:.2f}s") if sliding_window_features is None: return None # 特征融合(加权平均) combined_feature = multi_scale_features * 0.7 + sliding_window_features * 0.3 combined_feature = F.normalize(combined_feature, p=2, dim=0) return combined_feature except Exception as e: print(f"处理图片时出错: {e}") return None finally: # 清理显存 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def _extract_multi_scale_features(self, image: Image.Image) -> Optional[torch.Tensor]: """优化的多尺度特征提取""" try: features_list = [] width, height = image.size min_dim = min(width, height) # 优化的尺度选择 scales = [0.25, 0.5, 0.75, 1.0] # 减少尺度数量 window_sizes = sorted(list({int(min_dim * s) for s in scales})) window_sizes = [s for s in window_sizes if 64 <= s <= self.max_image_size] if not window_sizes: return None batch_size = self.get_dynamic_batch_size(max(window_sizes)) for size in window_sizes: transform = transforms.Compose([ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size), self.base_transform ]) with torch.no_grad(), autocast(): img_tensor = transform(image).unsqueeze(0).to(self.device) feature = self.model(img_tensor) features_list.append(feature.squeeze()) # 加权平均(较小尺度权重更高,适应部分图搜索) weights = torch.linspace(2, 1, len(features_list), device=self.device) weights /= weights.sum() final_feature = torch.stack(features_list) * weights[:, None] return final_feature.sum(dim=0) except Exception as e: print(f"提取多尺度特征时出错: {e}") return None def _extract_sliding_window_features(self, image: Image.Image) -> Optional[torch.Tensor]: """优化的滑动窗口特征提取""" try: if image is None or image.size[0] < 64 or image.size[1] < 64: return None orig_w, orig_h = image.size aspect_ratio = orig_w / orig_h # 优化窗口配置 base_size = min(512, min(orig_w, orig_h)) # 使用较小的基础窗口大小 window_sizes = [base_size] # 图像预处理 if aspect_ratio > 1: base_size = (int(base_size * aspect_ratio), base_size) else: base_size = (base_size, int(base_size / aspect_ratio)) transform = transforms.Compose([ transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.BILINEAR), self.base_transform ]) try: base_img = transform(image).to(self.device) except Exception as e: print(f"图像转换失败: {e}") return None all_features = [] total_windows = 0 for win_size in window_sizes: # 动态步长 stride = max(int(win_size * 0.25), 16) # 固定步长比例为0.25 # 计算窗口数量 h, w = base_img.shape[1:] num_h = (h - win_size) // stride + 1 num_w = (w - win_size) // stride + 1 # 内存优化 batch_size = self.get_dynamic_batch_size(win_size) batch = [] for i in range(num_h): for j in range(num_w): if self.get_available_memory()[0] > self.memory_threshold: print("显存使用率过高,正在清理...") torch.cuda.empty_cache() gc.collect() top = i * stride left = j * stride window = base_img[:, top:top+win_size, left:left+win_size] if torch.isnan(window).any() or torch.isinf(window).any(): continue batch.append(window) total_windows += 1 if len(batch) >= batch_size: with torch.no_grad(), autocast(): try: batch_tensor = torch.stack(batch) features = self.model(batch_tensor) all_features.append(features.cpu()) except RuntimeError as e: print(f"批处理失败,减小批次大小: {e}") if batch_size > self.min_batch_size: batch_size = max(batch_size // 2, self.min_batch_size) continue batch = [] # 处理剩余窗口 if batch: with torch.no_grad(), autocast(): try: batch_tensor = torch.stack(batch) features = self.model(batch_tensor) all_features.append(features.cpu()) except RuntimeError as e: print(f"处理剩余窗口失败: {e}") if not all_features: return None print(f"总处理窗口数: {total_windows}") # 特征融合 try: final_features = torch.cat([f.view(-1, f.shape[-1]) for f in all_features], dim=0) # 特征采样 if final_features.size(0) > 1000: indices = torch.randperm(final_features.size(0))[:1000] final_features = final_features[indices] return final_features.mean(dim=0).to(self.device) except Exception as e: print(f"特征融合失败: {e}") return None except Exception as e: print(f"滑动窗口特征提取失败: {e}") return None finally: # 清理显存 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect()