文章
问答
冒泡
本地使用Python实现以图搜图

环境

  • ES8.15

实现思路

1、把图片转换成向量

import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing import image
def image_process(img_path):
    """
    图片处理
    :param str img_path: 图片路径
    :return: <class 'PIL.Image.Image'>
    """
    img = image.load_img(img_path, target_size=(224, 224))
    img = img.resize((224, 224))
    img = img.convert('RGB')
    return img

def extract(img_path):    
    model = ResNet50(weights='imagenet', include_top=False, pooling='avg')    
    img = image_process(img_path)
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)    model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
    feature = model.predict(x)
    return feature.flatten()[::2] # 维度 2048 -> 1024

2、把图片S3地址和向量等其他属性保存到ES,然后通过向量搜图片

  • S3存储代码

from time import time
import boto3


class OssUtils:
    AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
    AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
    AWS_S3_REGION_NAME = os.getenv("AWS_S3_REGION_NAME")
    AWS_S3_ENDPOINT_URL = os.getenv("AWS_S3_ENDPOINT_URL")
    def __init__(self):
        self.AWS_STORAGE_BUCKET_NAME = os.getenv("BUCKET_NAME")
        self.BASE_URL = "https://aisence.s3.cn-northwest-1.amazonaws.com.cn/{key}/{}"
        session = boto3.Session(
            aws_access_key_id=self.AWS_ACCESS_KEY_ID,
            aws_secret_access_key=self.AWS_SECRET_ACCESS_KEY,
            region_name=self.AWS_S3_REGION_NAME,
        )
        self.client = session.client("s3", endpoint_url=self.AWS_S3_ENDPOINT_URL)

    def upload(self, file_path):
        file_name = f"{time()}.png"
        self.client.put_object(
            Body=open(file_path, "rb"),
            Key=f"{key}/{file_name}",
            Bucket=self.AWS_STORAGE_BUCKET_NAME,
        )
        return self.BASE_URL.format(file_name)


ou = OssUtils()
  • ES代码

import sys
from elasticsearch import Elasticsearch

class EsUtils:
    def __init__(self):
        self.es = Elasticsearch(
            hosts=[elasticsearch_host:9200],
            timeout=60,
        )
        # 检查是否成功连接
        if self.es.ping():
            print("elasticsearch连接成功")
        else:
            print("elasticsearch连接失败")
            sys.exit()

    def index(self, index, body):
        """
        上传数据到 es 索引。
        参数:
            index (str): 索引名称。
            doc (set): 数据。

        返回值:
            int: 两个数字的和。
        """
        self.es.index(index=index, document=body)  # 保存到elasticsearch
    
    # 根据图片向量搜索图片的ES查询代码
    def feature_search(self, query, result_count=config.result_count):
        """
        相似图片向量查询
        参数:
            query (list): 图片向量。
            result_count (int): 搜索图片数量

        返回值:
            answers (list): 包含图片路径和名称的集合。
        """
        results = self.es.search(
            index=config.elasticsearch_index,
            body={
                "size": result_count,
                "query": {
                    "script_score": {
                        "query": {
                            "match_all": {}
                        },
                        "script": {
                            "source": "cosineSimilarity(params.queryVector, 'feature') + 1.0",
                            "params": {
                                "queryVector": query
                            }
                        }
                    }
                }
            })
        hitCount = results['hits']['total']['value']

        if hitCount > 0:
            answers = []
            max_score = results['hits']['max_score']
            print(f"最高分: {max_score}")

            if max_score >= 0.35:
                for index, hit in enumerate(results['hits']['hits']):
                    print(f"第{index + 1}张匹配分: {hit['_score']}")
                    if hit['_score'] > 0.5 * max_score:
                        imgurl = hit['_source']['url']
                        name = hit['_source']['name']
                        imgurl = imgurl.replace("#", "%23")
                        answers.append([imgurl, name])
        else:
            answers = []
        return answers


es = EsUtils()

