app.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. from flask import Flask, request, jsonify, render_template, send_from_directory
  2. import os
  3. from image_search import ImageSearchEngine
  4. import magic
  5. from urllib.request import urlretrieve
  6. import time
  7. from . import config
  8. app = Flask(__name__)
  9. os.makedirs(config.UPLOAD_FOLDER, exist_ok=True)
  10. # 初始化图像搜索引擎
  11. search_engine = ImageSearchEngine()
  12. def allowed_file(filename):
  13. """检查文件是否允许上传"""
  14. return '.' in filename and filename.rsplit('.', 1)[1].lower() in config.ALLOWED_EXTENSIONS
  15. def is_valid_image(file_path):
  16. """检查文件是否为有效的图片文件"""
  17. try:
  18. mime = magic.Magic(mime=True)
  19. file_type = mime.from_file(file_path)
  20. return file_type.startswith('image/')
  21. except Exception:
  22. return False
  23. @app.route('/')
  24. def index():
  25. """渲染主页"""
  26. return render_template('index.html')
  27. @app.route('/upload', methods=['POST'])
  28. def upload_file():
  29. """处理图片上传请求"""
  30. try:
  31. # 获取并验证参数
  32. data = request.get_json()
  33. if not data:
  34. return jsonify({'error': '请求必须包含JSON数据'}), 400
  35. product_id = data.get('product_id')
  36. image_url = data.get('url')
  37. if not product_id:
  38. return jsonify({'error': '缺少product_id参数'}), 400
  39. if not image_url:
  40. return jsonify({'error': '缺少url参数'}), 400
  41. if not isinstance(image_url, str) or not (image_url.startswith('http://') or image_url.startswith('https://')):
  42. return jsonify({'error': '无效的图片URL'}), 400
  43. time_base = str(time.time() * 1000)
  44. image_path = os.path.join(config.UPLOAD_FOLDER, time_base + os.path.basename(image_url))
  45. urlretrieve(image_url, image_path)
  46. if allowed_file(image_path) and is_valid_image(image_path):
  47. # 添加图片到搜索引擎
  48. if search_engine.add_image_from_url(image_path, product_id):
  49. # 删除临时图片文件
  50. os.remove(image_path)
  51. return jsonify({'message': '上传成功', 'product_id': product_id})
  52. else:
  53. return jsonify({'error': '处理图片失败'}), 500
  54. else:
  55. return jsonify({'error': '图片错误'}), 500
  56. except Exception as e:
  57. return jsonify({'error': str(e)}), 500
  58. @app.route('/search', methods=['POST'])
  59. def search():
  60. """处理图片搜索请求"""
  61. try:
  62. # 获取并验证参数
  63. data = request.get_json()
  64. if not data:
  65. return jsonify({'error': '请求必须包含JSON数据'}), 400
  66. image_url = data.get('url')
  67. if not image_url:
  68. return jsonify({'error': '缺少url参数'}), 400
  69. if not isinstance(image_url, str) or not (image_url.startswith('http://') or image_url.startswith('https://')):
  70. return jsonify({'error': '无效的图片URL'}), 400
  71. # 获取可选参数
  72. limit = data.get('limit', config.TOP_K)
  73. min_score = data.get('min_score', config.MIN_SCORE)
  74. max_score = data.get('max_score', config.MAX_SCORE)
  75. # 验证参数类型和范围
  76. try:
  77. limit = int(limit)
  78. min_score = float(min_score)
  79. max_score = float(max_score)
  80. if limit <= 0:
  81. return jsonify({'error': 'limit必须大于0'}), 400
  82. if min_score < 0 or min_score > 100:
  83. return jsonify({'error': 'min_score必须在0到100之间'}), 400
  84. if max_score < 0 or max_score > 100:
  85. return jsonify({'error': 'max_score必须在0到100之间'}), 400
  86. if min_score > max_score:
  87. return jsonify({'error': 'min_score不能大于max_score'}), 400
  88. except ValueError:
  89. return jsonify({'error': '参数类型错误'}), 400
  90. start_download_time = time.time()
  91. time_base = str(time.time() * 1000)
  92. image_path = os.path.join(config.UPLOAD_FOLDER, time_base + os.path.basename(image_url))
  93. urlretrieve(image_url, image_path)
  94. end_download_time = time.time()
  95. print(f"下载耗时: { end_download_time - start_download_time } s",)
  96. if allowed_file(image_path) and is_valid_image(image_path):
  97. # 搜索相似图片,使用用户指定的 limit
  98. start_search_time = time.time()
  99. results = search_engine.search(image_path, top_k=limit)
  100. os.remove(image_path)
  101. # 格式化结果并过滤不在分数范围内的结果
  102. formatted_results = []
  103. for product_id, score in results:
  104. similarity = round(score * 100, 2) # 转换为百分比
  105. if min_score <= similarity <= max_score:
  106. formatted_results.append({
  107. 'product_id': product_id,
  108. 'similarity': similarity
  109. })
  110. end_search_time = time.time()
  111. print(f"搜索耗时: { end_search_time - start_search_time } s",)
  112. return jsonify({'results': formatted_results})
  113. else:
  114. return jsonify({'error': '图片错误'})
  115. except Exception as e:
  116. return jsonify({'error': str(e)}), 500
  117. # @app.route('/images')
  118. # def list_images():
  119. # """列出所有已索引的图片"""
  120. # try:
  121. # images = []
  122. # for path in search_engine.image_paths:
  123. # filename = os.path.basename(path)
  124. # images.append({
  125. # 'filename': filename,
  126. # 'path': '/static/images/' + filename
  127. # })
  128. # return jsonify({'images': images})
  129. # except Exception as e:
  130. # return jsonify({'error': str(e)}), 500
  131. @app.route('/remove/<product_id>', methods=['POST'])
  132. def remove_image(product_id):
  133. """移除指定商品ID的图片特征"""
  134. try:
  135. # 从索引中移除
  136. if search_engine.remove_by_product_id(product_id):
  137. return jsonify({'message': '删除成功', 'product_id': product_id})
  138. else:
  139. return jsonify({'error': '商品ID不存在或删除失败'}), 404
  140. except Exception as e:
  141. return jsonify({'error': str(e)}), 500
  142. @app.route('/clear/<password>', methods=['POST'])
  143. def clear_index(password):
  144. """清空索引"""
  145. if password == "infish@2025":
  146. try:
  147. search_engine.clear()
  148. except Exception as e:
  149. return jsonify({'error': str(e)}), 500
  150. if __name__ == '__main__':
  151. app.run(host='0.0.0.0', port=5000, debug=False)