|
@@ -1,25 +1,20 @@
|
|
|
-import faiss
|
|
|
+
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
-import io
|
|
|
-import os
|
|
|
-from typing import List, Tuple, Optional, Union
|
|
|
+from typing import Optional
|
|
|
import torch
|
|
|
import torchvision.transforms as transforms
|
|
|
import torchvision.models as models
|
|
|
from torchvision.models import ResNet50_Weights
|
|
|
-from scipy import ndimage
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
-from pymongo import MongoClient
|
|
|
-import datetime
|
|
|
import time
|
|
|
|
|
|
class ImageSearchEngine:
|
|
|
+
|
|
|
def __init__(self):
|
|
|
-
|
|
|
- # 强制使用CPU设备
|
|
|
- self.device = torch.device("cpu")
|
|
|
+ # 检查GPU是否可用(仅用于PyTorch模型)
|
|
|
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"使用设备: {self.device}")
|
|
|
|
|
|
# 定义基础预处理转换
|
|
@@ -30,28 +25,21 @@ class ImageSearchEngine:
|
|
|
])
|
|
|
|
|
|
|
|
|
- # 加载预训练的ResNet模型并保持全精度
|
|
|
+ # 加载预训练的ResNet模型
|
|
|
self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
|
|
|
# 移除最后的全连接层
|
|
|
self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
|
|
|
- self.model = self.model.float().to(self.device).eval()
|
|
|
+ self.model = self.model.to(self.device)
|
|
|
+ self.model.eval()
|
|
|
|
|
|
# 初始化FAISS索引(2048是ResNet50的特征维度)
|
|
|
self.dimension = 2048
|
|
|
# self.index = faiss.IndexFlatL2(self.dimension)
|
|
|
|
|
|
-
|
|
|
- def _batch_generator(self, cursor, batch_size):
|
|
|
- """从MongoDB游标中分批生成数据"""
|
|
|
- batch = []
|
|
|
- for doc in cursor:
|
|
|
- batch.append(doc)
|
|
|
- if len(batch) == batch_size:
|
|
|
- yield batch
|
|
|
- batch = []
|
|
|
- if batch:
|
|
|
- yield batch
|
|
|
-
|
|
|
+ # 改为支持删除的索引
|
|
|
+ # base_index = faiss.IndexFlatL2(self.dimension)
|
|
|
+ # self.index = faiss.IndexIDMap(base_index)
|
|
|
+
|
|
|
|
|
|
def _process_image(self, image_path: str) -> Optional[torch.Tensor]:
|
|
|
"""处理单张图片并提取特征。
|
|
@@ -109,9 +97,6 @@ class ImageSearchEngine:
|
|
|
多尺度特征向量,处理失败返回None
|
|
|
"""
|
|
|
try:
|
|
|
- # 三重精度保障
|
|
|
- self.model = self.model.float()
|
|
|
-
|
|
|
# 获取原图信息
|
|
|
orig_w, orig_h = image.size
|
|
|
max_edge = max(orig_w, orig_h)
|
|
@@ -119,14 +104,14 @@ class ImageSearchEngine:
|
|
|
|
|
|
# 动态调整策略 -------------------------------------------
|
|
|
# 策略1:根据最大边长确定基准尺寸
|
|
|
- base_size = min(max_edge, 2048) # 最大尺寸限制
|
|
|
+ base_size = min(max_edge, 3000) # 不超过模型支持的最大尺寸
|
|
|
|
|
|
# 策略2:自动生成窗口尺寸(等比数列)
|
|
|
min_size = 224 # 最小特征尺寸
|
|
|
- num_scales = 3 # 采样点数
|
|
|
+ num_scales = 4 # 固定采样点数
|
|
|
scale_factors = np.logspace(0, 1, num_scales, base=2)
|
|
|
window_sizes = [int(base_size * f) for f in scale_factors]
|
|
|
- window_sizes = sorted({min(max(s, min_size), 2048) for s in window_sizes})
|
|
|
+ window_sizes = sorted({min(max(s, min_size), 3000) for s in window_sizes})
|
|
|
|
|
|
# 策略3:根据长宽比调整尺寸组合
|
|
|
if aspect_ratio > 1.5: # 宽幅图像
|
|
@@ -143,8 +128,9 @@ class ImageSearchEngine:
|
|
|
self.base_transform
|
|
|
])
|
|
|
|
|
|
- #
|
|
|
- img_base = base_transform(image).unsqueeze(0).to(torch.float32).to(self.device)
|
|
|
+ # 半精度加速
|
|
|
+ self.model.half()
|
|
|
+ img_base = base_transform(image).unsqueeze(0).to(self.device).half()
|
|
|
|
|
|
# 动态特征提取 ------------------------------------------
|
|
|
features = []
|
|
@@ -152,21 +138,21 @@ class ImageSearchEngine:
|
|
|
# 保持长宽比的重采样
|
|
|
target_size = (int(size*aspect_ratio), size) if aspect_ratio > 1 else (size, int(size/aspect_ratio))
|
|
|
|
|
|
- # CPU兼容的插值
|
|
|
+ # GPU加速的智能插值
|
|
|
img_tensor = torch.nn.functional.interpolate(
|
|
|
img_base,
|
|
|
size=target_size,
|
|
|
- mode='bilinear',
|
|
|
+ mode= 'area' if size < base_size else 'bicubic', # 下采样用area,上采样用bicubic
|
|
|
align_corners=False
|
|
|
- ).to(torch.float32)
|
|
|
+ )
|
|
|
|
|
|
# 自适应归一化(保持原图统计特性)
|
|
|
if hasattr(self, 'adaptive_normalize'):
|
|
|
img_tensor = self.adaptive_normalize(img_tensor)
|
|
|
|
|
|
# 混合精度推理
|
|
|
- with torch.no_grad():
|
|
|
- feature = self.model(img_tensor).to(torch.float32)
|
|
|
+ with torch.no_grad(), torch.cuda.amp.autocast():
|
|
|
+ feature = self.model(img_tensor)
|
|
|
|
|
|
features.append(feature.squeeze().float())
|
|
|
|
|
@@ -185,9 +171,6 @@ class ImageSearchEngine:
|
|
|
print(f"智能特征提取失败: {e}")
|
|
|
return None
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
def _extract_sliding_window_features(self, image: Image.Image) -> Optional[torch.Tensor]:
|
|
|
"""优化版滑动窗口特征提取(动态调整+批量处理)
|
|
|
|
|
@@ -198,9 +181,6 @@ class ImageSearchEngine:
|
|
|
滑动窗口特征向量,处理失败返回None
|
|
|
"""
|
|
|
try:
|
|
|
- # 三重精度保障
|
|
|
- self.model = self.model.float()
|
|
|
-
|
|
|
# 获取原图信息
|
|
|
orig_w, orig_h = image.size
|
|
|
aspect_ratio = orig_w / orig_h
|
|
@@ -212,10 +192,10 @@ class ImageSearchEngine:
|
|
|
int(2 ** np.round(np.log2(max_dim * 0.1))), # 约10%尺寸
|
|
|
int(2 ** np.floor(np.log2(max_dim * 0.5))), # 约50%尺寸
|
|
|
int(2 ** np.ceil(np.log2(max_dim))) # 接近原图尺寸
|
|
|
- } & {256, 512, 1024, 2048}) # 与预设尺寸取交集
|
|
|
+ } & {256, 512, 1024, 2048, 3000}) # 与预设尺寸取交集
|
|
|
|
|
|
# 智能步长调整(窗口尺寸越大步长越大)
|
|
|
- stride_ratios = {256:0.5, 512:0.4, 1024:0.3, 2048:0.2}
|
|
|
+ stride_ratios = {256:0.5, 512:0.4, 1024:0.3, 2048:0.2, 3000:0.15}
|
|
|
|
|
|
# 预处理优化 --------------------------------------------
|
|
|
# 生成基准图像(最大窗口尺寸)
|
|
@@ -224,11 +204,15 @@ class ImageSearchEngine:
|
|
|
(max_win_size, int(max_win_size / aspect_ratio))
|
|
|
|
|
|
transform = transforms.Compose([
|
|
|
- transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.BILINEAR),
|
|
|
+ transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.LANCZOS),
|
|
|
self.base_transform
|
|
|
])
|
|
|
- base_img = transform(image).to(torch.float32).to(self.device)
|
|
|
+ base_img = transform(image).to(self.device)
|
|
|
|
|
|
+ # 半精度加速
|
|
|
+ self.model.half()
|
|
|
+ base_img = base_img.half()
|
|
|
+
|
|
|
# 批量特征提取 ------------------------------------------
|
|
|
all_features = []
|
|
|
for win_size in window_sizes:
|
|
@@ -241,7 +225,7 @@ class ImageSearchEngine:
|
|
|
num_w = (w - win_size) // stride + 1
|
|
|
|
|
|
# 调整窗口数量上限(防止显存溢出)
|
|
|
- MAX_WINDOWS = 16 # 最大窗口数
|
|
|
+ MAX_WINDOWS = 32 # 根据显存调整
|
|
|
if num_h * num_w > MAX_WINDOWS:
|
|
|
stride = int(np.sqrt(h * w * win_size**2 / MAX_WINDOWS))
|
|
|
num_h = (h - win_size) // stride + 1
|
|
@@ -260,11 +244,11 @@ class ImageSearchEngine:
|
|
|
continue
|
|
|
|
|
|
# 批量处理(自动分块防止OOM)
|
|
|
- BATCH_SIZE = 4 # 批处理大小
|
|
|
- with torch.no_grad():
|
|
|
+ BATCH_SIZE = 8 # 根据显存调整
|
|
|
+ with torch.no_grad(), torch.cuda.amp.autocast():
|
|
|
for i in range(0, len(windows), BATCH_SIZE):
|
|
|
- batch = torch.stack(windows[i:i+BATCH_SIZE]).to(torch.float32)
|
|
|
- features = self.model(batch).to(torch.float32)
|
|
|
+ batch = torch.stack(windows[i:i+BATCH_SIZE])
|
|
|
+ features = self.model(batch)
|
|
|
all_features.append(features.cpu().float()) # 转移至CPU释放显存
|
|
|
|
|
|
# 特征融合 ---------------------------------------------
|
|
@@ -274,10 +258,9 @@ class ImageSearchEngine:
|
|
|
final_feature = torch.cat([f.view(-1, f.shape[-1]) for f in all_features], dim=0)
|
|
|
final_feature = final_feature.mean(dim=0).to(self.device)
|
|
|
|
|
|
- return final_feature.float()
|
|
|
+ return final_feature
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"滑动窗口特征提取失败: {e}")
|
|
|
return None
|
|
|
|
|
|
-
|