打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
用Keras TensorFlow,实现ImageNet数据集日常对象的识别
王新民 编译自 Deep Learning Sandbox博客
量子位 出品 | 公众号 QbitAI

在计算机视觉领域里,有3个最受欢迎且影响非常大的学术竞赛:ImageNet ILSVRC(大规模视觉识别挑战赛),PASCAL VOC(关于模式分析,统计建模和计算学习的研究)和微软COCO图像识别大赛。这些比赛大大地推动了在计算机视觉研究中的多项发明和创新,其中很多都是免费开源的。

博客Deep Learning Sandbox作者Greg Chu打算通过一篇文章,教你用Keras和TensorFlow,实现对ImageNet数据集中日常物体的识别。

量子位翻译了这篇文章:

你想识别什么?

看看ILSVRC竞赛中包含的物体对象。如果你要研究的物体对象是该列表1001个对象中的一个,运气真好,可以获得大量该类别图像数据!以下是这个数据集包含的部分类别:

椅子
汽车键盘箱子
婴儿床旗杆iPod播放器
轮船面包车项链
降落伞枕头桌子
钱包球拍步枪
校车萨克斯管足球
袜子舞台火炉
火把吸尘器自动售货机
眼镜红绿灯菜肴
盘子西兰花红酒
 表1 ImageNet ILSVRC的类别摘录

完整类别列表见:https://gist.github.com/gregchu/134677e041cd78639fea84e3e619415b

如果你研究的物体对象不在该列表中,或者像医学图像分析中具有多种差异较大的背景,遇到这些情况该怎么办?可以借助迁移学习(transfer learning)和微调(fine-tuning),我们以后再另外写文章讲。

图像识别

图像识别,或者说物体识别是什么?它回答了一个问题:“这张图像中描绘了哪几个物体对象?”如果你研究的是基于图像内容进行标记,确定盘子上的食物类型,对癌症患者或非癌症患者的医学图像进行分类,以及更多的实际应用,那么就能用到图像识别。

Keras和TensorFlow

Keras是一个高级神经网络库,能够作为一种简单好用的抽象层,接入到数值计算库TensorFlow中。另外,它可以通过其keras.applications模块获取在ILSVRC竞赛中获胜的多个卷积网络模型,如由Microsoft Research开发的ResNet50网络和由Google Research开发的InceptionV3网络,这一切都是免费和开源的。具体安装参照以下说明进行操作:

Keras安装:https://keras.io/#installation

TensorFlow安装:https://www.tensorflow.org/install/

实现过程

我们的最终目标是编写一个简单的python程序,只需要输入本地图像文件的路径或是图像的URL链接就能实现物体识别。

以下是输入非洲大象照片的示例:

1. python classify.py --image African_Bush_Elephant.jpg
2. python classify.py --image_url http://i.imgur.com/wpxMwsR.jpg

输入:

输出将如下所示:

 该图像最可能的前3种预测类别及其相应概率

预测功能

我们接下来要载入ResNet50网络模型。首先,要加载keras.preprocessingkeras.applications.resnet50模块,并使用在ImageNet ILSVRC比赛中已经训练好的权重。

想了解ResNet50的原理,可以阅读论文《基于深度残差网络的图像识别》。地址:https://arxiv.org/pdf/1512.03385.pdf

import numpy as np
from keras.preprocessing import image
from keras.applications.resnet50
import ResNet50, preprocess_input, decode_predictionsmodel = ResNet50(weights='imagenet')

接下来定义一个预测函数:

def predict(model, img, target_size, top_n=3):  '''Run model prediction on image  Args:    model: keras model    img: PIL format image    target_size: (width, height) tuple    top_n: # of top predictions to return  Returns:    list of predicted labels and their probabilities  '''  if img.size != target_size:    img = img.resize(target_size)  x = image.img_to_array(img)  x = np.expand_dims(x, axis=0)  x = preprocess_input(x)  preds = model.predict(x)  
return decode_predictions(preds, top=top_n)[0]

在使用ResNet50网络结构时需要注意,输入大小target_size必须等于(224,224)。许多CNN网络结构具有固定的输入大小,ResNet50正是其中之一,作者将输入大小定为(224,224)

image.img_to_array:将PIL格式的图像转换为numpy数组。

np.expand_dims:将我们的(3,224,224)大小的图像转换为(1,3,224,224)。因为model.predict函数需要4维数组作为输入,其中第4维为每批预测图像的数量。这也就是说,我们可以一次性分类多个图像。

preprocess_input:使用训练数据集中的平均通道值对图像数据进行零值处理,即使得图像所有点的和为0。这是非常重要的步骤,如果跳过,将大大影响实际预测效果。这个步骤称为数据归一化。

model.predict:对我们的数据分批处理并返回预测值。

decode_predictions:采用与model.predict函数相同的编码标签,并从ImageNet ILSVRC集返回可读的标签。

keras.applications模块还提供4种结构:ResNet50、InceptionV3、VGG16、VGG19和XCeption,你可以用其中任何一种替换ResNet50。更多信息可以参考https://keras.io/applications/。

绘图

我们可以使用matplotlib函数库将预测结果做成柱状图,如下所示:

def plot_preds(image, preds):    '''Displays image and the top-n predicted probabilities     in a bar graph    Args:        image: PIL image    preds: list of predicted labels and their probabilities    '''    #image  plt.imshow(image)  plt.axis('off')  #bar graph  plt.figure()    order = list(reversed(range(len(preds))))    bar_preds = [pr[2] for pr in preds]  labels = (pr[1] for pr in preds)  plt.barh(order, bar_preds, alpha=0.5)  plt.yticks(order, labels)  plt.xlabel('Probability')  plt.xlim(0, 1.01)  plt.tight_layout()  plt.show()

主体部分

为了实现以下从网络中加载图片的功能:

1. python classify.py --image African_Bush_Elephant.jpg
2. python classify.py --image_url http://i.imgur.com/wpxMwsR.jpg

我们将定义主函数如下:

if __name__=='__main__':  a = argparse.ArgumentParser()  a.add_argument('--image',
help='path to image')  a.add_argument('--image_url',
help='url to image')  args = a.parse_args()
if args.image is None and args.image_url is None:    a.print_help()    sys.exit(1)
if args.image is not None:    img = Image.open(args.image)    print_preds(predict(model, img, target_size))
if args.image_url is not None:    response = requests.get(args.image_url)    img = Image.open(BytesIO(response.content))    print_preds(predict(model, img, target_size))

其中在写入image_url功能后,用python中的Requests库就能很容易地从URL链接中下载图像。

完工

将上述代码组合起来,你就创建了一个图像识别系统。项目的完整程序和示例图像请查看GitHub链接:

https://github.com/DeepLearningSandbox/DeepLearningSandbox/tree/master/image_recognition

招聘

我们正在招募编辑记者、运营等岗位,工作地点在北京中关村,期待你的到来,一起体验人工智能的风起云涌。

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
别磨叽,学完这篇你也是图像识别专家了
使用迁移学习和 TensorFlow 进行食品分类
深入探索图像处理:从基础到高级应用
训练一个自己的分类 | 【包教包会,数据都准备好了】
数据科学家必须知道的10个深度学习架构
干货|多重预训练视觉模型的迁移学习
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服