123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- import uvicorn
- import os
- from diskcache import Cache
- from fastapi import FastAPI, File, UploadFile
- from fastapi.param_functions import Form
- from starlette.middleware.cors import CORSMiddleware
- from starlette.responses import FileResponse
- from milvus_helpers import MilvusHelper
- from config import TOP_K, UPLOAD_PATH
- from encode import ResNet50
- from operators import do_delete, do_load, do_upload, do_search, do_count, do_drop
- from logs import LOGGER
- from pydantic import BaseModel
- from typing import Optional
- from urllib.request import urlretrieve
- app = FastAPI()
- origins = ["*"]
- app.add_middleware(
- CORSMiddleware,
- allow_origins=origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- MODEL = ResNet50()
- MILVUS_CLI = MilvusHelper()
- # Mkdir '/tmp/search-images'
- if not os.path.exists(UPLOAD_PATH):
- os.makedirs(UPLOAD_PATH)
- LOGGER.info(f"mkdir the path:{UPLOAD_PATH}")
- # @app.get('/data')
- # def get_img(image_path):
- # # Get the image file
- # try:
- # LOGGER.info(f"Successfully load image: {image_path}")
- # return FileResponse(image_path)
- # except Exception as e:
- # LOGGER.error(f"Get image error: {e}")
- # return {'status': False, 'msg': e}, 400
- # @app.get('/progress')
- # def get_progress():
- # # Get the progress of dealing with images
- # try:
- # cache = Cache('./tmp')
- # return f"current: {cache['current']}, total: {cache['total']}"
- # except Exception as e:
- # LOGGER.error(f"upload image error: {e}")
- # return {'status': False, 'msg': e}, 400
- # class Item(BaseModel):
- # Table: Optional[str] = None
- # File: str
- # @app.post('/img/load')
- # async def load_images(item: Item):
- # # Insert all the image under the file path to Milvus/MySQL
- # try:
- # total_num = do_load(item.Table, item.File, MODEL, MILVUS_CLI)
- # LOGGER.info(f"Successfully loaded data, total count: {total_num}")
- # return "Successfully loaded data!"
- # except Exception as e:
- # LOGGER.error(e)
- # return {'status': False, 'msg': e}, 400
- # file:image,query:id,url,table_name
- @app.post('/img/upload')
- async def upload_images(
- table_name: str = None,
- partition_name: str = None,
- image: UploadFile = File(None),
- im_hash: str = None,
- product_id: str = None,
- url: str = None
- ):
- try:
- if image is not None:
- content = await image.read()
- img_path = os.path.join(UPLOAD_PATH, image.filename)
- with open(img_path, "wb+") as f:
- f.write(content)
- elif url is not None:
- img_path = os.path.join(UPLOAD_PATH, os.path.basename(url))
- urlretrieve(url, img_path)
- else:
- return {'status': False, 'msg': 'Image and url are required'}, 400
- vector_id = do_upload(table_name,partition_name,im_hash, product_id, img_path, MODEL, MILVUS_CLI)
- LOGGER.info(f"Successfully uploaded data, vector id: {vector_id}")
- return {'id':vector_id[0]}
- except Exception as e:
- LOGGER.error(e)
- return {'status': False, 'msg': e}, 400
- # file:image,form:limit,table_name
- @app.post('/img/search')
- async def search_images(
- table_name: str = None,
- partition_name: str = None,
- image: UploadFile = File(...),
- limit: int = TOP_K
- ):
- try:
- content = await image.read()
- img_path = os.path.join(UPLOAD_PATH, image.filename)
- with open(img_path, "wb+") as f:
- f.write(content)
- res = do_search(table_name,partition_name,img_path, limit, MODEL, MILVUS_CLI)
- list = []
- for hits in res:
- for hit in hits:
- list.append({"im_hash":hit.id,"product_id":hit.entity.product_id,"score":hit.distance})
- # res = dict(zip(paths, distances))
- # res = sorted(res.items(), key=lambda item: item[1])
- LOGGER.info("Successfully searched similar images!")
- return list
- except Exception as e:
- LOGGER.error(e)
- return {'status': False, 'msg': e}, 400
- @app.post('/delete/record')
- async def delete_record(table_name: str = None,partition_name: str=None,expr: str=None):
- if len(expr) < 1:
- return {'status': False, 'msg': 'expr is not empty'}
- try:
- res = do_delete(table_name,partition_name,expr)
- print(res)
- return {'status': True}
- except Exception as e:
- LOGGER.error(e)
- return {'status': False, 'msg': e}, 400
-
- @app.post('/collection/count')
- async def count_images(table_name: str = None):
- # Returns the total number of images in the system
- try:
- num = do_count(table_name, MILVUS_CLI)
- LOGGER.info("Successfully count the number of images!")
- return num
- except Exception as e:
- LOGGER.error(e)
- return {'status': False, 'msg': e}, 400
- @app.post('/collection/drop')
- async def drop_tables(table_name: str = None):
- # Delete the collection of Milvus and MySQL
- try:
- status = do_drop(table_name, MILVUS_CLI)
- LOGGER.info("Successfully drop tables in Milvus and MySQL!")
- return status
- except Exception as e:
- LOGGER.error(e)
- return {'status': False, 'msg': e}, 400
- if __name__ == '__main__':
- uvicorn.run(app=app, host='0.0.0.0', port=5000)
|