import numpy as np from PIL import Image from typing import Optional import torch import torchvision.transforms as transforms import torchvision.models as models from torchvision.models import ResNet50_Weights import torch.nn.functional as F import time class ImageSearchEngine: def __init__(self): # 检查GPU是否可用(仅用于PyTorch模型) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {self.device}") # 定义基础预处理转换 self.base_transform = transforms.Compose([ transforms.Grayscale(num_output_channels=3), # 转换为灰度图但保持3通道 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载预训练的ResNet模型 self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) # 移除最后的全连接层 self.model = torch.nn.Sequential(*list(self.model.children())[:-1]) self.model = self.model.to(self.device) self.model.eval() # 初始化FAISS索引(2048是ResNet50的特征维度) self.dimension = 2048 # self.index = faiss.IndexFlatL2(self.dimension) # 改为支持删除的索引 # base_index = faiss.IndexFlatL2(self.dimension) # self.index = faiss.IndexIDMap(base_index) def _process_image(self, image_path: str) -> Optional[torch.Tensor]: """处理单张图片并提取特征。 Args: image_path: 图片路径 Returns: 处理后的特征向量,如果处理失败返回None """ try: # 读取图片 image = Image.open(image_path) # 确保图片是RGB模式 if image.mode != 'RGB': image = image.convert('RGB') start_ms_time = time.time() # 提取多尺度特征 multi_scale_features = self._extract_multi_scale_features(image) end_ms_time = time.time() print(f"提取多尺度特征耗时: { end_ms_time - start_ms_time } s",) if multi_scale_features is None: return None start_sw_time = time.time() # 提取滑动窗口特征 sliding_window_features = self._extract_sliding_window_features(image) end_sw_time = time.time() print(f"提取滑动窗口耗时: { end_sw_time - start_sw_time } s",) if sliding_window_features is None: return None # 组合特征(加权平均) combined_feature = multi_scale_features * 0.6 + sliding_window_features * 0.4 # 标准化特征 combined_feature = F.normalize(combined_feature, p=2, dim=0) return combined_feature except Exception as e: print(f"处理图片时出错: {e}") return None def _extract_multi_scale_features(self, image: Image.Image) -> Optional[torch.Tensor]: """提取多尺度特征。""" try: features_list = [] width, height = image.size min_dim = min(width, height) max_dim = max(width, height) # 动态生成候选尺寸,基于原图尺寸 scales = [0.25, 0.5, 0.75, 1.0, 1.5, 2.0] fixed_sizes = [256, 512, 1024, 2048] candidate_sizes = [int(min_dim * s) for s in scales] + fixed_sizes max_allowed = int(max_dim * 1.5) window_sizes = [size for size in candidate_sizes if 64 <= size <= max_allowed] window_sizes = sorted(list(set(window_sizes))) if not window_sizes: return None for size in window_sizes: # 保持宽高比调整较小边,并中心裁剪 transform = transforms.Compose([ transforms.Resize(size, interpolation=transforms.InterpolationMode.LANCZOS), transforms.CenterCrop(size), self.base_transform ]) img_tensor = transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): feature = self.model(img_tensor) features_list.append(feature.squeeze()) # 加权平均(较大尺度权重更高) weights = torch.linspace(1, 2, len(features_list), device=self.device) weights /= weights.sum() final_feature = torch.stack(features_list) * weights[:, None] return final_feature.sum(dim=0) except Exception as e: print(f"提取多尺度特征时出错: {e}") return None def _extract_sliding_window_features(self, image: Image.Image) -> Optional[torch.Tensor]: """优化版滑动窗口特征提取(动态调整+批量处理) Args: image: PIL图片对象 Returns: 滑动窗口特征向量,处理失败返回None """ try: # 基础图片检查 if image is None or image.size[0] < 64 or image.size[1] < 64: print("图片无效或尺寸过小") return None # 获取原图信息 orig_w, orig_h = image.size aspect_ratio = orig_w / orig_h max_dim = max(orig_w, orig_h) # 动态窗口配置 ------------------------------------------- # 使用对数尺度生成窗口尺寸,确保合理的尺寸分布 base_sizes = {256, 512, 1024, 2048} log_size = np.log2(max_dim) dynamic_sizes = { int(2 ** size) for size in [ np.floor(log_size - 1), # 约50%原尺寸 np.ceil(log_size), # 接近原尺寸 ] } window_sizes = sorted(base_sizes & dynamic_sizes) if not window_sizes: # 如果没有合适的预设尺寸,选择最接近的基础尺寸 closest_size = min(base_sizes, key=lambda x: abs(np.log2(x) - log_size)) window_sizes = [closest_size] # 智能步长配置(窗口越大,步长比例越大) def get_stride_ratio(size): # 使用线性插值计算步长比例 size_ratio = np.clip(size / 2048, 0.2, 0.8) return 0.2 + size_ratio * 0.3 # 步长比例范围:0.2-0.5 # 预处理优化 -------------------------------------------- # 生成基准图像(使用最大窗口尺寸) max_win_size = max(window_sizes) if aspect_ratio > 1: base_size = (int(max_win_size * aspect_ratio), max_win_size) else: base_size = (max_win_size, int(max_win_size / aspect_ratio)) # 图像转换和加载 transform = transforms.Compose([ transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.LANCZOS), self.base_transform ]) try: base_img = transform(image).to(self.device) except Exception as e: print(f"图像转换失败: {e}") return None # 特征提取 --------------------------------------------- all_features = [] total_windows = 0 for win_size in window_sizes: # 计算动态步长 stride_ratio = get_stride_ratio(win_size) stride = max(int(win_size * stride_ratio), 16) # 确保最小步长 # 计算窗口数量 h, w = base_img.shape[1:] num_h = (h - win_size) // stride + 1 num_w = (w - win_size) // stride + 1 # 内存优化:控制单个尺寸下的最大窗口数 MAX_WINDOWS_PER_SIZE = 64 if num_h * num_w > MAX_WINDOWS_PER_SIZE: adjusted_stride = int(np.sqrt((h * w) / MAX_WINDOWS_PER_SIZE)) stride = max(stride, adjusted_stride) num_h = (h - win_size) // stride + 1 num_w = (w - win_size) // stride + 1 print(f"处理窗口 {win_size}x{win_size}, 步长 {stride}, 窗口数 {num_h * num_w}") # 批量处理窗口 batch = [] batch_size = min(16, num_h * num_w) # 动态批次大小 for i in range(num_h): for j in range(num_w): top = i * stride left = j * stride window = base_img[:, top:top+win_size, left:left+win_size] if torch.isnan(window).any() or torch.isinf(window).any(): continue batch.append(window) total_windows += 1 if len(batch) >= batch_size: with torch.no_grad(): try: batch_tensor = torch.stack(batch) features = self.model(batch_tensor) all_features.append(features.cpu()) # 转移到CPU释放显存 except RuntimeError as e: print(f"批处理失败,尝试减小批次大小: {e}") if batch_size > 4: batch_size //= 2 continue batch = [] # 处理剩余的窗口 if batch: with torch.no_grad(): try: batch_tensor = torch.stack(batch) features = self.model(batch_tensor) all_features.append(features.cpu()) except RuntimeError as e: print(f"处理剩余窗口失败: {e}") # 特征融合 --------------------------------------------- if not all_features: print("未能提取到有效特征") return None print(f"总处理窗口数: {total_windows}") # 合并所有特征 try: final_features = torch.cat([f.view(-1, f.shape[-1]) for f in all_features], dim=0) # 如果特征数量过多,进行随机采样 if final_features.size(0) > 1000: indices = torch.randperm(final_features.size(0))[:1000] final_features = final_features[indices] return final_features.mean(dim=0).to(self.device) except Exception as e: print(f"特征融合失败: {e}") return None except Exception as e: print(f"滑动窗口特征提取失败: {e}") return None