import uvicorn import os from diskcache import Cache from fastapi import FastAPI, File, UploadFile from fastapi.param_functions import Form from starlette.middleware.cors import CORSMiddleware from starlette.responses import FileResponse from milvus_helpers import MilvusHelper from config import TOP_K, UPLOAD_PATH from encode import ResNet50 from operators import do_delete, do_load, do_upload, do_search, do_count, do_drop from logs import LOGGER from pydantic import BaseModel from typing import Optional from urllib.request import urlretrieve app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MODEL = ResNet50() MILVUS_CLI = MilvusHelper() # Mkdir '/tmp/search-images' if not os.path.exists(UPLOAD_PATH): os.makedirs(UPLOAD_PATH) LOGGER.info(f"mkdir the path:{UPLOAD_PATH}") # @app.get('/data') # def get_img(image_path): # # Get the image file # try: # LOGGER.info(f"Successfully load image: {image_path}") # return FileResponse(image_path) # except Exception as e: # LOGGER.error(f"Get image error: {e}") # return {'status': False, 'msg': e}, 400 # @app.get('/progress') # def get_progress(): # # Get the progress of dealing with images # try: # cache = Cache('./tmp') # return f"current: {cache['current']}, total: {cache['total']}" # except Exception as e: # LOGGER.error(f"upload image error: {e}") # return {'status': False, 'msg': e}, 400 # class Item(BaseModel): # Table: Optional[str] = None # File: str # @app.post('/img/load') # async def load_images(item: Item): # # Insert all the image under the file path to Milvus/MySQL # try: # total_num = do_load(item.Table, item.File, MODEL, MILVUS_CLI) # LOGGER.info(f"Successfully loaded data, total count: {total_num}") # return "Successfully loaded data!" # except Exception as e: # LOGGER.error(e) # return {'status': False, 'msg': e}, 400 # file:image,query:id,url,table_name @app.post('/img/upload') async def upload_images( table_name: str = None, partition_name: str = None, image: UploadFile = File(None), im_hash: str = None, product_id: str = None, url: str = None ): try: if image is not None: content = await image.read() img_path = os.path.join(UPLOAD_PATH, image.filename) with open(img_path, "wb+") as f: f.write(content) elif url is not None: img_path = os.path.join(UPLOAD_PATH, os.path.basename(url)) urlretrieve(url, img_path) else: return {'status': False, 'msg': 'Image and url are required'}, 400 vector_id = do_upload(table_name,partition_name,im_hash, product_id, img_path, MODEL, MILVUS_CLI) LOGGER.info(f"Successfully uploaded data, vector id: {vector_id}") return {'id':vector_id[0]} except Exception as e: LOGGER.error(e) return {'status': False, 'msg': e}, 400 # file:image,form:limit,table_name @app.post('/img/search') async def search_images( table_name: str = None, partition_name: str = None, image: UploadFile = File(...), limit: int = TOP_K ): try: content = await image.read() img_path = os.path.join(UPLOAD_PATH, image.filename) with open(img_path, "wb+") as f: f.write(content) res = do_search(table_name,partition_name,img_path, limit, MODEL, MILVUS_CLI) list = [] for hits in res: for hit in hits: list.append({"im_hash":hit.id,"product_id":hit.entity.product_id,"score":hit.distance}) # res = dict(zip(paths, distances)) # res = sorted(res.items(), key=lambda item: item[1]) LOGGER.info("Successfully searched similar images!") return list except Exception as e: LOGGER.error(e) return {'status': False, 'msg': e}, 400 @app.post('/delete/record') async def delete_record(table_name: str = None,partition_name: str=None,expr: str=None): if len(expr) < 1: return {'status': False, 'msg': 'expr is not empty'} try: res = do_delete(table_name,partition_name,expr) print(res) return {'status': True} except Exception as e: LOGGER.error(e) return {'status': False, 'msg': e}, 400 @app.post('/collection/count') async def count_images(table_name: str = None): # Returns the total number of images in the system try: num = do_count(table_name, MILVUS_CLI) LOGGER.info("Successfully count the number of images!") return num except Exception as e: LOGGER.error(e) return {'status': False, 'msg': e}, 400 @app.post('/collection/drop') async def drop_tables(table_name: str = None): # Delete the collection of Milvus and MySQL try: status = do_drop(table_name, MILVUS_CLI) LOGGER.info("Successfully drop tables in Milvus and MySQL!") return status except Exception as e: LOGGER.error(e) return {'status': False, 'msg': e}, 400 if __name__ == '__main__': uvicorn.run(app=app, host='0.0.0.0', port=5000)