博客
关于我
matlab 神经网络ann用于分类方法
阅读量:141 次
发布时间:2019-02-28

本文共 3570 字,大约阅读时间需要 11 分钟。

matlab关于ann的分类方法讲解了一个例子,Fishr集上鸢尾花(Iris)的分类,学习了这个方法可以套用在个人项目上使用,万变不离其宗,

1、Fishr集上鸢尾花Iris数据集的分类

①iris数据集简介
iris数据集的中文名是安德森鸢尾花卉数据集,英文全称是Anderson’s Iris data set。iris包含150个样本,对应数据集的每行数据。每行数据包含每个样本的四个特征和样本的类别信息,所以iris数据集是一个150行5列的二维表。
通俗地说,iris数据集是用来给花做分类的数据集,每个样本包含了花萼长度、花萼宽度、花瓣长度、花瓣宽度四个特征(前4列),我们需要建立一个分类器,分类器可以通过样本的四个特征来判断样本属于山鸢尾、变色鸢尾还是维吉尼亚鸢尾(这三个名词都是花的品种)。
iris的每个样本都包含了品种信息,即目标属性(第5列,也叫target或label)。

样本局部截图:

load iris.dat

 将数据集载入到工作区,部分数据集如图所示。数据集的前四列分别为与鸢尾花种类相关的4个特征值,对应上图中的花萼长度、花萼宽度、花瓣长度及花瓣宽度;第五列为鸢尾花所属种类,分为1-Setosa、2-Versicolour、3-Virginica三类。

②数据预处理

这里的神经网络属于监督学习的模式,因此需要从上述数据集中分离出训练集和测试集,我们分别记为trainData和testData。我们从iris数据集中选取2/3数据作为训练集trainData,选取1/3数据作为测试集testData,并分别将其保存至trainData.txt和testData.txt文件,用于程序的数据导入源,方法见另外一个博客()

③分类源程序

%读取训练数据clearclc%------本代码采用ANN对鸢尾花进行分类,程序运行前,请准备好#鸢尾花样本#的训练集和测试集(可在MATLAB中载入iris.dat查看数据)-----%f1 f2 f3 f4是四个特征值[f1,f2,f3,f4,class] = textread('trainData.txt' , '%f%f%f%f%f',150);%特征值归一化[input,minI,maxI] = premnmx( [f1 , f2 , f3 , f4 ]')  ;%构造输出矩阵s = length( class) ;output = zeros( s , 3  ) ;for i = 1 : s    output( i , class( i )  ) = 1 ;end%创建神经网络net = newff( minmax(input) , [10 3] , { 'logsig' 'purelin' } , 'traingdx' ) ; %{    minmax(input):获取4个输入信号(存储在f1 f2 f3 f4中)的最大值和最小值;    [10,3]:表示使用2层网络,第一层网络节点数为10,第二层网络节点数为3;    { 'logsig' 'purelin' }:        表示每一层相应神经元的激活函数;        即:第一层神经元的激活函数为logsig(线性函数),第二层为purelin(对数S形转移函数)    'traingdx':表示学习规则采用的学习方法为traingdx(梯度下降自适应学习率训练函数)%}%设置训练參数net.trainparam.show = 50 ;% 显示中间结果的周期net.trainparam.epochs = 500 ;%最大迭代次数(学习次数)net.trainparam.goal = 0.01 ;%神经网络训练的目标误差net.trainParam.lr = 0.01 ;%学习速率(Learning rate)%开始训练%其中input为训练集的输入信号,对应output为训练集的输出结果net = train( net, input , output' ) ;%================================训练完成====================================%%=============================接下来进行测试=================================% %读取测试数据[t1 t2 t3 t4 c] = textread('testData.txt' , '%f%f%f%f%f',150); %测试数据归一化testInput = tramnmx ( [t1,t2,t3,t4]' , minI, maxI ) ;%[testInput,minI,maxI] = premnmx( [t1 , t2 , t3 , t4 ]')  ;%仿真%其中net为训练后得到的网络,返回的Y为Y = sim( net , testInput )  %统计识别正确率[s1 , s2] = size( Y ) ;hitNum = 0 ;for i = 1 : s2    [m , Index] = max( Y( : ,  i ) ) ;    if( Index  == c(i)   )         hitNum = hitNum + 1 ;     endendsprintf('识别率是 %3.3f%%',100 * hitNum / s2 )

