环境
ES8.15
实现思路
1、把图片转换成向量
import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing import image
def image_process(img_path):
"""
图片处理
:param str img_path: 图片路径
:return: <class 'PIL.Image.Image'>
"""
img = image.load_img(img_path, target_size=(224, 224))
img = img.resize((224, 224))
img = img.convert('RGB')
return img
def extract(img_path):
model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
img = image_process(img_path)
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x) model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
feature = model.predict(x)
return feature.flatten()[::2] # 维度 2048 -> 1024
2、把图片S3地址和向量等其他属性保存到ES,然后通过向量搜图片
S3存储代码
from time import time
import boto3
class OssUtils:
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_S3_REGION_NAME = os.getenv("AWS_S3_REGION_NAME")
AWS_S3_ENDPOINT_URL = os.getenv("AWS_S3_ENDPOINT_URL")
def __init__(self):
self.AWS_STORAGE_BUCKET_NAME = os.getenv("BUCKET_NAME")
self.BASE_URL = "https://aisence.s3.cn-northwest-1.amazonaws.com.cn/{key}/{}"
session = boto3.Session(
aws_access_key_id=self.AWS_ACCESS_KEY_ID,
aws_secret_access_key=self.AWS_SECRET_ACCESS_KEY,
region_name=self.AWS_S3_REGION_NAME,
)
self.client = session.client("s3", endpoint_url=self.AWS_S3_ENDPOINT_URL)
def upload(self, file_path):
file_name = f"{time()}.png"
self.client.put_object(
Body=open(file_path, "rb"),
Key=f"{key}/{file_name}",
Bucket=self.AWS_STORAGE_BUCKET_NAME,
)
return self.BASE_URL.format(file_name)
ou = OssUtils()
ES代码
import sys
from elasticsearch import Elasticsearch
class EsUtils:
def __init__(self):
self.es = Elasticsearch(
hosts=[elasticsearch_host:9200],
timeout=60,
)
# 检查是否成功连接
if self.es.ping():
print("elasticsearch连接成功")
else:
print("elasticsearch连接失败")
sys.exit()
def index(self, index, body):
"""
上传数据到 es 索引。
参数:
index (str): 索引名称。
doc (set): 数据。
返回值:
int: 两个数字的和。
"""
self.es.index(index=index, document=body) # 保存到elasticsearch
# 根据图片向量搜索图片的ES查询代码
def feature_search(self, query, result_count=config.result_count):
"""
相似图片向量查询
参数:
query (list): 图片向量。
result_count (int): 搜索图片数量
返回值:
answers (list): 包含图片路径和名称的集合。
"""
results = self.es.search(
index=config.elasticsearch_index,
body={
"size": result_count,
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.queryVector, 'feature') + 1.0",
"params": {
"queryVector": query
}
}
}
}
})
hitCount = results['hits']['total']['value']
if hitCount > 0:
answers = []
max_score = results['hits']['max_score']
print(f"最高分: {max_score}")
if max_score >= 0.35:
for index, hit in enumerate(results['hits']['hits']):
print(f"第{index + 1}张匹配分: {hit['_score']}")
if hit['_score'] > 0.5 * max_score:
imgurl = hit['_source']['url']
name = hit['_source']['name']
imgurl = imgurl.replace("#", "%23")
answers.append([imgurl, name])
else:
answers = []
return answers
es = EsUtils()
注意⚠️ 搜索的脚本中不能用doc['feature'] 直接使用向量字段'feature'
业务逻辑,遍历图片,先把本地图片转换成向量后上传到ES
def extract_batch(img_folder_path, batch_size=32):
batch_size_copy = batch_size
cnt = 0 # 图片计数
time_start = time.time()
# 获取图片路径生成器
img_path_list_generator = get_file_paths(img_folder_path, batch_size)
for img_path_list in img_path_list_generator:
# 过滤掉非图片类型的文件
img_path_list = [name for name in img_path_list if
os.path.splitext(name)[1] in config.types]
# 如果剩余图片数量小于batch_size,batch_size的值设置为剩余的图片数量
if len(img_path_list) < batch_size:
batch_size = len(img_path_list)
else:
batch_size = batch_size_copy
# 创建一个空的数组用于存储图像数据
batch_images = np.zeros(
(batch_size, 224, 224, 3))
cnt += len(img_path_list)
for i, img_path in enumerate(img_path_list):
img = image_process(img_path)
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
batch_images[i % batch_size] = x # 将图像添加到批次数组中
# 当达到批量处理大小时或者是最后一张图像时进行处理
embeddings = self.model.predict(batch_images) # 执行特征提取
for j in range(embeddings.shape[0]):
feature = embeddings[j][::2] # 维度 2048 -> 1024
# 上传到OSS,返回图片地址 test是文件夹 前面不能加 /
# TODO 上传 oss 耗时较高
img_url = ou.upload(img_path_list[j])
# img_url = "https://测试"
# 上传es
doc = {'name': os.path.basename(img_path_list[j]), 'feature': feature, 'url': img_url}
es_utils.es.index(config.elasticsearch_index, body=doc) # 保存到elasticsearch
time_end = time.time()
time_sum = time_end - time_start
print("提取结束,提取成功图片: {} 张 总耗时: {} 秒\n".format(
cnt, time_sum))
使用Flask暴露接口
# -*- coding: utf-8 -*-
import os
import numpy as np
# -*- coding: utf-8 -*-
from PIL import Image
from flask import Flask, request, render_template
import config
import es_utils
import feature_extractor
import oss_utils
'''
以图搜图服务
'''
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False
@app.route('/', methods=['GET', 'POST'])
def index():
return render_template('index.html')
# 搜索图片
@app.route('/search', methods=['GET', 'POST'])
def search():
if request.method == 'POST':
file = request.files['query_img']
# 临时将图片存储在本地
img = Image.open(file.stream)
uploaded_img_path = "static/uploaded/" + file.filename
img.save(uploaded_img_path)
query = feature_extractor.fe.extract(uploaded_img_path)
answers = es_utils.es.feature_search(query)
# 删除本地图片
if os.path.exists(uploaded_img_path):
os.remove(uploaded_img_path)
else:
print('删除图片失败:', uploaded_img_path)
return render_template('index.html',
query_path=uploaded_img_path.replace("#", "%23"),
scores=answers)
else:
return render_template('index.html')
# 上传图片
@app.route('/upload', methods=['GET', 'POST'])
def upload():
if request.method == 'POST':
for file in request.files.getlist('upload_img'):
name = file.filename
# 暂存图片
img = Image.open(file.stream)
uploaded_img_path = config.root_path + '/static/uploaded/' + file.filename
img.save(uploaded_img_path)
feature = feature_extractor.fe.extract(uploaded_img_path)
feature = np.array(feature).flatten()
# 上传到OSS,返回图片地址 test前不能加 /
img_url = oss_utils.ou.upload(config.root_path + '/static/uploaded/' + name)
# 上传es
doc = {'name': name, 'feature': feature, 'url': img_url}
es_utils.es.index(config.elasticsearch_index, body=doc) # 保存到elasticsearch
# 删除本地图片
if os.path.exists(uploaded_img_path):
os.remove(uploaded_img_path)
else:
print('删除图片失败:', uploaded_img_path)
return render_template('index.html')
else:
return render_template('index.html')
if __name__ == "__main__":
app.run("0.0.0.0", port=8088, debug=True)
前端demo
<!doctype html>
<html>
<head>
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css">
<meta charset="utf-8" name="referrer" content="never">
</head>
<body>
<div class="container">
<h1>搜图图片</h1>
<form action="/search" method="POST" enctype="multipart/form-data">
<input type="file" name="query_img"><br>
<input type="submit">
</form>
<h1>上传图片</h1>
<form action="/upload" method="POST" enctype="multipart/form-data">
<input type="file" name="upload_img" multiple>
<br>
<input type="submit">
</form>
{# <h2>Query:</h2>#}
{# {% if query_path %}#}
{# <img src="{{ query_path }}" width="300px">#}
{# {% endif %}#}
<h2>结果:</h2>
{% for score in scores %}
<figure style="float: left; margin-right: 20px; margin-bottom: 20px;">
<img src="{{ score[0] }}" height="200px">
{# <p>{{ score[0] }}</p>#}
<figcaption>{{ score[1] }}</figcaption>
</figure>
{% endfor %}
</div>
</body>
<script>
{#var data = document.getElementById('dataid').getAttribute('d');//绑定以获取data值#}
{#dataJson = JSON.parse(data);#}
{#console.log(dataJson.name)#}
{#console.log(dataJson.num)#}
</script>
</html>