import faiss import numpy as np from typing import List, Tuple, Optional # import torch from pymongo import MongoClient import datetime import time import requests import config import os def RequestAliyunFc(url): # 定义URL api = config.ALIYUN_FUNCTION_URL # 替换为实际的URL print("func api:",api) # 定义参数 payload = {'url': url} print("payload", payload) # headers = { # 'Content-Type': 'application/json' # 明确指定请求头为JSON # } # 发送POST请求 response = requests.post(api, json=payload) # 检查请求是否成功 if response.status_code == 200: # 获取返回的JSON数据 json_data = response.json() return json_data.get("vector") print(f'请求失败,状态码: {response.text}') return None class ImageSearchEngine: def __init__(self): mongodburi = os.getenv("MONGODB") if not mongodburi: mongodburi = "mongodb://root:hangzhou_manage_mat_2025@localhost:37018/" # 添加mongodb self.mongo_client = MongoClient(mongodburi) # MongoDB 连接字符串 self.mongo_db = self.mongo_client["faiss_index"] # 数据库名称 self.mongo_collection = self.mongo_db["mat_vectors"] # 集合名称 self.mongo_collection.create_index([("product_id", 1)], unique=True) self.mongo_collection.create_index([("faiss_id", 1)], unique=True) # 初始化一个id生成计数 self.faiss_id_max = 0 self.dimension = 2048 # 检查GPU是否可用(仅用于PyTorch模型) # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"使用设备: {self.device}") # 改为支持删除的索引 base_index = faiss.IndexFlatL2(self.dimension) self.index = faiss.IndexIDMap(base_index) # 尝试加载现有索引,如果不存在则创建新索引 if self._load_index(): print("成功加载现有索引") def _batch_generator(self, cursor, batch_size): """从MongoDB游标中分批生成数据""" batch = [] for doc in cursor: batch.append(doc) if len(batch) == batch_size: yield batch batch = [] if batch: yield batch def _process_image(self, image_url: str): try: # 这里调用阿里云云函数处理图片 return RequestAliyunFc(image_url) except Exception as e: print(f"处理图片时出错: {e}") return None def add_image_from_url(self, image_path: str, product_id: str) -> bool: """从URL添加图片到索引。 Args: url: 图片URL product_id: 图片对应的商品ID Returns: 添加成功返回True,失败返回False """ try: # 使用原有的特征提取逻辑 feature = self._process_image(image_path) if feature is None: print("无法提取特征") return False # 转换为numpy数组并添加到索引 # feature_np = feature.cpu().numpy().reshape(1, -1).astype('float32') feature_np = np.array(feature).reshape(1, -1).astype('float32') idx = self.faiss_id_max + 1 print(f"当前: idx { idx }") if not isinstance(idx, int) or idx <= 0: print("ID生成失败") return False self.faiss_id_max = idx #走upsert的逻辑 # 向数据库写入记录 record = { "faiss_id": idx, "product_id": product_id, "vector": feature_np.flatten().tolist(), # 将numpy数组转为列表 "created_at": datetime.datetime.utcnow() # 记录创建时间 } self.mongo_collection.insert_one(record) # 为向量设置ID并添加到Faiss索引 self.index.add_with_ids(feature_np, np.array([idx], dtype='int64')) print(f"已添加图片: product_id: {product_id}, faiss_id: {idx}") return True except Exception as e: print(f"添加图片时出错: {e}") return False def get_product_id_by_faiss_id(self, faiss_id: int) -> Optional[str]: """根据 faiss_id 查找 MongoDB 中的 product_id。 Args: faiss_id: Faiss 索引中的 ID Returns: 对应的 product_id,如果未找到则返回 None """ try: faiss_id = int(faiss_id) # 检查 faiss_id 是否有效 if faiss_id < 0: print(f"无效的 faiss_id: {faiss_id}") return None # 查询 MongoDB query = {"faiss_id": faiss_id} record = self.mongo_collection.find_one(query) # 检查是否找到记录 if record is None: print(f"未找到 faiss_id 为 {faiss_id} 的记录") return None # 返回 product_id product_id = record.get("product_id") if product_id is None: print(f"记录中缺少 product_id 字段: {record}") return None return str(product_id) # 确保返回字符串类型 except Exception as e: print(f"查询 faiss_id 为 {faiss_id} 的记录时出错: {e}") return None def search(self, image_path: str = None, top_k: int = 5) -> List[Tuple[str, float]]: try: if image_path is None: print("搜索图片下载失败!") return [] feature = self._process_image(image_path) if feature is None: print("无法提取查询图片的特征") return [] # 将特征转换为numpy数组 fn = np.array(feature) feature_np = fn.reshape(1, -1).astype('float32') start_vector_time = time.time() # 搜索最相似的图片 distances, indices = self.index.search(feature_np, min(top_k, self.index.ntotal)) end_vector_time = time.time() print(f"搜索vector耗时: {end_vector_time - start_vector_time}") start_other_time = time.time() # 返回结果 results = [] for faiss_id, dist in zip(indices[0], distances[0]): if faiss_id == -1: # Faiss返回-1表示无效结果 continue # 将距离转换为相似度分数(0-1之间,1表示完全相似) similarity = 1.0 / (1.0 + dist) # 根据faiss_id获取product_id print(f"搜索结果->faiss_id: { faiss_id }") product_id = self.get_product_id_by_faiss_id(faiss_id) if product_id: results.append((product_id, similarity)) end_other_time = time.time() print(f"查询结果耗时: {end_other_time - start_other_time}") return results except Exception as e: print(f"搜索图片时出错: {e}") return [] def _load_index(self) -> bool: """从数据库分批加载数据并初始化faiss_id_max""" try: # 配置参数 BATCH_SIZE = 10000 # 获取文档总数 total_docs = self.mongo_collection.count_documents({}) if total_docs == 0: print("数据库为空,跳过索引加载") return True # 空数据库不算错误 # 用于跟踪最大ID(兼容空数据情况) max_faiss_id = -1 # 分批加载数据 cursor = self.mongo_collection.find({}).batch_size(BATCH_SIZE) for batch in self._batch_generator(cursor, BATCH_SIZE): # 处理批次数据 batch_vectors = [] batch_ids = [] current_max = -1 for doc in batch: try: # 数据校验 if len(doc['vector']) != self.dimension: continue if not isinstance(doc['faiss_id'], int): continue # 提取数据 faiss_id = int(doc['faiss_id']) vector = doc['vector'] print(f"load faiss_id :{ faiss_id }") # 更新最大值 if faiss_id > current_max: current_max = faiss_id # 收集数据 batch_vectors.append(vector) batch_ids.append(faiss_id) except Exception as e: print(f"文档处理异常: {str(e)}") continue # 批量添加到索引 if batch_vectors: vectors_np = np.array(batch_vectors, dtype='float32') ids_np = np.array(batch_ids, dtype='int64') self.index.add_with_ids(vectors_np, ids_np) # 更新全局最大值 if current_max > max_faiss_id: max_faiss_id = current_max print(f"向量总数: {self.index.ntotal}") # 设置初始值(如果已有更大值则保留) if max_faiss_id != -1: new_id = max_faiss_id self.faiss_id_max = new_id print(f"ID计数器初始化完成,当前值: {new_id}") return True except Exception as e: print(f"索引加载失败: {str(e)}") return False def clear(self) -> bool: """清除所有索引和 MongoDB 中的记录。 Returns: 清除成功返回 True,失败返回 False """ try: # 检查索引是否支持重置操作 if not hasattr(self.index, "reset"): print("当前索引不支持重置操作") return False # 重置 Faiss 索引 self.index.reset() print("已清除 Faiss 索引中的所有向量") # 删除 MongoDB 中的所有记录 result = self.mongo_collection.delete_many({}) print(f"已从 MongoDB 中删除 {result.deleted_count} 条记录") self.faiss_id_max = 0 return True except Exception as e: print(f"清除索引时出错: {e}") return False def remove_by_product_id(self, product_id: str) -> bool: """通过 product_id 删除向量索引和数据库记录。 Args: product_id: 要删除的商品 ID Returns: 删除成功返回 True,失败返回 False """ try: # 检查 product_id 是否有效 if not product_id or not isinstance(product_id, str): print(f"无效的 product_id: {product_id}") return False # 查询 MongoDB 获取 faiss_id query = {"product_id": product_id} record = self.mongo_collection.find_one(query) # 检查是否找到记录 if record is None: print(f"未找到 product_id 为 {product_id} 的记录") return False # 提取 faiss_id faiss_id = record.get("faiss_id") if faiss_id is None: print(f"记录中缺少 faiss_id 字段: {record}") return False # 删除 Faiss 索引中的向量 if isinstance(self.index, faiss.IndexIDMap): # 检查 faiss_id 是否在索引中 # ids = self.index.id_map.at(1) # 获取所有 ID # if faiss_id not in ids: # print(f"faiss_id {faiss_id} 不在索引中") # return False # 删除向量 self.index.remove_ids(np.array([faiss_id], dtype='int64')) print(f"已从 Faiss 索引中删除 faiss_id: {faiss_id}") else: print("当前索引不支持删除操作") return False # 删除 MongoDB 中的记录 result = self.mongo_collection.delete_one({"faiss_id": faiss_id}) if result.deleted_count == 1: print(f"已从 MongoDB 中删除 faiss_id: {faiss_id}") return True else: print(f"未找到 faiss_id 为 {faiss_id} 的记录") return False except Exception as e: print(f"删除 product_id 为 {product_id} 的记录时出错: {e}") return False def get_index_size(self) -> int: """获取索引中的图片数量。 Returns: 索引中的图片数量 """ return len(self.image_paths)