123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384 |
- import faiss
- import numpy as np
- from typing import List, Tuple, Optional
- # import torch
- from pymongo import MongoClient
- import datetime
- import time
- import requests
- import config
- import os
- def RequestAliyunFc(url):
- # 定义URL
- api = config.ALIYUN_FUNCTION_URL # 替换为实际的URL
- print("func api:",api)
- # 定义参数
- payload = {'url': url}
- print("payload", payload)
- # headers = {
- # 'Content-Type': 'application/json' # 明确指定请求头为JSON
- # }
- # 发送POST请求
- response = requests.post(api, json=payload)
- # 检查请求是否成功
- if response.status_code == 200:
- # 获取返回的JSON数据
- json_data = response.json()
- return json_data.get("vector")
-
- print(f'请求失败,状态码: {response.text}')
- return None
- class ImageSearchEngine:
- def __init__(self):
- mongodburi = os.getenv("MONGODB")
- if not mongodburi:
- mongodburi = "mongodb://root:hangzhou_manage_mat_2025@localhost:37018/"
- # 添加mongodb
- self.mongo_client = MongoClient(mongodburi) # MongoDB 连接字符串
- self.mongo_db = self.mongo_client["faiss_index"] # 数据库名称
- self.mongo_collection = self.mongo_db["mat_vectors"] # 集合名称
- self.mongo_collection.create_index([("product_id", 1)], unique=True)
- self.mongo_collection.create_index([("faiss_id", 1)], unique=True)
- # 初始化一个id生成计数
- self.faiss_id_max = 0
- self.dimension = 2048
- # 检查GPU是否可用(仅用于PyTorch模型)
- # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # print(f"使用设备: {self.device}")
-
- # 改为支持删除的索引
- base_index = faiss.IndexFlatL2(self.dimension)
- self.index = faiss.IndexIDMap(base_index)
-
-
- # 尝试加载现有索引,如果不存在则创建新索引
- if self._load_index():
- print("成功加载现有索引")
- def _batch_generator(self, cursor, batch_size):
- """从MongoDB游标中分批生成数据"""
- batch = []
- for doc in cursor:
- batch.append(doc)
- if len(batch) == batch_size:
- yield batch
- batch = []
- if batch:
- yield batch
- def _process_image(self, image_url: str):
- try:
- # 这里调用阿里云云函数处理图片
- return RequestAliyunFc(image_url)
- except Exception as e:
- print(f"处理图片时出错: {e}")
- return None
- def add_image_from_url(self, image_path: str, product_id: str) -> bool:
- """从URL添加图片到索引。
-
- Args:
- url: 图片URL
- product_id: 图片对应的商品ID
-
- Returns:
- 添加成功返回True,失败返回False
- """
- try:
- # 使用原有的特征提取逻辑
- feature = self._process_image(image_path)
- if feature is None:
- print("无法提取特征")
- return False
-
- # 转换为numpy数组并添加到索引
- # feature_np = feature.cpu().numpy().reshape(1, -1).astype('float32')
- feature_np = np.array(feature).reshape(1, -1).astype('float32')
- idx = self.faiss_id_max + 1
- print(f"当前: idx { idx }")
- if not isinstance(idx, int) or idx <= 0:
- print("ID生成失败")
- return False
- self.faiss_id_max = idx
-
- #走upsert的逻辑
- # 向数据库写入记录
- record = {
- "faiss_id": idx,
- "product_id": product_id,
- "vector": feature_np.flatten().tolist(), # 将numpy数组转为列表
- "created_at": datetime.datetime.utcnow() # 记录创建时间
- }
- self.mongo_collection.insert_one(record)
- # 为向量设置ID并添加到Faiss索引
- self.index.add_with_ids(feature_np, np.array([idx], dtype='int64'))
-
- print(f"已添加图片: product_id: {product_id}, faiss_id: {idx}")
- return True
- except Exception as e:
- print(f"添加图片时出错: {e}")
- return False
- def get_product_id_by_faiss_id(self, faiss_id: int) -> Optional[str]:
- """根据 faiss_id 查找 MongoDB 中的 product_id。
-
- Args:
- faiss_id: Faiss 索引中的 ID
-
- Returns:
- 对应的 product_id,如果未找到则返回 None
- """
- try:
- faiss_id = int(faiss_id)
- # 检查 faiss_id 是否有效
- if faiss_id < 0:
- print(f"无效的 faiss_id: {faiss_id}")
- return None
- # 查询 MongoDB
- query = {"faiss_id": faiss_id}
- record = self.mongo_collection.find_one(query)
- # 检查是否找到记录
- if record is None:
- print(f"未找到 faiss_id 为 {faiss_id} 的记录")
- return None
- # 返回 product_id
- product_id = record.get("product_id")
- if product_id is None:
- print(f"记录中缺少 product_id 字段: {record}")
- return None
- return str(product_id) # 确保返回字符串类型
- except Exception as e:
- print(f"查询 faiss_id 为 {faiss_id} 的记录时出错: {e}")
- return None
- def search(self, image_path: str = None, top_k: int = 5) -> List[Tuple[str, float]]:
- try:
- if image_path is None:
- print("搜索图片下载失败!")
- return []
- feature = self._process_image(image_path)
- if feature is None:
- print("无法提取查询图片的特征")
- return []
-
- # 将特征转换为numpy数组
- fn = np.array(feature)
- feature_np = fn.reshape(1, -1).astype('float32')
-
- start_vector_time = time.time()
- # 搜索最相似的图片
- distances, indices = self.index.search(feature_np, min(top_k, self.index.ntotal))
- end_vector_time = time.time()
- print(f"搜索vector耗时: {end_vector_time - start_vector_time}")
- start_other_time = time.time()
- # 返回结果
- results = []
- for faiss_id, dist in zip(indices[0], distances[0]):
- if faiss_id == -1: # Faiss返回-1表示无效结果
- continue
-
- # 将距离转换为相似度分数(0-1之间,1表示完全相似)
- similarity = 1.0 / (1.0 + dist)
-
- # 根据faiss_id获取product_id
- print(f"搜索结果->faiss_id: { faiss_id }")
- product_id = self.get_product_id_by_faiss_id(faiss_id)
- if product_id:
- results.append((product_id, similarity))
- end_other_time = time.time()
- print(f"查询结果耗时: {end_other_time - start_other_time}")
- return results
- except Exception as e:
- print(f"搜索图片时出错: {e}")
- return []
- def _load_index(self) -> bool:
- """从数据库分批加载数据并初始化faiss_id_max"""
- try:
- # 配置参数
- BATCH_SIZE = 10000
- # 获取文档总数
- total_docs = self.mongo_collection.count_documents({})
- if total_docs == 0:
- print("数据库为空,跳过索引加载")
- return True # 空数据库不算错误
- # 用于跟踪最大ID(兼容空数据情况)
- max_faiss_id = -1
-
- # 分批加载数据
- cursor = self.mongo_collection.find({}).batch_size(BATCH_SIZE)
- for batch in self._batch_generator(cursor, BATCH_SIZE):
- # 处理批次数据
- batch_vectors = []
- batch_ids = []
- current_max = -1
-
- for doc in batch:
- try:
- # 数据校验
- if len(doc['vector']) != self.dimension:
- continue
- if not isinstance(doc['faiss_id'], int):
- continue
-
- # 提取数据
- faiss_id = int(doc['faiss_id'])
- vector = doc['vector']
- print(f"load faiss_id :{ faiss_id }")
-
- # 更新最大值
- if faiss_id > current_max:
- current_max = faiss_id
-
- # 收集数据
- batch_vectors.append(vector)
- batch_ids.append(faiss_id)
- except Exception as e:
- print(f"文档处理异常: {str(e)}")
- continue
- # 批量添加到索引
- if batch_vectors:
- vectors_np = np.array(batch_vectors, dtype='float32')
- ids_np = np.array(batch_ids, dtype='int64')
- self.index.add_with_ids(vectors_np, ids_np)
-
- # 更新全局最大值
- if current_max > max_faiss_id:
- max_faiss_id = current_max
- print(f"向量总数: {self.index.ntotal}")
-
- # 设置初始值(如果已有更大值则保留)
- if max_faiss_id != -1:
- new_id = max_faiss_id
- self.faiss_id_max = new_id
- print(f"ID计数器初始化完成,当前值: {new_id}")
-
- return True
- except Exception as e:
- print(f"索引加载失败: {str(e)}")
- return False
-
- def clear(self) -> bool:
- """清除所有索引和 MongoDB 中的记录。
-
- Returns:
- 清除成功返回 True,失败返回 False
- """
- try:
- # 检查索引是否支持重置操作
- if not hasattr(self.index, "reset"):
- print("当前索引不支持重置操作")
- return False
- # 重置 Faiss 索引
- self.index.reset()
- print("已清除 Faiss 索引中的所有向量")
- # 删除 MongoDB 中的所有记录
- result = self.mongo_collection.delete_many({})
- print(f"已从 MongoDB 中删除 {result.deleted_count} 条记录")
- self.faiss_id_max = 0
- return True
- except Exception as e:
- print(f"清除索引时出错: {e}")
- return False
- def remove_by_product_id(self, product_id: str) -> bool:
- """通过 product_id 删除向量索引和数据库记录。
-
- Args:
- product_id: 要删除的商品 ID
-
- Returns:
- 删除成功返回 True,失败返回 False
- """
- try:
- # 检查 product_id 是否有效
- if not product_id or not isinstance(product_id, str):
- print(f"无效的 product_id: {product_id}")
- return False
- # 查询 MongoDB 获取 faiss_id
- query = {"product_id": product_id}
- record = self.mongo_collection.find_one(query)
- # 检查是否找到记录
- if record is None:
- print(f"未找到 product_id 为 {product_id} 的记录")
- return False
- # 提取 faiss_id
- faiss_id = record.get("faiss_id")
- if faiss_id is None:
- print(f"记录中缺少 faiss_id 字段: {record}")
- return False
- # 删除 Faiss 索引中的向量
- if isinstance(self.index, faiss.IndexIDMap):
- # 检查 faiss_id 是否在索引中
- # ids = self.index.id_map.at(1) # 获取所有 ID
- # if faiss_id not in ids:
- # print(f"faiss_id {faiss_id} 不在索引中")
- # return False
- # 删除向量
- self.index.remove_ids(np.array([faiss_id], dtype='int64'))
- print(f"已从 Faiss 索引中删除 faiss_id: {faiss_id}")
- else:
- print("当前索引不支持删除操作")
- return False
- # 删除 MongoDB 中的记录
- result = self.mongo_collection.delete_one({"faiss_id": faiss_id})
- if result.deleted_count == 1:
- print(f"已从 MongoDB 中删除 faiss_id: {faiss_id}")
- return True
- else:
- print(f"未找到 faiss_id 为 {faiss_id} 的记录")
- return False
- except Exception as e:
- print(f"删除 product_id 为 {product_id} 的记录时出错: {e}")
- return False
- def get_index_size(self) -> int:
- """获取索引中的图片数量。
-
- Returns:
- 索引中的图片数量
- """
- return len(self.image_paths)
|