④代码的相关说明

A. 语句net = newff( minmax(input) , [10 3] , { 'logsig' 'purelin' } , 'traingdx' ) ;用于创建神经网络,其参数含义和用法如下:

    (1)minmax(input):获取4个输入信号(存储在f1 f2 f3 f4中)的最大值和最小值;

    (2) [10,3]:表示使用2层网络,第一层网络节点数为10,第二层网络节点数为3。其中最后一层的网络包含的节点数一定要与网络的理论输出个数保持一致,例如本例中鸢尾花的种类数为3,因此最后一层的网络节点数为3;
    (3){ 'logsig' 'purelin' }:表示每一层相应神经元的激活函数,即:第一层神经元的激活函数为logsig(线性函数),第二层为purelin(对数S形转移函数),其他激活函数和用法请参见神经网络与深度学习之激活函数;
    (4) 'traingdx':表示学习规则采用的学习方法为traingdx(梯度下降自适应学习率训练函数)。常见的训练函数(学习方法)有:

  traingd :梯度下降BP训练函数(Gradient descent backpropagation)

  traingdx :梯度下降自适应学习率训练函数

    (5)创建的神经网络用MATLAB神经网络工具箱显示如图,图中更形象的展示了构造的神经网络模型。

B. 关于正确率的统计算法的说明

第一次看到这里的正确率统计算法时,我自己是不大明白的,之后又从网上搜了一些资料并查阅了MATLAB的帮助文档,才明白代码的含义。

语句net = train( net, input , output' ) ;是对网络进行训练,该语句明确了网络的输出为output,通过对output矩阵的构造方式分析,我们可知网络的输出可以看成3个,我们不妨即为C1、C2、C3,分别代表鸢尾花的三个种类,例如:

(1)当output的某一行为1 0 0,则说明该花属于C1类

(2)当output的某一行为0 1 0,则说明该花属于C2类

(3)当output的某一行为0 1 0,则说明该花属于C3类

语句Y = sim( net , testInput ) 是对训练后的网络net进行仿真测试,测试用的数据为testInput;这里,Y返回的是网络训练后对测试输入的预测值,例如:

(1)当Y的某一行为1.0220  -0.0020  -0.0091,代表输出结果C1=1.0220, 输出结果C2=-0.0020,C3=-0.0091

(2)当Y的某一行为-0.0108  0.9884  -0.0216,代表输出结果C1=-0.0108,输出结果C2=0.9884,C3=-0.0216

输出结果中只包含一个1和两个0是理想情况下的结果,在进行仿真时,分类输出往往达不到这样的结果,但我们可以根据哪个结果对应的值与1的接近程度来进行判断,例如仿真结果(1)说明该花极有可能属于C1类,仿真结果(2)说明该花极有可能属于C2类。

*从以上描述中我们可以明白神经网络算法也可以应用于具体数值的预测,且应用广泛。

 

转载地址:http://iebc.baihongyu.com/

你可能感兴趣的文章
MySQL
查看>>
MySQL
查看>>
mysql
查看>>
MTK Android 如何获取系统权限
查看>>
MySQL - 4种基本索引、聚簇索引和非聚索引、索引失效情况、SQL 优化
查看>>
MySQL - ERROR 1406
查看>>
mysql - 视图
查看>>
MySQL - 解读MySQL事务与锁机制
查看>>
MTTR、MTBF、MTTF的大白话理解
查看>>
mt_rand
查看>>
mysql /*! 50100 ... */ 条件编译
查看>>
mudbox卸载/完美解决安装失败/如何彻底卸载清除干净mudbox各种残留注册表和文件的方法...
查看>>
mysql 1264_关于mysql 出现 1264 Out of range value for column 错误的解决办法
查看>>
mysql 1593_Linux高可用(HA)之MySQL主从复制中出现1593错误码的低级错误
查看>>
mysql 5.6 修改端口_mysql5.6.24怎么修改端口号
查看>>
MySQL 8.0 恢复孤立文件每表ibd文件
查看>>
MySQL 8.0开始Group by不再排序
查看>>
mysql ansi nulls_SET ANSI_NULLS ON SET QUOTED_IDENTIFIER ON 什么意思
查看>>
multi swiper bug solution
查看>>
MySQL Binlog 日志监听与 Spring 集成实战
查看>>