注意⚠️ 搜索的脚本中不能用doc['feature'] 直接使用向量字段'feature'

  • 业务逻辑,遍历图片,先把本地图片转换成向量后上传到ES

    def extract_batch(img_folder_path, batch_size=32):
        batch_size_copy = batch_size
        cnt = 0  # 图片计数
        time_start = time.time()
        # 获取图片路径生成器
        img_path_list_generator = get_file_paths(img_folder_path, batch_size)
        for img_path_list in img_path_list_generator:
            # 过滤掉非图片类型的文件
            img_path_list = [name for name in img_path_list if
                             os.path.splitext(name)[1] in config.types]

            # 如果剩余图片数量小于batch_size,batch_size的值设置为剩余的图片数量
            if len(img_path_list) < batch_size:
                batch_size = len(img_path_list)
            else:
                batch_size = batch_size_copy

            # 创建一个空的数组用于存储图像数据
            batch_images = np.zeros(
                (batch_size, 224, 224, 3))

            cnt += len(img_path_list)

            for i, img_path in enumerate(img_path_list):
                img = image_process(img_path)
                x = image.img_to_array(img)
                x = np.expand_dims(x, axis=0)
                x = preprocess_input(x)
                batch_images[i % batch_size] = x  # 将图像添加到批次数组中

            # 当达到批量处理大小时或者是最后一张图像时进行处理
            embeddings = self.model.predict(batch_images)  # 执行特征提取
            for j in range(embeddings.shape[0]):
                feature = embeddings[j][::2]  # 维度 2048 -> 1024
                # 上传到OSS,返回图片地址   test是文件夹 前面不能加 /
                # TODO 上传 oss 耗时较高
                img_url = ou.upload(img_path_list[j])
                # img_url = "https://测试"

                # 上传es
                doc = {'name': os.path.basename(img_path_list[j]), 'feature': feature, 'url': img_url}
                es_utils.es.index(config.elasticsearch_index, body=doc)  # 保存到elasticsearch

        time_end = time.time()
        time_sum = time_end - time_start
        print("提取结束,提取成功图片: {} 张 总耗时: {} 秒\n".format(
            cnt, time_sum))

使用Flask暴露接口

# -*- coding: utf-8 -*-
import os

import numpy as np
# -*- coding: utf-8 -*-
from PIL import Image
from flask import Flask, request, render_template

import config
import es_utils
import feature_extractor
import oss_utils

'''
    以图搜图服务
'''

app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False


@app.route('/', methods=['GET', 'POST'])
def index():
    return render_template('index.html')


# 搜索图片
@app.route('/search', methods=['GET', 'POST'])
def search():
    if request.method == 'POST':
        file = request.files['query_img']

        # 临时将图片存储在本地
        img = Image.open(file.stream)
        uploaded_img_path = "static/uploaded/" + file.filename
        img.save(uploaded_img_path)

        query = feature_extractor.fe.extract(uploaded_img_path)
        answers = es_utils.es.feature_search(query)

        # 删除本地图片
        if os.path.exists(uploaded_img_path):
            os.remove(uploaded_img_path)
        else:
            print('删除图片失败:', uploaded_img_path)

        return render_template('index.html',
                               query_path=uploaded_img_path.replace("#", "%23"),
                               scores=answers)
    else:
        return render_template('index.html')


# 上传图片
@app.route('/upload', methods=['GET', 'POST'])
def upload():
    if request.method == 'POST':
        for file in request.files.getlist('upload_img'):
            name = file.filename

            # 暂存图片
            img = Image.open(file.stream)
            uploaded_img_path = config.root_path + '/static/uploaded/' + file.filename
            img.save(uploaded_img_path)

            feature = feature_extractor.fe.extract(uploaded_img_path)
            feature = np.array(feature).flatten()

            # 上传到OSS,返回图片地址   test前不能加 /
            img_url = oss_utils.ou.upload(config.root_path + '/static/uploaded/' + name)

            # 上传es
            doc = {'name': name, 'feature': feature, 'url': img_url}
            es_utils.es.index(config.elasticsearch_index, body=doc)  # 保存到elasticsearch

            # 删除本地图片
            if os.path.exists(uploaded_img_path):
                os.remove(uploaded_img_path)
            else:
                print('删除图片失败:', uploaded_img_path)

        return render_template('index.html')
    else:
        return render_template('index.html')


if __name__ == "__main__":
    app.run("0.0.0.0", port=8088, debug=True)

前端demo

<!doctype html>
<html>
<head>
   <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css">
   <meta charset="utf-8" name="referrer" content="never">
</head>
<body>
<div class="container">
   <h1>搜图图片</h1>
   <form action="/search" method="POST" enctype="multipart/form-data">
       <input type="file" name="query_img"><br>
       <input type="submit">
   </form>
   <h1>上传图片</h1>
   <form action="/upload" method="POST" enctype="multipart/form-data">
       <input type="file" name="upload_img" multiple>
       <br>
       <input type="submit">
   </form>
   {#    <h2>Query:</h2>#}
   {#    {% if query_path %}#}
   {#        <img src="{{ query_path }}" width="300px">#}
   {#    {% endif %}#}
   <h2>结果:</h2>
   {% for score in scores %}
       <figure style="float: left; margin-right: 20px; margin-bottom: 20px;">
           <img src="{{ score[0] }}" height="200px">
           {#            <p>{{ score[0] }}</p>#}
           <figcaption>{{ score[1] }}</figcaption>
       </figure>
   {% endfor %}
</div>
</body>
<script>
   {#var data = document.getElementById('dataid').getAttribute('d');//绑定以获取data值#}
   {#dataJson = JSON.parse(data);#}
   {#console.log(dataJson.name)#}
   {#console.log(dataJson.num)#}
</script>
</html>

实现效果


python

关于作者

小乙哥
学海无涯,回头是岸
获得点赞
文章被阅读