相似图像搜索从训练到服务全过程

最近完成了一个以图搜图的项目,项目总共用时三个多月。记录一下项目中用到机器学习的地方,以及各种踩过的坑。总的来说,项目分为一下几个部分:

一、训练目标函数

1、    设定基础模型

2、    添加新层

3、    冻结 base 层

4、    编译模型

5、    训练

6、    保存模型

二、特征提取

三、创建索引

四、构建服务

1、flask 开发

2、Gunicorn 异步,增加服务稳健性

3、Supervisor 部署监控服务

五、总结

一、训练目标函数

项目是在预训练模型 vgg16 的基础上进行微调(fine_tune),并将特征的维度从原先的 2048 维降为 1024 维度。

模型的微调又分为以下几个步骤:

1、设定基础模型

本次采用预训练的 VGG16基础模型,利用其 bottleneck 特征

# 设定基础模型

base_model = VGG16(weights=‘./model/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5‘, include_top=False)

#指定权重路径

# include_top= False 不加载三层全连接层

2、添加新层

将自己要目标图片,简单分类,统计类别(在训练模型时需要指定类别)

# 添加新层

def add_new_last_layer(base_model, nb_classes):

    ‘‘‘
    添加最后的层
    :param base_model: 预训练模型
    :param nb_classes: 分类数量
    :return: 新的 model
    ‘‘‘
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation=‘relu‘)(x) #输出的特征维度 88
    predictions = Dense(nb_classes, activation=‘softmax‘)(x)
    model = Model(input=base_model.input, output=predictions)
    return model

3、冻结 base 层

以前的参数可以使用预训练好的参数,不需要重新训练,所以需要冻结,不让其改变。

def freeze_base_layer(model, base_model):

        for layer in base_model.layers:

        layer.trainable = False

 4、编译模型

model.compile(optimizer=‘rmsprop‘, loss=‘categorical_crossentropy‘, metrics= [‘accuracy‘])

# optimizer: 优化器

# loss: 损失函数,多类的对数损失需要将分类标签转换为(将标签转化为形如(nb_samples, nb_classes)的二值序列)

# metrics: 列表,包含评估模型在训练和测试时的网络性能的指标准备训练数据。

5、训练

