打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
如何导出CNTK网络节点训练结果?

前言

CNTK作为微软在机器学习领域的重点项目,已经成功的应用于各种产品之中,CNTK可以快速且方便的训练神经网络并使用神经网络。其中CNTK的一大亮点就是CNTK的速度,在机器学习领域,训练速度以及计算决定了工具是否好用。

但是一些情况下,也许我们只需要使用CNTK对于网络训练的速度,之后导出训练的结果用于其他第三方软件使用,也许是绘图,也许是进行实际的生产计算。

通过dumpnode命令导出节点数据

dumpnode是一个action,在配置文件中可以进行指定,通过这个命令可以导出指定的节点信息(包括其结果)以文本的形式导出至一个文件。在使用这个dumpnode时需要制定如下参数。

modelPath 指定网络模型的文件。一般这个参数再配置文件中应该已经在上级给出了指定。
nodeName (可选)指定用于导出的网络节点名称。如果不指定或者指定的是一个不存在的名称则导出所有的节点。
nodeNameRegex (可选)可以通过正则表达式的方式指定网络节点名称,导出相匹配的节点。如果指定了这个参数的话,nodeName将值会被忽视。
outputFile (可选)指定导出文件的位置,如果不指定的情况下,默认会被设置为通网络模型文件同路径下的一个文本文件。在网络模型的文件名后面加上.{nodename}.txt
printValues (可选)是否导出网络节点的值,默认值为true
printMetadata (可选)是否导出元数据信息,包括节点名称、维度等,默认值为true

我们以Simple2D(位于CNTK的~\CNTK\Examples\Other\Simple2d\目录)为例,
我们只需要在配置文件末尾的位置添加如下内容,不需要设置modelPath是因为modelPath已经在上面指定了,

#########################################  DUMP NODE INFORMATION               #########################################Simple_Demo_DumpNode=[    action = "dumpnode"]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

之后我们在配置文件的command中指定新设置的Simple_Demo_DumpNode即可,

command = Simple_Demo_Train:Simple_Demo_Test:Simple_Demo_Output:Simple_Demo_DumpNode
  • 1
  • 1

之后从新运行后,运行的结果如下,并在其中会多出有关dumpnode的相关内容,


(注,那个Action "dumpnode" complete.是笔者后期标黄的,不是默认输出就是黄色的。)

此时,我们可以去往model所在目录(~\CNTK\Examples\Other\Simple2d\Output\Models)打开导出的文件,

默认情况下,是导出的网络节点是包括参数以及元数据的,所以数据信息比较多。

通过探究CNTK源码来实现

CNTK可以通过dumpnode命令给出的文本格式的导出文件,如果是用来看的则没什么大问题,但是如果希望实现自动化的某些操作,则需要解析这个文件后输入下游程序中,解析过程也许会十分的繁琐,

我们也可以通过探究CNTK的源码来找出节点数据存储的位置,直接使用CNTK的文件即可,
这种情况,只需要我们在我们的程序中加载CNTK中的EvalDLL读取网络模型文件后,直接使用其参数,后者根据需要导出成我们所期望的形式。

CNTK中核心的一个类是ComputationNetwork,一个网络模型对应着一个ComputationNetwork的实例。
如果我们只是为了读取网络模型,那么我们可以通过如下代码打开一个CNTK的网络模型文件。

ComputationNetwork net(-1);    // -1 means we only need to use CPUnet.Load<ElemType>(modelPath); // ElemType should be a float type (float or double)
  • 1
  • 2
  • 1
  • 2

这样我们就可以将模型文件加载到net中,进而我们可以通过net的方法去获取我们所需要的内容,

class ComputationNetwork : ...{...public:    ComputationNetwork();    ComputationNetwork(DEVICEID_TYPE deviceId);    virtual ~ComputationNetwork();    template <class ElemType>     void Load(const std::wstring& fileName);    ComputationNodeBasePtr GetNodeFromName(const std::wstring& name) const;    // GetNodesFromName - Get all the nodes from a name that may match a wildcard '*' pattern    //   only patterns with a single '*' at the beginning, in the middle, or at the end are accepted    // name - node name (with possible wildcard)    // returns: vector of nodes that match the pattern, may return an empty vector for no match    std::vector<ComputationNodeBasePtr> GetNodesFromName(const std::wstring& name) const;    // these are specified as such by the user    const std::vector<ComputationNodeBasePtr>& FeatureNodes();    const std::vector<ComputationNodeBasePtr>& LabelNodes();    const std::vector<ComputationNodeBasePtr>& FinalCriterionNodes();    const std::vector<ComputationNodeBasePtr>& EvaluationNodes();    const std::vector<ComputationNodeBasePtr>& OutputNodes();    ...private:    // main node holder    std::map<const std::wstring, ComputationNodeBasePtr, nocase_compare> m_nameToNodeMap; // [name] -> node; this is the main container that holds this networks' nodes    // node groups    // These are specified by the user by means of tags or explicitly listing the node groups.    // TODO: Are these meant to be disjoint?    std::vector<ComputationNodeBasePtr> m_featureNodes;    // tag="feature"    std::vector<ComputationNodeBasePtr> m_labelNodes;      // tag="label"    std::vector<ComputationNodeBasePtr> m_criterionNodes;  // tag="criterion"    std::vector<ComputationNodeBasePtr> m_evaluationNodes; // tag="evaluation"    std::vector<ComputationNodeBasePtr> m_outputNodes;     // tag="output"...};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

在Net中,我们能够获取每个节点的ComputationNodeBasePtr, 我们进而可以通过这个指针来进行其节点中数据或者参数的访问。

class ComputationNodeBase : ...{...public:    // -----------------------------------------------------------------------    // accessors for value and gradient    // -----------------------------------------------------------------------    const Matrix<ElemType>& Value() const { return *m_value; }    Matrix<ElemType>&       Value()       { return *m_value; }    MatrixBasePtr ValuePtr() const override final { return m_value; }    // readers want this as a shared_ptr straight    // Note: We cannot return a const& since returning m_value as a MatrixBasePtr is a type cast that generates a temporary. Interesting.    const Matrix<ElemType>& Gradient() const { return *m_gradient; }    Matrix<ElemType>&       Gradient()       { return *m_gradient; }    MatrixBasePtr GradientPtr() const { return m_gradient; }...};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

这样我们就可以通过Value来获取来获取内部的计算矩阵以应用于实际的项目中。

总结

本文主要是描述了CNTK中网络节点训练结果的导出方法,首先将CNTK当做工具使用情况下的导出方法,之后接受的是将CNTK看做一个开源项目的前提下,通过研究源码的方式针对其内部数据结果进行探究。希望能够对大家使用或者学习CNTK有所帮助,如本文中有任何错误或者读者有任何意见或者建议,请在评论区给出,谢谢。

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
手把手:四色猜想、七桥问题…程序员眼里的图论,了解下?(附大量代码和手绘)
从七桥问题开始:全面介绍图论及其应用
2019-2020Nowcoder Girl初赛题解
Vector类模板界面及其函数的实现
C++ find()函数用法(一般用于vector的查找)
C++ 多态技术
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服