animeic 2 年之前
当前提交
bf12acf3a0
共有 11 个文件被更改,包括 657 次插入0 次删除
  1. 16 0
      Dockerfile
  2. 26 0
      build.sh
  3. 50 0
      milvus/docker-compose.yml
  4. 30 0
      readme.md
  5. 0 0
      src/__init__.py
  6. 16 0
      src/config.py
  7. 21 0
      src/encode.py
  8. 135 0
      src/logs.py
  9. 142 0
      src/main.py
  10. 126 0
      src/milvus_helpers.py
  11. 95 0
      src/operators.py

+ 16 - 0
Dockerfile

@@ -0,0 +1,16 @@
+FROM python:3.9-slim-buster
+RUN sed -i -E "s/\w+.debian.org/mirrors.tuna.tsinghua.edu.cn/g" /etc/apt/sources.list
+RUN apt-get update && \
+    apt-get install -y ffmpeg libsm6 libxext6 && \
+    apt-get remove --purge -y && rm -rf /var/lib/apt/lists/*
+
+RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
+RUN pip install diskcache==5.2.1 fastapi==0.65.2 pymilvus==2.1.0 towhee==0.8.0 uvicorn==0.13.4 protobuf==3.20.2 opencv-python==4.6.0.66 torch==1.12.1 torchvision==0.13.1 Pillow==9.2.0 timm==0.6.7 && \
+    pip cache purge
+
+WORKDIR /app/src
+COPY . /app
+
+EXPOSE 5000
+
+CMD python3 main.py

+ 26 - 0
build.sh

@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# 命名镜像
+local_imge="search-image:v1.0.0"
+repository_image="registry.cn-chengdu.aliyuncs.com/infish/search-image:v1.0.0"
+
+# 删除本地已存在的镜像
+docker rmi $repository_image
+
+# 创建本地镜像
+docker build -t $local_imge .
+
+# 镜像标签
+docker tag $local_imge $repository_image
+
+# push到镜像仓库,需要登陆对应docker仓库账号
+docker push $repository_image
+
+
+
+
+
+
+
+
+

+ 50 - 0
milvus/docker-compose.yml

@@ -0,0 +1,50 @@
+services:
+  etcd:
+    container_name: milvus-etcd
+    image: quay.io/coreos/etcd:v3.5.5
+    environment:
+      - ETCD_AUTO_COMPACTION_MODE=revision
+      - ETCD_AUTO_COMPACTION_RETENTION=1000
+      - ETCD_QUOTA_BACKEND_BYTES=4294967296
+      - ETCD_SNAPSHOT_COUNT=50000
+    volumes:
+      - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
+    command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
+
+  minio:
+    container_name: milvus-minio
+    image: minio/minio:RELEASE.2022-03-17T06-34-49Z
+    environment:
+      MINIO_ACCESS_KEY: minioadmin
+      MINIO_SECRET_KEY: minioadmin
+    ports:
+      - "9001:9001"
+      - "9000:9000"
+    volumes:
+      - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
+    command: minio server /minio_data --console-address ":9001"
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
+      interval: 30s
+      timeout: 20s
+      retries: 3
+
+  standalone:
+    container_name: milvus-standalone
+    image: milvusdb/milvus:v2.2.3
+    command: ["milvus", "run", "standalone"]
+    environment:
+      ETCD_ENDPOINTS: etcd:2379
+      MINIO_ADDRESS: minio:9000
+    volumes:
+      - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
+    ports:
+      - "19530:19530"
+      - "9091:9091"
+    depends_on:
+      - "etcd"
+      - "minio"
+
+networks:
+  default:
+    name: milvus

+ 30 - 0
readme.md

@@ -0,0 +1,30 @@
+# 以图搜图
+
+## 安装milvus
+
+`milvus/docker-compose.yml`
+
+## search-image服务
+
+1. 根据业务需求编写`src/*`
+2. 生成服务镜像`sh build.sh`
+3. 在对应应用服务`docker-compose.yml`中使用该镜像
+
+主要api:
+
+```python
+# file:image,query:id,url,table_name
+# url || image 参数必传
+@app.post('/img/upload')
+
+# file:image,form:limit,table_name
+# image参数必传
+@app.post('/img/search')
+
+# table_name
+@app.post('/img/count')
+
+# table_name
+@app.post('/img/drop')
+
+```

+ 0 - 0
src/__init__.py


+ 16 - 0
src/config.py

@@ -0,0 +1,16 @@
+import os
+
+############### Milvus Configuration ###############
+MILVUS_HOST = os.getenv("MILVUS_HOST", "103.143.81.176")
+MILVUS_PORT = int(os.getenv("MILVUS_PORT", "19530"))
+VECTOR_DIMENSION = int(os.getenv("VECTOR_DIMENSION", "2048"))
+INDEX_FILE_SIZE = int(os.getenv("INDEX_FILE_SIZE", "1024"))
+METRIC_TYPE = os.getenv("METRIC_TYPE", "L2")
+DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "reverse_img_search")
+TOP_K = int(os.getenv("TOP_K", "10"))
+
+############### Data Path ###############
+UPLOAD_PATH = os.getenv("UPLOAD_PATH", "tmp/images")
+
+############### Number of log files ###############
+LOGS_NUM = int(os.getenv("logs_num", "0"))

+ 21 - 0
src/encode.py

@@ -0,0 +1,21 @@
+import towhee
+from towhee.functional.option import _Reason
+
+class ResNet50:
+    def __init__(self):
+        self.pipe = (towhee.dummy_input()
+                    .image_decode()
+                    .image_embedding.timm(model_name='resnet50')
+                    .tensor_normalize()
+                    .as_function()
+        )
+
+    def resnet50_extract_feat(self, img_path):
+        feat = self.pipe(img_path)
+        if isinstance(feat, _Reason):
+            raise feat.exception
+        return feat
+
+
+if __name__ == "__main__":
+    ResNet50().resnet50_extract_feat('https://i1.sinaimg.cn/dy/deco/2013/0329/logo/LOGO_1x.png')

+ 135 - 0
src/logs.py

@@ -0,0 +1,135 @@
+import os
+import re
+import datetime
+import logging
+import sys
+from config import LOGS_NUM
+
+try:
+    import codecs
+except ImportError:
+    codecs = None
+
+
+class MultiprocessHandler(logging.FileHandler):
+    """
+    Say something about the ExampleCalass...
+
+    Args:
+        args_0 (`type`):
+        ...
+    """
+    def __init__(self, filename, when='D', backupCount=0, encoding=None, delay=False):
+        self.prefix = filename
+        self.backupCount = backupCount
+        self.when = when.upper()
+        self.extMath = r"^\d{4}-\d{2}-\d{2}"
+
+        self.when_dict = {
+            'S': "%Y-%m-%d-%H-%M-%S",
+            'M': "%Y-%m-%d-%H-%M",
+            'H': "%Y-%m-%d-%H",
+            'D': "%Y-%m-%d"
+        }
+
+        self.suffix = self.when_dict.get(when)
+        if not self.suffix:
+            print('The specified date interval unit is invalid: ', self.when)
+            sys.exit(1)
+
+        self.filefmt = os.path.join('.', "logs", f"{self.prefix}-{self.suffix}.log")
+
+        self.filePath = datetime.datetime.now().strftime(self.filefmt)
+
+        _dir = os.path.dirname(self.filefmt)
+        try:
+            if not os.path.exists(_dir):
+                os.makedirs(_dir)
+        except Exception as e:
+            print('Failed to create log file: ', e)
+            print("log_path:" + self.filePath)
+            sys.exit(1)
+
+        if codecs is None:
+            encoding = None
+
+        logging.FileHandler.__init__(self, self.filePath, 'a+', encoding, delay)
+
+    def shouldChangeFileToWrite(self):
+        _filePath = datetime.datetime.now().strftime(self.filefmt)
+        if _filePath != self.filePath:
+            self.filePath = _filePath
+            return True
+        return False
+
+    def doChangeFile(self):
+        self.baseFilename = os.path.abspath(self.filePath)
+        if self.stream:
+            self.stream.close()
+            self.stream = None
+
+        if not self.delay:
+            self.stream = self._open()
+        if self.backupCount > 0:
+            for s in self.getFilesToDelete():
+                os.remove(s)
+
+    def getFilesToDelete(self):
+        dir_name, _ = os.path.split(self.baseFilename)
+        file_names = os.listdir(dir_name)
+        result = []
+        prefix = self.prefix + '-'
+        for file_name in file_names:
+            if file_name[:len(prefix)] == prefix:
+                suffix = file_name[len(prefix):-4]
+                if re.compile(self.extMath).match(suffix):
+                    result.append(os.path.join(dir_name, file_name))
+        result.sort()
+
+        if len(result) < self.backupCount:
+            result = []
+        else:
+            result = result[:len(result) - self.backupCount]
+        return result
+
+    def emit(self, record):
+        try:
+            if self.shouldChangeFileToWrite():
+                self.doChangeFile()
+            logging.FileHandler.emit(self, record)
+        except (KeyboardInterrupt, SystemExit):
+            raise
+        except:
+            self.handleError(record)
+
+
+def write_log():
+    logger = logging.getLogger()
+    logger.setLevel(logging.DEBUG)
+    # formatter = '%(asctime)s | %(levelname)s | %(filename)s | %(funcName)s | %(module)s | %(lineno)s | %(message)s'
+    fmt = logging.Formatter(
+        '%(asctime)s | %(levelname)s | %(filename)s | %(funcName)s | %(lineno)s | %(message)s')
+
+    stream_handler = logging.StreamHandler(sys.stdout)
+    stream_handler.setLevel(logging.INFO)
+    stream_handler.setFormatter(fmt)
+
+    log_name = "milvus"
+    file_handler = MultiprocessHandler(log_name, when='D', backupCount=LOGS_NUM)
+    file_handler.setLevel(logging.DEBUG)
+    file_handler.setFormatter(fmt)
+    file_handler.doChangeFile()
+
+    logger.addHandler(stream_handler)
+    logger.addHandler(file_handler)
+
+    return logger
+
+
+LOGGER = write_log()
+# if __name__ == "__main__":
+#     message = 'test writing logs'
+#     logger = write_log()
+#     logger.info(message)
+#     logger.debug(message)
+#     logger.error(message)

+ 142 - 0
src/main.py

@@ -0,0 +1,142 @@
+import uvicorn
+import os
+from diskcache import Cache
+from fastapi import FastAPI, File, UploadFile
+from fastapi.param_functions import Form
+from starlette.middleware.cors import CORSMiddleware
+from starlette.responses import FileResponse
+from milvus_helpers import MilvusHelper
+from config import TOP_K, UPLOAD_PATH
+from encode import ResNet50
+from operators import do_load, do_upload, do_search, do_count, do_drop
+from logs import LOGGER
+from pydantic import BaseModel
+from typing import Optional
+from urllib.request import urlretrieve
+
+app = FastAPI()
+origins = ["*"]
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=origins,
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+
+)
+MODEL = ResNet50()
+MILVUS_CLI = MilvusHelper()
+
+# Mkdir '/tmp/search-images'
+if not os.path.exists(UPLOAD_PATH):
+    os.makedirs(UPLOAD_PATH)
+    LOGGER.info(f"mkdir the path:{UPLOAD_PATH}")
+
+
+@app.get('/data')
+def get_img(image_path):
+    # Get the image file
+    try:
+        LOGGER.info(f"Successfully load image: {image_path}")
+        return FileResponse(image_path)
+    except Exception as e:
+        LOGGER.error(f"Get image error: {e}")
+        return {'status': False, 'msg': e}, 400
+
+
+@app.get('/progress')
+def get_progress():
+    # Get the progress of dealing with images
+    try:
+        cache = Cache('./tmp')
+        return f"current: {cache['current']}, total: {cache['total']}"
+    except Exception as e:
+        LOGGER.error(f"upload image error: {e}")
+        return {'status': False, 'msg': e}, 400
+
+
+class Item(BaseModel):
+    Table: Optional[str] = None
+    File: str
+
+
+@app.post('/img/load')
+async def load_images(item: Item):
+    # Insert all the image under the file path to Milvus/MySQL
+    try:
+        total_num = do_load(item.Table, item.File, MODEL, MILVUS_CLI)
+        LOGGER.info(f"Successfully loaded data, total count: {total_num}")
+        return "Successfully loaded data!"
+    except Exception as e:
+        LOGGER.error(e)
+        return {'status': False, 'msg': e}, 400
+
+# file:image,query:id,url,table_name
+@app.post('/img/upload')
+async def upload_images(image: UploadFile = File(None),id: str = None,url: str = None, table_name: str = None):
+    # Insert the upload image to Milvus/MySQL
+    try:
+        # Save the upload image to server.
+        if image is not None:
+            content = await image.read()
+            img_path = os.path.join(UPLOAD_PATH, image.filename)
+            with open(img_path, "wb+") as f:
+                f.write(content)
+        elif url is not None:
+            img_path = os.path.join(UPLOAD_PATH, os.path.basename(url))
+            urlretrieve(url, img_path)
+        else:
+            return {'status': False, 'msg': 'Image and url are required'}, 400
+        vector_id = do_upload(table_name, id,img_path, MODEL, MILVUS_CLI)
+        LOGGER.info(f"Successfully uploaded data, vector id: {vector_id}")
+        return {'id':vector_id[0]}
+    except Exception as e:
+        LOGGER.error(e)
+        return {'status': False, 'msg': e}, 400
+
+# file:image,form:limit,table_name
+@app.post('/img/search')
+async def search_images(image: UploadFile = File(...), limit: int = Form(TOP_K), table_name: str = None):
+    # Search the upload image in Milvus/MySQL
+    try:
+        # Save the upload image to server.
+        content = await image.read()
+        img_path = os.path.join(UPLOAD_PATH, image.filename)
+        with open(img_path, "wb+") as f:
+            f.write(content)
+        paths, distances = do_search(table_name, img_path, limit, MODEL, MILVUS_CLI)
+        res = dict(zip(paths, distances))
+        res = sorted(res.items(), key=lambda item: item[1])
+        LOGGER.info("Successfully searched similar images!")
+        return {"list":res}
+    except Exception as e:
+        LOGGER.error(e)
+        return {'status': False, 'msg': e}, 400
+
+
+@app.post('/img/count')
+async def count_images(table_name: str = None):
+    # Returns the total number of images in the system
+    try:
+        num = do_count(table_name, MILVUS_CLI)
+        LOGGER.info("Successfully count the number of images!")
+        return num
+    except Exception as e:
+        LOGGER.error(e)
+        return {'status': False, 'msg': e}, 400
+
+
+@app.post('/img/drop')
+async def drop_tables(table_name: str = None):
+    # Delete the collection of Milvus and MySQL
+    try:
+        status = do_drop(table_name, MILVUS_CLI)
+        LOGGER.info("Successfully drop tables in Milvus and MySQL!")
+        return status
+    except Exception as e:
+        LOGGER.error(e)
+        return {'status': False, 'msg': e}, 400
+
+
+if __name__ == '__main__':
+    uvicorn.run(app=app, host='0.0.0.0', port=5000)

+ 126 - 0
src/milvus_helpers.py

@@ -0,0 +1,126 @@
+import sys
+from config import MILVUS_HOST, MILVUS_PORT, VECTOR_DIMENSION, METRIC_TYPE
+from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
+from logs import LOGGER
+
+
+class MilvusHelper:
+    """
+    MilvusHelper class to manager the Milvus Collection.
+
+    Args:
+        host (`str`):
+            Milvus server Host.
+        port (`str|int`):
+            Milvus server port.
+        ...
+    """
+    def __init__(self, host=MILVUS_HOST, port=MILVUS_PORT):
+        try:
+            self.collection = None
+            connections.connect(host=host, port=port)
+            LOGGER.debug(f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}")
+        except Exception as e:
+            LOGGER.error(f"Failed to connect Milvus: {e}")
+            sys.exit(1)
+
+    def set_collection(self, collection_name):
+        try:
+            self.collection = Collection(name=collection_name)
+        except Exception as e:
+            LOGGER.error(f"Failed to load data to Milvus: {e}")
+            sys.exit(1)
+
+    def has_collection(self, collection_name):
+        # Return if Milvus has the collection
+        try:
+            return utility.has_collection(collection_name)
+        except Exception as e:
+            LOGGER.error(f"Failed to load data to Milvus: {e}")
+            sys.exit(1)
+
+    def create_collection(self, collection_name):
+        # Create milvus collection if not exists
+        try:
+            if not self.has_collection(collection_name):
+                field1 = FieldSchema(name='id', dtype=DataType.VARCHAR, descrition='id to image', max_length=500,
+                                     is_primary=True, auto_id=False)
+                field2 = FieldSchema(name='path', dtype=DataType.VARCHAR, descrition='path to image', max_length=500,
+                                     is_primary=False, auto_id=False)
+                field3 = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, descrition="image embedding vectors",
+                                     dim=VECTOR_DIMENSION, is_primary=False)
+                schema = CollectionSchema(fields=[field1, field2,field3], description="collection_name: "+collection_name)
+                self.collection = Collection(name=collection_name, schema=schema)
+                self.create_index(collection_name)
+                LOGGER.debug(f"Create Milvus collection: {collection_name}")
+            else:
+                self.set_collection(collection_name)
+            return "OK"
+        except Exception as e:
+            LOGGER.error(f"Failed to load data to Milvus: {e}")
+            sys.exit(1)
+
+    def insert(self, collection_name, id,path, vectors):
+        # Batch insert vectors to milvus collection
+        try:
+            data = [id,path, vectors]
+            self.set_collection(collection_name)
+            mr = self.collection.insert(data)
+            ids = mr.primary_keys
+            self.collection.load()
+            LOGGER.debug(
+                    f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows")
+            return ids
+        except Exception as e:
+            LOGGER.error(f"Failed to load data to Milvus: {e}")
+            sys.exit(1)
+
+    def create_index(self, collection_name):
+        # Create IVF_FLAT index on milvus collection
+        try:
+            self.set_collection(collection_name)
+            default_index = {"index_type": "IVF_SQ8", "metric_type": METRIC_TYPE, "params": {"nlist": 16384}}
+            status = self.collection.create_index(field_name="embedding", index_params=default_index)
+            if not status.code:
+                LOGGER.debug(
+                    f"Successfully create index in collection:{collection_name} with param:{default_index}")
+                return status
+            else:
+                raise Exception(status.message)
+        except Exception as e:
+            LOGGER.error(f"Failed to create index: {e}")
+            sys.exit(1)
+
+    def delete_collection(self, collection_name):
+        # Delete Milvus collection
+        try:
+            self.set_collection(collection_name)
+            self.collection.drop()
+            LOGGER.debug("Successfully drop collection!")
+            return "ok"
+        except Exception as e:
+            LOGGER.error(f"Failed to drop collection: {e}")
+            sys.exit(1)
+
+    def search_vectors(self, collection_name, vectors, top_k):
+        # Search vector in milvus collection
+        try:
+            self.set_collection(collection_name)
+            search_params = {"metric_type": METRIC_TYPE, "params": {"nprobe": 16}}
+            res = self.collection.search(vectors, anns_field="embedding", param=search_params, limit=top_k)
+            LOGGER.debug(f"Successfully search in collection: {res}")
+            return res
+        except Exception as e:
+            LOGGER.error(f"Failed to search vectors in Milvus: {e}")
+            sys.exit(1)
+
+    def count(self, collection_name):
+        # Get the number of milvus collection
+        try:
+            self.set_collection(collection_name)
+            num = self.collection.num_entities
+            LOGGER.debug(f"Successfully get the num:{num} of the collection:{collection_name}")
+            return num
+        except Exception as e:
+            LOGGER.error(f"Failed to count vectors in Milvus: {e}")
+            sys.exit(1)

+ 95 - 0
src/operators.py

@@ -0,0 +1,95 @@
+import sys
+from glob import glob
+from diskcache import Cache
+from config import DEFAULT_TABLE
+from logs import LOGGER
+
+
+def do_upload(table_name,id, img_path, model, milvus_client):
+    try:
+        if not table_name:
+            table_name = DEFAULT_TABLE
+        milvus_client.create_collection(table_name)
+        feat = model.resnet50_extract_feat(img_path)
+        ids = milvus_client.insert(table_name, [id],[img_path], [feat])
+        return ids
+    except Exception as e:
+        LOGGER.error(f"Error with upload : {e}")
+        sys.exit(1)
+
+
+def extract_features(img_dir, model):
+    img_list = []
+    for path in ['/*.png', '/*.jpg', '/*.jpeg', '/*.PNG', '/*.JPG', '/*.JPEG']:
+        img_list.extend(glob(img_dir+path))
+    try:
+        if len(img_list) == 0:
+            raise FileNotFoundError(f"There is no image file in {img_dir} and endswith ['/*.png', '/*.jpg', '/*.jpeg', '/*.PNG', '/*.JPG', '/*.JPEG']")
+        cache = Cache('./tmp')
+        feats = []
+        names = []
+        total = len(img_list)
+        cache['total'] = total
+        for i, img_path in enumerate(img_list):
+            try:
+                norm_feat = model.resnet50_extract_feat(img_path)
+                feats.append(norm_feat)
+                names.append(img_path)
+                cache['current'] = i + 1
+                print(f"Extracting feature from image No. {i + 1} , {total} images in total")
+            except Exception as e:
+                LOGGER.error(f"Error with extracting feature from image:{img_path}, error: {e}")
+                continue
+        return feats, names
+    except Exception as e:
+        LOGGER.error(f"Error with extracting feature from image {e}")
+        sys.exit(1)
+
+
+def do_load(table_name, image_dir, model, milvus_client):
+    if not table_name:
+        table_name = DEFAULT_TABLE
+    milvus_client.create_collection(table_name)
+    vectors, paths = extract_features(image_dir, model)
+    ids = milvus_client.insert(table_name, paths, vectors)
+    return len(ids)
+
+
+def do_search(table_name, img_path, top_k, model, milvus_client):
+    try:
+        if not table_name:
+            table_name = DEFAULT_TABLE
+        feat = model.resnet50_extract_feat(img_path)
+        vectors = milvus_client.search_vectors(table_name, [feat], top_k)
+        paths = [str(x.id) for x in vectors[0]]
+        distances = [x.distance for x in vectors[0]]
+        return paths, distances
+    except Exception as e:
+        LOGGER.error(f"Error with search : {e}")
+        sys.exit(1)
+
+
+def do_count(table_name, milvus_cli):
+    if not table_name:
+        table_name = DEFAULT_TABLE
+    try:
+        if not milvus_cli.has_collection(table_name):
+            return None
+        num = milvus_cli.count(table_name)
+        return num
+    except Exception as e:
+        LOGGER.error(f"Error with count table {e}")
+        sys.exit(1)
+
+
+def do_drop(table_name, milvus_cli):
+    if not table_name:
+        table_name = DEFAULT_TABLE
+    try:
+        if not milvus_cli.has_collection(table_name):
+            return f"Milvus doesn't have a collection named {table_name}"
+        status = milvus_cli.delete_collection(table_name)
+        return status
+    except Exception as e:
+        LOGGER.error(f"Error with drop table: {e}")
+        sys.exit(1)