|
@@ -27,8 +27,8 @@ class ImageSearchEngine:
|
|
|
# 初始化一个id生成计数
|
|
|
self.faiss_id_max = 0
|
|
|
|
|
|
- # 强制使用CPU设备
|
|
|
- self.device = torch.device("cpu")
|
|
|
+ # 检查GPU是否可用(仅用于PyTorch模型)
|
|
|
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"使用设备: {self.device}")
|
|
|
|
|
|
# 定义基础预处理转换
|
|
@@ -39,11 +39,12 @@ class ImageSearchEngine:
|
|
|
])
|
|
|
|
|
|
|
|
|
- # 加载预训练的ResNet模型并保持全精度
|
|
|
+ # 加载预训练的ResNet模型
|
|
|
self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
|
|
|
# 移除最后的全连接层
|
|
|
self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
|
|
|
- self.model = self.model.float().to(self.device).eval()
|
|
|
+ self.model = self.model.to(self.device)
|
|
|
+ self.model.eval()
|
|
|
|
|
|
# 初始化FAISS索引(2048是ResNet50的特征维度)
|
|
|
self.dimension = 2048
|
|
@@ -126,9 +127,6 @@ class ImageSearchEngine:
|
|
|
多尺度特征向量,处理失败返回None
|
|
|
"""
|
|
|
try:
|
|
|
- # 三重精度保障
|
|
|
- self.model = self.model.float()
|
|
|
-
|
|
|
# 获取原图信息
|
|
|
orig_w, orig_h = image.size
|
|
|
max_edge = max(orig_w, orig_h)
|
|
@@ -136,14 +134,14 @@ class ImageSearchEngine:
|
|
|
|
|
|
# 动态调整策略 -------------------------------------------
|
|
|
# 策略1:根据最大边长确定基准尺寸
|
|
|
- base_size = min(max_edge, 2048) # 最大尺寸限制
|
|
|
+ base_size = min(max_edge, 3000) # 不超过模型支持的最大尺寸
|
|
|
|
|
|
# 策略2:自动生成窗口尺寸(等比数列)
|
|
|
min_size = 224 # 最小特征尺寸
|
|
|
- num_scales = 3 # 采样点数
|
|
|
+ num_scales = 4 # 固定采样点数
|
|
|
scale_factors = np.logspace(0, 1, num_scales, base=2)
|
|
|
window_sizes = [int(base_size * f) for f in scale_factors]
|
|
|
- window_sizes = sorted({min(max(s, min_size), 2048) for s in window_sizes})
|
|
|
+ window_sizes = sorted({min(max(s, min_size), 3000) for s in window_sizes})
|
|
|
|
|
|
# 策略3:根据长宽比调整尺寸组合
|
|
|
if aspect_ratio > 1.5: # 宽幅图像
|
|
@@ -160,8 +158,9 @@ class ImageSearchEngine:
|
|
|
self.base_transform
|
|
|
])
|
|
|
|
|
|
- #
|
|
|
- img_base = base_transform(image).unsqueeze(0).to(torch.float32).to(self.device)
|
|
|
+ # 半精度加速
|
|
|
+ self.model.half()
|
|
|
+ img_base = base_transform(image).unsqueeze(0).to(self.device).half()
|
|
|
|
|
|
# 动态特征提取 ------------------------------------------
|
|
|
features = []
|
|
@@ -169,21 +168,21 @@ class ImageSearchEngine:
|
|
|
# 保持长宽比的重采样
|
|
|
target_size = (int(size*aspect_ratio), size) if aspect_ratio > 1 else (size, int(size/aspect_ratio))
|
|
|
|
|
|
- # CPU兼容的插值
|
|
|
+ # GPU加速的智能插值
|
|
|
img_tensor = torch.nn.functional.interpolate(
|
|
|
img_base,
|
|
|
size=target_size,
|
|
|
- mode='bilinear',
|
|
|
+ mode= 'area' if size < base_size else 'bicubic', # 下采样用area,上采样用bicubic
|
|
|
align_corners=False
|
|
|
- ).to(torch.float32)
|
|
|
+ )
|
|
|
|
|
|
# 自适应归一化(保持原图统计特性)
|
|
|
if hasattr(self, 'adaptive_normalize'):
|
|
|
img_tensor = self.adaptive_normalize(img_tensor)
|
|
|
|
|
|
# 混合精度推理
|
|
|
- with torch.no_grad():
|
|
|
- feature = self.model(img_tensor).to(torch.float32)
|
|
|
+ with torch.no_grad(), torch.cuda.amp.autocast():
|
|
|
+ feature = self.model(img_tensor)
|
|
|
|
|
|
features.append(feature.squeeze().float())
|
|
|
|
|
@@ -204,6 +203,48 @@ class ImageSearchEngine:
|
|
|
|
|
|
|
|
|
|
|
|
+ def _extract_multi_scale_features_bak(self, image: Image.Image) -> Optional[torch.Tensor]:
|
|
|
+ """提取多尺度特征。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: PIL图片对象
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 多尺度特征向量,如果处理失败返回None
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ features_list = []
|
|
|
+ window_sizes = [256, 512,1024,1560,2048,2560,3000]
|
|
|
+ # 多尺度转换 - 增加更多尺度
|
|
|
+ #self.multi_scale_sizes = [224, 384, 512, 768, 1024, 1536,2048,3000]
|
|
|
+ for size in window_sizes:
|
|
|
+ # 调整图片大小
|
|
|
+ transform = transforms.Compose([
|
|
|
+ transforms.Resize((size, size), interpolation=transforms.InterpolationMode.LANCZOS),
|
|
|
+ self.base_transform
|
|
|
+ ])
|
|
|
+
|
|
|
+ # 应用变换
|
|
|
+ img_tensor = transform(image).unsqueeze(0).to(self.device)
|
|
|
+
|
|
|
+ # 提取特征
|
|
|
+ with torch.no_grad():
|
|
|
+ feature = self.model(img_tensor)
|
|
|
+
|
|
|
+ features_list.append(feature.squeeze())
|
|
|
+
|
|
|
+ # 计算加权平均,较大尺度的权重更高
|
|
|
+ weights = torch.linspace(1, 2, len(features_list)).to(self.device)
|
|
|
+ weights = weights / weights.sum()
|
|
|
+
|
|
|
+ weighted_features = torch.stack([f * w for f, w in zip(features_list, weights)])
|
|
|
+ final_feature = weighted_features.sum(dim=0)
|
|
|
+
|
|
|
+ return final_feature
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"提取多尺度特征时出错: {e}")
|
|
|
+ return None
|
|
|
|
|
|
def _extract_sliding_window_features(self, image: Image.Image) -> Optional[torch.Tensor]:
|
|
|
"""优化版滑动窗口特征提取(动态调整+批量处理)
|
|
@@ -215,9 +256,6 @@ class ImageSearchEngine:
|
|
|
滑动窗口特征向量,处理失败返回None
|
|
|
"""
|
|
|
try:
|
|
|
- # 三重精度保障
|
|
|
- self.model = self.model.float()
|
|
|
-
|
|
|
# 获取原图信息
|
|
|
orig_w, orig_h = image.size
|
|
|
aspect_ratio = orig_w / orig_h
|
|
@@ -229,10 +267,10 @@ class ImageSearchEngine:
|
|
|
int(2 ** np.round(np.log2(max_dim * 0.1))), # 约10%尺寸
|
|
|
int(2 ** np.floor(np.log2(max_dim * 0.5))), # 约50%尺寸
|
|
|
int(2 ** np.ceil(np.log2(max_dim))) # 接近原图尺寸
|
|
|
- } & {256, 512, 1024, 2048}) # 与预设尺寸取交集
|
|
|
+ } & {256, 512, 1024, 2048, 3000}) # 与预设尺寸取交集
|
|
|
|
|
|
# 智能步长调整(窗口尺寸越大步长越大)
|
|
|
- stride_ratios = {256:0.5, 512:0.4, 1024:0.3, 2048:0.2}
|
|
|
+ stride_ratios = {256:0.5, 512:0.4, 1024:0.3, 2048:0.2, 3000:0.15}
|
|
|
|
|
|
# 预处理优化 --------------------------------------------
|
|
|
# 生成基准图像(最大窗口尺寸)
|
|
@@ -241,11 +279,15 @@ class ImageSearchEngine:
|
|
|
(max_win_size, int(max_win_size / aspect_ratio))
|
|
|
|
|
|
transform = transforms.Compose([
|
|
|
- transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.BILINEAR),
|
|
|
+ transforms.Resize(base_size[::-1], interpolation=transforms.InterpolationMode.LANCZOS),
|
|
|
self.base_transform
|
|
|
])
|
|
|
- base_img = transform(image).to(torch.float32).to(self.device)
|
|
|
+ base_img = transform(image).to(self.device)
|
|
|
|
|
|
+ # 半精度加速
|
|
|
+ self.model.half()
|
|
|
+ base_img = base_img.half()
|
|
|
+
|
|
|
# 批量特征提取 ------------------------------------------
|
|
|
all_features = []
|
|
|
for win_size in window_sizes:
|
|
@@ -258,7 +300,7 @@ class ImageSearchEngine:
|
|
|
num_w = (w - win_size) // stride + 1
|
|
|
|
|
|
# 调整窗口数量上限(防止显存溢出)
|
|
|
- MAX_WINDOWS = 16 # 最大窗口数
|
|
|
+ MAX_WINDOWS = 32 # 根据显存调整
|
|
|
if num_h * num_w > MAX_WINDOWS:
|
|
|
stride = int(np.sqrt(h * w * win_size**2 / MAX_WINDOWS))
|
|
|
num_h = (h - win_size) // stride + 1
|
|
@@ -277,11 +319,11 @@ class ImageSearchEngine:
|
|
|
continue
|
|
|
|
|
|
# 批量处理(自动分块防止OOM)
|
|
|
- BATCH_SIZE = 4 # 批处理大小
|
|
|
- with torch.no_grad():
|
|
|
+ BATCH_SIZE = 8 # 根据显存调整
|
|
|
+ with torch.no_grad(), torch.cuda.amp.autocast():
|
|
|
for i in range(0, len(windows), BATCH_SIZE):
|
|
|
- batch = torch.stack(windows[i:i+BATCH_SIZE]).to(torch.float32)
|
|
|
- features = self.model(batch).to(torch.float32)
|
|
|
+ batch = torch.stack(windows[i:i+BATCH_SIZE])
|
|
|
+ features = self.model(batch)
|
|
|
all_features.append(features.cpu().float()) # 转移至CPU释放显存
|
|
|
|
|
|
# 特征融合 ---------------------------------------------
|
|
@@ -291,13 +333,72 @@ class ImageSearchEngine:
|
|
|
final_feature = torch.cat([f.view(-1, f.shape[-1]) for f in all_features], dim=0)
|
|
|
final_feature = final_feature.mean(dim=0).to(self.device)
|
|
|
|
|
|
- return final_feature.float()
|
|
|
+ return final_feature
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"滑动窗口特征提取失败: {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
+ def _extract_sliding_window_features_bak(self, image: Image.Image) -> Optional[torch.Tensor]:
|
|
|
+ """使用滑动窗口提取特征。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: PIL图片对象
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 滑动窗口特征向量,如果处理失败返回None
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ window_sizes = [256, 512,1024,1560,2048,2560,3000]
|
|
|
+ stride_ratio = 0.25 # 步长比例
|
|
|
+
|
|
|
+ features_list = []
|
|
|
+
|
|
|
+ for window_size in window_sizes:
|
|
|
+ # 调整图片大小,保持宽高比
|
|
|
+ aspect_ratio = image.size[0] / image.size[1]
|
|
|
+ if aspect_ratio > 1:
|
|
|
+ new_width = int(window_size * aspect_ratio)
|
|
|
+ new_height = window_size
|
|
|
+ else:
|
|
|
+ new_width = window_size
|
|
|
+ new_height = int(window_size / aspect_ratio)
|
|
|
+
|
|
|
+ transform = transforms.Compose([
|
|
|
+ transforms.Resize((new_height, new_width), interpolation=transforms.InterpolationMode.LANCZOS),
|
|
|
+ self.base_transform
|
|
|
+ ])
|
|
|
+
|
|
|
+ # 转换图片
|
|
|
+ img_tensor = transform(image)
|
|
|
+
|
|
|
+ # 计算步长
|
|
|
+ stride = int(window_size * stride_ratio)
|
|
|
+
|
|
|
+ # 使用滑动窗口提取特征
|
|
|
+ for i in range(0, img_tensor.size(1) - window_size + 1, stride):
|
|
|
+ for j in range(0, img_tensor.size(2) - window_size + 1, stride):
|
|
|
+ window = img_tensor[:, i:i+window_size, j:j+window_size].unsqueeze(0).to(self.device)
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ feature = self.model(window)
|
|
|
+
|
|
|
+ features_list.append(feature.squeeze())
|
|
|
+
|
|
|
+ # 如果没有提取到特征,返回None
|
|
|
+ if not features_list:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 计算所有特征的平均值
|
|
|
+ final_feature = torch.stack(features_list).mean(dim=0)
|
|
|
+
|
|
|
+ return final_feature
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"提取滑动窗口特征时出错: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
def extract_features(self, img: Image.Image) -> np.ndarray:
|
|
|
"""结合多尺度和滑动窗口提取特征。
|
|
|
|
|
@@ -663,6 +764,7 @@ class ImageSearchEngine:
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"删除 product_id 为 {product_id} 的记录时出错: {e}")
|
|
|
+ traceback.print_exc()
|
|
|
return False
|
|
|
|
|
|
def get_index_size(self) -> int:
|