机器都在不停学习
人怎么能停止脚步

利用sklearn对鸢尾花数据集进行逻辑回归

这几日一直在学习逻辑回归,使用sklearn进行逻辑回归的时候,用到了鸢尾花数据集,这里记录下来。

# 导入必要的几个包
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_iris   
from sklearn.linear_model import LogisticRegression 

sklearn.datasets中的鸢尾花数据集一共包含4个特征变量,1个类别变量。共有150个样本,这里存储了其萼片和花瓣的长宽,共4个属性,鸢尾植物分三类,种类分别为山鸢尾、杂色鸢尾、维吉尼亚鸢尾。。

sklearn.datasets已经包含了鸢尾花数据集

# 载入数据集,Y的值有0,1,2三种情况,每种特征50个样本
iris = load_iris()         
X = iris.data[:, :2]   #获取花卉两列数据集
Y = iris.target

这里可以看到X是特征对应的值,这里只取了两个特征

Y的值有0,1,2三种情况,分别表示三种不同的花,每种特征50个样本

#逻辑回归模型,C=1e5表示目标函数。
lr = LogisticRegression(C=1e5)  
lr = lr.fit(X,Y)

# 将样本集花在坐标上
plt.scatter(X[:50], Y[:50], color='red', marker='o', label='setosa') #前50个样本
plt.scatter(X[50:100], Y[50:100], color='blue', marker='x', label='versicolor') #中间50个
plt.scatter(X[100:], Y[100:],color='green', marker='+', label='Virginica') #后50个样本
plt.legend(loc=2) #左上角
plt.show()

获取的鸢尾花两列数据,对应为花萼长度和花萼宽度,每个点的坐标就是(x,y)。 先取X二维数组的第一列(长度)的最小值、最大值和步长h(设置为0.02)生成数组,再取X二维数组的第二列(宽度)的最小值、最大值和步长h生成数组, 最后用meshgrid函数生成两个网格矩阵xx和yy:

# meshgrid函数生成两个网格矩阵,h表示步进长度
h = .01
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

xx的值

yy的值

调用ravel()函数将xx和yy的两个矩阵转变成一维数组,由于两个矩阵大小相等,因此两个一维数组大小也相等。

将xx与yy分别转换为一维数组

使用np.c_[xx.ravel(), yy.ravel()]合并两个数组。(这里写了一些关于numpy中np.c_和np.r_的介绍)

数组合并成功

调用predict()函数进行预测,预测结果赋值给Z

Z = lr.predict(np.c_[xx.ravel(), yy.ravel()])

预测结果Z

# 对预测的结果进行可视化
Z = Z.reshape(xx.shape)
plt.figure(1, figsize=(8,6))
plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired)
plt.show()

调用pcolormesh()函数将xx、yy两个网格矩阵和对应的预测结果Z绘制在图片上,可以发现输出为三个颜色区块,分布表示分类的三类区域。cmap=plt.cm.Paired表示绘图样式选择Paired主题。输出的区域如下图所示:

预测结果

最后,将样本集一起画到图中,方便进行对比。

# 导入样本集,绘制散点图
plt.scatter(X[:50,0], X[:50,1], color='red',marker='*', label='setosa')
plt.scatter(X[50:100,0], X[50:100,1], color='blue', marker='x', label='versicolor')
plt.scatter(X[100:,0], X[100:,1], color='green', marker='s', label='Virginica') 

plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xticks(())
plt.yticks(())
plt.legend(loc=2) 
plt.show()

样本集与预测结果

输出如上图所示,经过逻辑回归后划分为三个区域,左上角部分为红色的圆点,对应setosa鸢尾花;右上角部分为绿色方块,对应virginica鸢尾花;中间下部分为蓝色星形,对应versicolor鸢尾花。散点图为各数据点真实的花类型,划分的三个区域为数据点预测的花类型,预测的分类结果与训练数据的真实结果结果基本一致,部分鸢尾花出现交叉。

回归算法作为统计学中最重要的工具之一,它通过建立一个回归方程用来预测目标值,并求解这个回归方程的回归系数。本篇文章详细讲解了逻辑回归模型的原理知识,结合Sklearn机器学习库的LogisticRegression算法分析了鸢尾花分类情

赞(0)
转载请注明出处机器在学习 » 利用sklearn对鸢尾花数据集进行逻辑回归
分享到: 更多 (0)

评论 抢沙发

Scroll Up