#数据准备
IM_WIDTH, IM_HEIGHT = 224,224
train_dir = ‘./refine_img_data/train‘
val_dir = ‘./refine_img_data/test‘
nb_classes = 5
np_epoch = 3
batch_size = 16
nb_train_samples = get_nb_files(train_dir)
nb_classes = len(glob.glob(train_dir + ‘/*‘))
nb_val_samples = get_nb_files(val_dir)

# 根据现有数据,设置新数据生成参数
train_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)

test_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)

# 从文件夹获取数据
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(IM_WIDTH, IM_HEIGHT),
batch_size=batch_size,
class_mode=‘categorical‘
)

validation_generator = test_datagen.flow_from_directory(
val_dir,
target_size=(IM_WIDTH, IM_HEIGHT),
batch_size=batch_size,
class_mode=‘categorical‘
)

# 训练
history_t1 = model.fit_generator(
train_generator,
epochs=1,
steps_per_epoch=10,
validation_data=validation_generator,
validation_steps=10,
class_weight=‘auto‘
)

6、保存模型

将模型保存到指定路径一般保存为 .h5 格式

 model.save(‘/model/test_model.h5‘)

  

二、特征提取

加载我们训练好的模型,根据需要,取指定层的特征。

# 可用 model.summary() 查看模型结构

#根据模型提取图片特征

target_size = (224,224)

def my_feature(mod, path):
    img = image.load_img(path,target_size=target_size)
    img = image.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = preprocess_input(img)
    return mod.predict(img)

# 创建模型,获取指定层特征
model_path = ‘./model/my_model.h5‘
base_model = load_model(model_path)
model = Model(inputs=base_model.input, outputs=base_model.get_layer(‘dense_1‘).output)

# 提取特征
img_path = ‘./my_img/bus.jpg‘
feat = my_feature(model,img_path) # shape 为 (1,128)
print(feat)
print(feat.shape)

#注意, 当需要提取的图片特征数量较大,比如千万以上,需要的时间是比较长的,这时我们可以采用多核与批处理来进行 (python 由于 GIL 的问题对多线程不友好)。
def pre_processs_image(path):
    if path is not None and os.path.exists(path) and len(path) > 10:
      try:
          img = cv2.imread(path, cv2.IMREAD_COLOR)
          img = cv2.resize(img, (224, 224))
          img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
          img = img.transpose(2, 0, 1)
          return [material_id,img, flag]
      except Exception as err:
          traceback.print_exc()
          return None
    else:
    logging.error(‘could not find path: ‘ + path)
    return None

#cpu 部分,调用多核处理函数,指定核数为 20
with ProcessPoolExecutor(max_workers=20) as executor:
feat_paras = list(executor.map(pre_processs_image,, material_batch))

# GPU 部分采用批处理
# TODO

三、创建索引

此处我们使用 Facebook 开源的近邻索引框架 faiss 。


# create index
d = 128
nlist = 100 # 切分数量
nprobe = 8 # 每次查找分片数量
quantizer_img = faiss.IndexFlatL2(d) #根据欧式距离创建索引

image_index = None
model_index = None

if image_feat_array is not None and len(img_feat_list) > 100:
  image_index = faiss.IndexIVFFlat(quantizer_img, d, nlist, faiss.METRIC_L2)
  image_index.train(image_feat_array)
  image_index.add_with_ids(image_feat_array,image_id_array)
  image_index.nprobe = nprobe
  image_index.dont_dealloc_me = quantizer_img

# 保存当前索引到指定路径
faiss.write_index(img_index,path)

# 测试当前索引
temp_feat = img_feat_list[1]
res_2 = image_index.search(temp_feat, k=5)
logging.info(‘image search result is:‘ + str(res_2))

四、构建服务

1、flask 开发

参考文档 http://docs.jinkan.org/docs/flask/quickstart.html#a-minimal-application

2、Gunicorn 异步,增加服务稳健性

基础语法:

Gunicorn –w process_num –b ip:port –k ‘gevent‘ fileName:app

# 注意:此处不选择 –k ‘gevent‘ 则为同步运行

同步部署:

gunicorn -b 0.0.0.0:9090 my_service:app

异步部署:

gunicorn -b 0.0.0.0:9090 -k gevent my_service:app

用了 Gunicorn 来部署应用后, 对比 flask , qps 提升了一倍。原 flask 框架中由于我的接口中 request 了其他的接口,线程在此处会阻塞,导致程序非常容易假死。改用后,稳定又了极大的提升。

3、Supervisor 部署监控服务

可参考以下文档 https://www.cnblogs.com/gjack/p/8076419.html

五、总结

项目到这个地方,基本的服务框架已经有了。许多地方只说了大体思路,但是结构是完整。文中的许多用了许多方法工具,如 gunicorn 的异步等, 但是原理却不甚了解,还需要花功夫去学习。由于上线压力大,时间紧,许多地方来不及仔细琢磨,肯定有不少纰漏,后面再查漏补缺吧。

原文地址:https://www.cnblogs.com/yaolin1228/p/9557588.html

时间: 2024-10-31 06:56:00

相似图像搜索从训练到服务全过程的相关文章

百度图像搜索探秘

源地址:http://blog.sina.com.cn/s/blog_6ae183910101gily.html 昨天,百度上线了新的相似图(similar image search)搜索,试了风景.人物.文字等不同类型query的效果,感觉效果非常赞.尤其对于人物搜索,返回的结果在颜色.以及姿态方面具有非常大的相似性.特别是在输入某个pose的美女图片时,会搜到一系列相近pose的美女图片,真的是宅男之福啊.本着娱乐精神,贴一个搜索结果供大家yy. 我们知道这个产品底层的技术是余凯老师领导的百

全域图像搜索给你更精准的搜索体验

摘要: 2018飞天技术汇,阿里巴巴机器智能技术实验室的刘磊带来题为全域精准图像搜索介绍的演讲,主要从四个方面进行了阐述,第一部分介绍了图像搜索的基本概念,第二部分主要是讲解了图像搜索的技术架构及其优势,第三部分对应用场景及案例进行了分析,最后对商品使用情况以及定价做了简单介绍. 2018飞天技术汇,阿里巴巴机器智能技术实验室的刘磊带来题为全域精准图像搜索介绍的演讲,主要从四个方面进行了阐述,第一部分介绍了图像搜索的基本概念,第二部分主要是讲解了图像搜索的技术架构及其优势,第三部分对应用场景及案

图像搜索技术发展应知道

什么是图像搜索?图像搜索,是通过搜索图像文本或者视觉特征,为用户提供相关图形图像资料检索服务.?从图像搜索的发展过程来看,主要包含两种搜索方式:基于文本的图像搜索(Text-Based Image Retrieval,TBIR),将图像作为数据库中的存储对象,利用与图像相关联的文本关键词进行匹配,返回搜索结果.基于内容的图像搜索(Content-Based Image Retrieval,CBIR),提取图像的视觉内容特征作为索引,例如颜色.纹理.形状等,通过输入一张图片比较特征向量之间的相似度

看起来像它——图像搜索其实也不难 (图像相似,图像指纹,phash hash,图像搜索) 使用时候记得看这文章的评论

链接: http://pan.baidu.com/s/1o7ScyVo 密码: h8eb    这个文章的代码 另一个类似的代码  链接: http://pan.baidu.com/s/1hsFDCNy 密码: jxus http://blog.csdn.net/luoweifu/article/details/8220992                 使用时候记得看这文章的评论 看起来像它——图像搜索其实也不难 标签: pHash图像搜索图像识别图片搜索算法 2012-11-24 23:

基于感知哈希算法的图像搜索实现

无意中看见一篇博客,是讲仿造google搜图的,链接如下: Google 以图搜图 - 相似图片搜索原理 - Java实现 觉得挺好玩的,博主使用Java实现的,于是我用 OpenCv实现了下. 根据看到的博文,里面说到,Google图像搜索的关键技术是"感知压缩算法"(Perceptual hash algorithm),它的作用是对每张图片生成一个"指纹"(fingerprint)字符串,然后比较不同图片的指纹.结果越接近,就说明图片越相似.看到这里我就突然来了

Boosting从原理到实现图像数组的训练

Boosting原理 众做周知,boosting就是所谓的有多个弱分类器组成一个强分类器.而什么叫做弱分类学习和什么时候需要使用弱分类学习呢? 弱分类学习 弱分类学习:识别一组概念的正确率仅比随机猜的概率高一点. 同理,当需要分类的训练组具有上述特点时,可以优先考虑使用boosting算法. Boosting的重要历史事件 Kearns & Valiant (1984) : Boosting由来 Kearns & Valiant (1989) : 证明了弱学习器和强学习器的等价问题. Sc

就是看起来像而已——图像搜索内核探索

这是我第一次翻译外文文章,如果翻译的不好,还望大家多包含!以下黑色部分是作者原文的翻译,红色部分是我本人自己的理解和对其的补充. 原文:Looks Like It 在google里对的搜索结果是 下面是我用pHash算法(Java)实现的结果: 十张比较的图如下: source: f0a0000030400000 1-5    2-5    3-0    4-5    5-5    6-5    7-5    8-7    9-6    10-3    11-5 f0a0000030400000

Ubuntu 下配置 SSH服务全过程及问题解决

Windows下做Linux开发,装虚拟机里,怎么可以不用SSH呢.有人说,“做Linux开发,还不直接装机器上跑起来了,还挂虚拟机,开SSH……闲的蛋疼了吧”,不管怎样,我接触Linux算是3年了,用了3年的Ubuntu了 吧,对Ubuntu算是了解,Ubuntu在Linux众多发行版里做的算是可圈可点的了,但是Linux在PC系统中的很多方面并不是非常到位,用户体 验也好,性能也罢.我也热爱Linux,不过不能否认它在这方面的一些略势,当然他强大的命令行.网络服务,以及那华丽的界面也是win

solairs11 配置samba文件服务全过程(solairs11系统的安装,网络配置,samba的安装及配置)

叨叨 由于项目中产生很多了很多的文档,为了集中存放,决定搭建一个服务器跑samba.也不知道当初怎么想的,就决定下来用unix系统了,这中间从freebsd到solairs真是个折腾啊!(不然也不会有这篇文章^_^). 其实用windows或是linux安装配置起来都是很方便的,选择unix的原因无非是因其稳定,相对安全.刚开始是在虚拟机中安装freebsd来着,想着如果能够在虚拟机里配置起来再配真机应该很简单了,可是在虚拟机的freebsd中采用的是最小安装,然后执行安装samba时,那可叫个