main.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import uvicorn
  2. import os
  3. from diskcache import Cache
  4. from fastapi import FastAPI, File, UploadFile
  5. from fastapi.param_functions import Form
  6. from starlette.middleware.cors import CORSMiddleware
  7. from starlette.responses import FileResponse
  8. from milvus_helpers import MilvusHelper
  9. from config import TOP_K, UPLOAD_PATH
  10. from encode import ResNet50
  11. from operators import do_delete, do_load, do_upload, do_search, do_count, do_drop
  12. from logs import LOGGER
  13. from pydantic import BaseModel
  14. from typing import Optional
  15. from urllib.request import urlretrieve
  16. app = FastAPI()
  17. origins = ["*"]
  18. app.add_middleware(
  19. CORSMiddleware,
  20. allow_origins=origins,
  21. allow_credentials=True,
  22. allow_methods=["*"],
  23. allow_headers=["*"],
  24. )
  25. MODEL = ResNet50()
  26. MILVUS_CLI = MilvusHelper()
  27. # Mkdir '/tmp/search-images'
  28. if not os.path.exists(UPLOAD_PATH):
  29. os.makedirs(UPLOAD_PATH)
  30. LOGGER.info(f"mkdir the path:{UPLOAD_PATH}")
  31. # @app.get('/data')
  32. # def get_img(image_path):
  33. # # Get the image file
  34. # try:
  35. # LOGGER.info(f"Successfully load image: {image_path}")
  36. # return FileResponse(image_path)
  37. # except Exception as e:
  38. # LOGGER.error(f"Get image error: {e}")
  39. # return {'status': False, 'msg': e}, 400
  40. # @app.get('/progress')
  41. # def get_progress():
  42. # # Get the progress of dealing with images
  43. # try:
  44. # cache = Cache('./tmp')
  45. # return f"current: {cache['current']}, total: {cache['total']}"
  46. # except Exception as e:
  47. # LOGGER.error(f"upload image error: {e}")
  48. # return {'status': False, 'msg': e}, 400
  49. # class Item(BaseModel):
  50. # Table: Optional[str] = None
  51. # File: str
  52. # @app.post('/img/load')
  53. # async def load_images(item: Item):
  54. # # Insert all the image under the file path to Milvus/MySQL
  55. # try:
  56. # total_num = do_load(item.Table, item.File, MODEL, MILVUS_CLI)
  57. # LOGGER.info(f"Successfully loaded data, total count: {total_num}")
  58. # return "Successfully loaded data!"
  59. # except Exception as e:
  60. # LOGGER.error(e)
  61. # return {'status': False, 'msg': e}, 400
  62. # file:image,query:id,url,table_name
  63. @app.post('/img/upload')
  64. async def upload_images(
  65. table_name: str = None,
  66. partition_name: str = None,
  67. image: UploadFile = File(None),
  68. im_hash: str = None,
  69. product_id: str = None,
  70. url: str = None
  71. ):
  72. try:
  73. if image is not None:
  74. content = await image.read()
  75. img_path = os.path.join(UPLOAD_PATH, image.filename)
  76. with open(img_path, "wb+") as f:
  77. f.write(content)
  78. elif url is not None:
  79. img_path = os.path.join(UPLOAD_PATH, os.path.basename(url))
  80. urlretrieve(url, img_path)
  81. else:
  82. return {'status': False, 'msg': 'Image and url are required'}, 400
  83. vector_id = do_upload(table_name,partition_name,im_hash, product_id, img_path, MODEL, MILVUS_CLI)
  84. LOGGER.info(f"Successfully uploaded data, vector id: {vector_id}")
  85. return {'id':vector_id[0]}
  86. except Exception as e:
  87. LOGGER.error(e)
  88. return {'status': False, 'msg': e}, 400
  89. # file:image,form:limit,table_name
  90. @app.post('/img/search')
  91. async def search_images(
  92. table_name: str = None,
  93. partition_name: str = None,
  94. image: UploadFile = File(...),
  95. limit: int = TOP_K
  96. ):
  97. try:
  98. content = await image.read()
  99. img_path = os.path.join(UPLOAD_PATH, image.filename)
  100. with open(img_path, "wb+") as f:
  101. f.write(content)
  102. res = do_search(table_name,partition_name,img_path, limit, MODEL, MILVUS_CLI)
  103. list = []
  104. for hits in res:
  105. for hit in hits:
  106. list.append({"im_hash":hit.id,"product_id":hit.entity.product_id,"score":hit.distance})
  107. # res = dict(zip(paths, distances))
  108. # res = sorted(res.items(), key=lambda item: item[1])
  109. LOGGER.info("Successfully searched similar images!")
  110. return list
  111. except Exception as e:
  112. LOGGER.error(e)
  113. return {'status': False, 'msg': e}, 400
  114. @app.post('/delete/record')
  115. async def delete_record(table_name: str = None,partition_name: str=None,expr: str=None):
  116. if len(expr) < 1:
  117. return {'status': False, 'msg': 'expr is not empty'}
  118. try:
  119. res = do_delete(table_name,partition_name,expr)
  120. print(res)
  121. return {'status': True}
  122. except Exception as e:
  123. LOGGER.error(e)
  124. return {'status': False, 'msg': e}, 400
  125. @app.post('/collection/count')
  126. async def count_images(table_name: str = None):
  127. # Returns the total number of images in the system
  128. try:
  129. num = do_count(table_name, MILVUS_CLI)
  130. LOGGER.info("Successfully count the number of images!")
  131. return num
  132. except Exception as e:
  133. LOGGER.error(e)
  134. return {'status': False, 'msg': e}, 400
  135. @app.post('/collection/drop')
  136. async def drop_tables(table_name: str = None):
  137. # Delete the collection of Milvus and MySQL
  138. try:
  139. status = do_drop(table_name, MILVUS_CLI)
  140. LOGGER.info("Successfully drop tables in Milvus and MySQL!")
  141. return status
  142. except Exception as e:
  143. LOGGER.error(e)
  144. return {'status': False, 'msg': e}, 400
  145. if __name__ == '__main__':
  146. uvicorn.run(app=app, host='0.0.0.0', port=5000)