第一个机器学习模型

前面已经介绍了一个机器学习问题,一般分为以下三步:

  1. 分析问题,数据探索及预处理
  2. 选择合适的模型,利用训练集训练模型
  3. 模型评估及使用,利用测试集评估模型

下面,我们按照这三个步骤来做一个简单的机器学习模型:鸢尾花分类。
鸢尾花是一种植物,如下图所示。

它有四个特征:

  • 花萼(sepal)长度
  • 花萼宽度
  • 花瓣(petal)长度
  • 花瓣宽度

这种植物有三个类别,分别是:Setosa,Versicolour,Virginica。
我们现在的任务是根据这四个特征对鸢尾花进行分类。
下面直接进入代码演示环节。

分析问题,数据探索及预处理

鸢尾花数据集可以从sklearn中导入,代码如下。

from sklearn.datasets import load_iris
iris_dataset=load_iris()

说明:这里用到了sklearn库,需要提前安装,安装命令:pip install scikit-learn。
如果网速太慢,可以使用国内镜像源安装:
pip install scikit-learn -i https://pypi.tuna.tsinghua.edu.cn/simple

对于sklearn中导入的数据,可以通过DESCR属性查看数据的描述。

print(iris_dataset['DESCR'])

运行结果:

.. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica

    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fisher's paper. Note that it's the same as in R, but not as in the UCI
Machine Learning Repository, which has two wrong data points.
此处省略一万字......
   - Many, many more ...

从以上输出结果可以看到以下信息:

  • Iris数据集有150个样本,每种类别有50个。
  • 4个特征:花萼长度、宽度,花瓣长度、宽度,单位都是cm。
  • 3个类别:分别是Setosa、Versicolour、Virginica。
  • 还给出了数据集的描述性统计分析:每个属性的最值、均值、标准差、相关性等。
  • 最后还给出关于这个数据集更为详细的描述,有兴趣的话可以自己阅读。

通过data属性可以查看样本的特征数据,这里显示前10条记录。

print(iris_dataset['data'][:10])

运行结果:

[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]
 [5.4 3.9 1.7 0.4]
 [4.6 3.4 1.4 0.3]
 [5.  3.4 1.5 0.2]
 [4.4 2.9 1.4 0.2]
 [4.9 3.1 1.5 0.1]]

通过shape属性可以查看数据的形状。

print(iris_dataset['data'].shape)

运行结果:

(150, 4)

通过feature_names属性可以获取特征名称。

print(iris_dataset['feature_names'])

运行结果:

['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

接着,通过target属性查看样本的标签数据。

print(iris_dataset["target"])

运行结果:

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]

说明:之前提到的三个类别是通过0,1,2来标注的,0对应的是setosa,1对应的是versicolor,2对应的是virginica。
通过target_names属性可以获取样本的标签名称。

print(iris_dataset['target_names'])

运行结果:

['setosa' 'versicolor' 'virginica']

选择合适的模型,利用训练集训练模型

接下来进行建模,建模前需要将收集好的带标签的数据分成两部分,一部分数据用于构建机器学习模型,叫做训练集(train set),其余的数据用来评估模型性能,叫做测试集(test set)。
scikit-learn中的train_test_split函数可以帮助我们实现这一功能。

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(iris_dataset['data'],iris_dataset['target'],test_size=0.3,random_state=0)

说明:

  • sklearn中的train_test_split函数可以将数据按照一定比例分成测试集和训练集。
  • 这里将总体数据的70%作为训练集,30%作为验证集。
  • random_state指定随机数种子,使得函数的输出固定不变,方便结果重现。

这里采用机器学习中最简单的算法,K近邻算法(KNN)。

from sklearn.neighbors import KNeighborsClassifier

knn=KNeighborsClassifier(n_neighbors=1) #选取邻居数为1
knn.fit(X_train,y_train)

这一步主要是用训练集去训练KNN算法,设置K=1,K是指它的邻居数,KNN算法的原理后面会讲。

模型评估及使用,利用测试集评估模型

接下来,使用测试集对模型进行评估,这里直接调用sklearn中的score方法可以直接给出精度,也就是预测正确的比例。

knn.score(X_test,y_test)

运行结果:

0.9736842105263158

从上面可以看到,精度达到了97%,说明模型的效果还算不错,接着进行预测吧!
例如,我们在野外发现了一朵鸢尾花,花萼长5cm宽2.9cm,花瓣长1cm宽0.2cm,这朵鸢尾花属于哪个品种呢?

import numpy as np
X1=np.array([[5,2.9,1,0.2]])
y1=knn.predict(X1)
print(y1)

运行结果:

[0]

从上面看出,这朵新的鸢尾花的类别标签是0,也就是说它属于Setosa品种。
以上就是一个基本的机器学习应用。

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注