123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- import numpy as np
- from PIL import Image
- from typing import Optional, Tuple
- import torch
- import torchvision.transforms as transforms
- import torchvision.models as models
- from torchvision.models import ResNet50_Weights
- import torch.nn.functional as F
- from torch.cuda.amp import autocast
- import time
- import gc
- class ImageSearchEngine:
- def __init__(self):
- # 检查GPU是否可用
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- print(f"使用设备: {self.device}")
-
- # 图像预处理参数
- self.max_image_size = 2048 # 最大图像尺寸限制
- self.base_transform = transforms.Compose([
- transforms.Grayscale(num_output_channels=3), # 转换为灰度图但保持3通道
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
-
-
- # 加载预训练的ResNet模型
- self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
- # 移除最后的全连接层
- self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
- self.model = self.model.to(self.device)
- self.model.eval()
-
- # 特征维度
- self.dimension = 2048
-
- # 内存管理参数
- self.min_batch_size = 4
- self.max_batch_size = 16
- self.memory_threshold = 0.8 # 显存使用阈值
-
- def get_available_memory(self) -> Tuple[float, float]:
- """获取当前可用显存信息"""
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- total = torch.cuda.get_device_properties(0).total_memory
- allocated = torch.cuda.memory_allocated()
- return allocated / total, total - allocated
- return 0.0, 0.0
- def get_dynamic_batch_size(self, image_size: int) -> int:
- """动态计算批处理大小"""
- if not torch.cuda.is_available():
- return self.min_batch_size
-
- mem_usage, _ = self.get_available_memory()
- if mem_usage > self.memory_threshold:
- return self.min_batch_size
-
- # 根据图像大小动态调整
- size_factor = (self.max_image_size / image_size) ** 2
- batch_size = min(self.max_batch_size,
- max(self.min_batch_size,
- int(self.max_batch_size * size_factor)))
- return batch_size
- def preprocess_image(self, image: Image.Image) -> Optional[Image.Image]:
- """图像预处理"""
- try:
- # 确保图像尺寸合适
- if max(image.size) > self.max_image_size:
- scale = self.max_image_size / max(image.size)
- new_size = tuple(int(dim * scale) for dim in image.size)
- image = image.resize(new_size, Image.BILINEAR) # 使用BILINEAR提高性能
-
- return image
- except Exception as e:
- print(f"图像预处理失败: {e}")
- return None
- def _process_image(self, image_path: str) -> Optional[torch.Tensor]:
- """处理单张图片并提取特征"""
- try:
- image = Image.open(image_path)
- if image.mode != 'RGB':
- image = image.convert('RGB')
-
- # 预处理图像
- image = self.preprocess_image(image)
- if image is None:
- return None
- # 提取特征
- with torch.no_grad(), autocast(): # 使用混合精度计算
- start_ms_time = time.time()
- multi_scale_features = self._extract_multi_scale_features(image)
- end_ms_time = time.time()
- print(f"提取多尺度特征耗时: {end_ms_time - start_ms_time:.2f}s")
-
- if multi_scale_features is None:
- return None
-
- start_sw_time = time.time()
- sliding_window_features = self._extract_sliding_window_features(image)
- end_sw_time = time.time()
- print(f"提取滑动窗口特征耗时: {end_sw_time - start_sw_time:.2f}s")
-
- if sliding_window_features is None:
- return None
-
- # 特征融合(加权平均)
- combined_feature = multi_scale_features * 0.7 + sliding_window_features * 0.3
- combined_feature = F.normalize(combined_feature, p=2, dim=0)
-
- return combined_feature
-
- except Exception as e:
- print(f"处理图片时出错: {e}")
- return None
- finally:
- # 清理显存
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- gc.collect()
- def _extract_multi_scale_features(self, image: Image.Image) -> Optional[torch.Tensor]:
- """优化的多尺度特征提取"""
- try:
- features_list = []
- width, height = image.size
- min_dim = min(width, height)
-
- # 优化的尺度选择
- scales = [0.25, 0.5, 0.75, 1.0] # 减少尺度数量
- window_sizes = sorted(list({int(min_dim * s) for s in scales}))
- window_sizes = [s for s in window_sizes if 64 <= s <= self.max_image_size]
-
- if not window_sizes:
- return None
-
- batch_size = self.get_dynamic_batch_size(max(window_sizes))
-
- for size in window_sizes:
- transform = transforms.Compose([
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
- transforms.CenterCrop(size),
- self.base_transform
- ])
-
- with torch.no_grad(), autocast():
- img_tensor = transform(image).unsqueeze(0).to(self.device)
- feature = self.model(img_tensor)
- features_list.append(feature.squeeze())
-
- # 加权平均(较小尺度权重更高,适应部分图搜索)
- weights = torch.linspace(2, 1, len(features_list), device=self.device)
- weights /= weights.sum()
- final_feature = torch.stack(features_list) * weights[:, None]
- return final_feature.sum(dim=0)
-
- except Exception as e:
- print(f"提取多尺度特征时出错: {e}")
- return None
-
- def _extract_sliding_window_features(self, image: Image.Image) -> Optional[torch.Tensor]:
- """优化的滑动窗口特征提取"""
- try:
- if image is None or image.size[0] < 64 or image.size[1] < 64:
- return None
-
- orig_w, orig_h = image.size
- aspect_ratio = orig_w / orig_h
-
- # 优化窗口配置
- base_size = min(512, min(orig_w, orig_h)) # 使用较小的基础窗口大小
- window_sizes = [base_size]
-
- # 图像预处理
- if aspect_ratio > 1:
- base_size = (int(base_size * aspect_ratio), base_size)
- else:
- base_size = (base_size, int(base_size / aspect_ratio))
-
- transform = transforms.Compose([
- transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.BILINEAR),
- self.base_transform
- ])
-
- try:
- base_img = transform(image).to(self.device)
- except Exception as e:
- print(f"图像转换失败: {e}")
- return None
-
- all_features = []
- total_windows = 0
-
- for win_size in window_sizes:
- # 动态步长
- stride = max(int(win_size * 0.25), 16) # 固定步长比例为0.25
-
- # 计算窗口数量
- h, w = base_img.shape[1:]
- num_h = (h - win_size) // stride + 1
- num_w = (w - win_size) // stride + 1
-
- # 内存优化
- batch_size = self.get_dynamic_batch_size(win_size)
- batch = []
-
- for i in range(num_h):
- for j in range(num_w):
- if self.get_available_memory()[0] > self.memory_threshold:
- print("显存使用率过高,正在清理...")
- torch.cuda.empty_cache()
- gc.collect()
-
- top = i * stride
- left = j * stride
- window = base_img[:, top:top+win_size, left:left+win_size]
-
- if torch.isnan(window).any() or torch.isinf(window).any():
- continue
-
- batch.append(window)
- total_windows += 1
-
- if len(batch) >= batch_size:
- with torch.no_grad(), autocast():
- try:
- batch_tensor = torch.stack(batch)
- features = self.model(batch_tensor)
- all_features.append(features.cpu())
- except RuntimeError as e:
- print(f"批处理失败,减小批次大小: {e}")
- if batch_size > self.min_batch_size:
- batch_size = max(batch_size // 2, self.min_batch_size)
- continue
- batch = []
-
- # 处理剩余窗口
- if batch:
- with torch.no_grad(), autocast():
- try:
- batch_tensor = torch.stack(batch)
- features = self.model(batch_tensor)
- all_features.append(features.cpu())
- except RuntimeError as e:
- print(f"处理剩余窗口失败: {e}")
-
- if not all_features:
- return None
-
- print(f"总处理窗口数: {total_windows}")
-
- # 特征融合
- try:
- final_features = torch.cat([f.view(-1, f.shape[-1]) for f in all_features], dim=0)
-
- # 特征采样
- if final_features.size(0) > 1000:
- indices = torch.randperm(final_features.size(0))[:1000]
- final_features = final_features[indices]
-
- return final_features.mean(dim=0).to(self.device)
-
- except Exception as e:
- print(f"特征融合失败: {e}")
- return None
- except Exception as e:
- print(f"滑动窗口特征提取失败: {e}")
- return None
- finally:
- # 清理显存
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- gc.collect()
|