123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- import numpy as np
- from PIL import Image
- from typing import Optional
- import torch
- import torchvision.transforms as transforms
- import torchvision.models as models
- from torchvision.models import ResNet50_Weights
- import torch.nn.functional as F
- import time
- class ImageSearchEngine:
- def __init__(self):
-
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- print(f"使用设备: {self.device}")
-
-
- self.base_transform = transforms.Compose([
- transforms.Grayscale(num_output_channels=3),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
-
-
-
- self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
-
- self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
- self.model = self.model.to(self.device)
- self.model.eval()
-
-
- self.dimension = 2048
-
-
-
-
-
- def _process_image(self, image_path: str) -> Optional[torch.Tensor]:
- """处理单张图片并提取特征。
-
- Args:
- image_path: 图片路径
-
- Returns:
- 处理后的特征向量,如果处理失败返回None
- """
- try:
-
- image = Image.open(image_path)
-
-
- if image.mode != 'RGB':
- image = image.convert('RGB')
-
- start_ms_time = time.time()
-
- multi_scale_features = self._extract_multi_scale_features(image)
- end_ms_time = time.time()
- print(f"提取多尺度特征耗时: { end_ms_time - start_ms_time } s",)
- if multi_scale_features is None:
- return None
-
- start_sw_time = time.time()
-
- sliding_window_features = self._extract_sliding_window_features(image)
- end_sw_time = time.time()
- print(f"提取滑动窗口耗时: { end_sw_time - start_sw_time } s",)
- if sliding_window_features is None:
- return None
-
-
- combined_feature = multi_scale_features * 0.6 + sliding_window_features * 0.4
-
-
- combined_feature = F.normalize(combined_feature, p=2, dim=0)
-
- return combined_feature
-
- except Exception as e:
- print(f"处理图片时出错: {e}")
- return None
- def _extract_multi_scale_features(self, image: Image.Image) -> Optional[torch.Tensor]:
- """基于原图分辨率的多尺度特征提取(智能动态调整版)
-
- Args:
- image: PIL图片对象
-
- Returns:
- 多尺度特征向量,处理失败返回None
- """
- try:
-
- orig_w, orig_h = image.size
- max_edge = max(orig_w, orig_h)
- aspect_ratio = orig_w / orig_h
-
-
- base_size = min(max_edge, 3000)
-
-
- min_size = 224
- num_scales = 4
- scale_factors = np.logspace(0, 1, num_scales, base=2)
- window_sizes = [int(base_size * f) for f in scale_factors]
- window_sizes = sorted({min(max(s, min_size), 3000) for s in window_sizes})
-
-
- if aspect_ratio > 1.5:
- window_sizes = [int(s*aspect_ratio) for s in window_sizes]
- elif aspect_ratio < 0.67:
- window_sizes = [int(s/aspect_ratio) for s in window_sizes]
-
-
- base_size = 2 ** int(np.log2(base_size))
- base_transform = transforms.Compose([
- transforms.Resize((base_size, base_size),
- interpolation=transforms.InterpolationMode.LANCZOS),
- self.base_transform
- ])
-
-
- self.model.half()
- img_base = base_transform(image).unsqueeze(0).to(self.device).half()
-
- features = []
- for size in window_sizes:
-
- target_size = (int(size*aspect_ratio), size) if aspect_ratio > 1 else (size, int(size/aspect_ratio))
-
-
- img_tensor = torch.nn.functional.interpolate(
- img_base,
- size=target_size,
- mode= 'area' if size < base_size else 'bicubic',
- align_corners=False
- )
-
- if hasattr(self, 'adaptive_normalize'):
- img_tensor = self.adaptive_normalize(img_tensor)
-
- with torch.no_grad(), torch.cuda.amp.autocast():
- feature = self.model(img_tensor)
-
- features.append(feature.squeeze().float())
-
-
- size_diffs = [abs(size - base_size) for size in window_sizes]
- weights = 1 / (torch.tensor(size_diffs, device=self.device) + 1e-6)
- weights = weights / weights.sum()
-
- final_feature = torch.stack([f * w for f, w in zip(features, weights)]).sum(dim=0)
-
- return final_feature
- except Exception as e:
- print(f"智能特征提取失败: {e}")
- return None
- def _extract_sliding_window_features(self, image: Image.Image) -> Optional[torch.Tensor]:
- """优化版滑动窗口特征提取(动态调整+批量处理)
-
- Args:
- image: PIL图片对象
-
- Returns:
- 滑动窗口特征向量,处理失败返回None
- """
- try:
-
- orig_w, orig_h = image.size
- aspect_ratio = orig_w / orig_h
-
-
-
- max_dim = max(orig_w, orig_h)
- window_sizes = sorted({
- int(2 ** np.round(np.log2(max_dim * 0.1))),
- int(2 ** np.floor(np.log2(max_dim * 0.5))),
- int(2 ** np.ceil(np.log2(max_dim)))
- } & {256, 512, 1024, 2048, 3000})
-
-
- stride_ratios = {256:0.5, 512:0.4, 1024:0.3, 2048:0.2, 3000:0.15}
-
-
-
- max_win_size = max(window_sizes)
- base_size = (int(max_win_size * aspect_ratio), max_win_size) if aspect_ratio > 1 else \
- (max_win_size, int(max_win_size / aspect_ratio))
-
- transform = transforms.Compose([
- transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.LANCZOS),
- self.base_transform
- ])
- base_img = transform(image).to(self.device)
-
-
- self.model.half()
- base_img = base_img.half()
-
- all_features = []
- for win_size in window_sizes:
-
- stride = int(win_size * stride_ratios.get(win_size, 0.3))
-
-
- h, w = base_img.shape[1:]
- num_h = (h - win_size) // stride + 1
- num_w = (w - win_size) // stride + 1
-
-
- MAX_WINDOWS = 32
- if num_h * num_w > MAX_WINDOWS:
- stride = int(np.sqrt(h * w * win_size**2 / MAX_WINDOWS))
- num_h = (h - win_size) // stride + 1
- num_w = (w - win_size) // stride + 1
-
- windows = []
- for i in range(num_h):
- for j in range(num_w):
- top = i * stride
- left = j * stride
- window = base_img[:, top:top+win_size, left:left+win_size]
- windows.append(window)
-
- if not windows:
- continue
-
- BATCH_SIZE = 8
- with torch.no_grad(), torch.cuda.amp.autocast():
- for i in range(0, len(windows), BATCH_SIZE):
- batch = torch.stack(windows[i:i+BATCH_SIZE])
- features = self.model(batch)
- all_features.append(features.cpu().float())
-
- if not all_features:
- return None
-
- final_feature = torch.cat([f.view(-1, f.shape[-1]) for f in all_features], dim=0)
- final_feature = final_feature.mean(dim=0).to(self.device)
- return final_feature
- except Exception as e:
- print(f"滑动窗口特征提取失败: {e}")
- return None
|