sklearn中的train_test_split()函数解析
train_test_split函数参数解析
·
sklearn中的train_test_split()函数解析
train_test_split()函数:机器学习中用于分割数据集(训练集和测试集)
X_train,X_test,y_train,y_test = train_test_split(X, y, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None)
参数说明:
参数 | 含义 |
---|---|
X | 待划分的样本特征集 |
y | 待划分的样本标签 |
test_size | 默认值为none,值为0.0-1.0时表示测试集占总样本比例;值为整数时表示测试集数量 |
train_size | 默认值为none,值为0.0-1.0时表示训练集占总样本比例;值为整数时表示训练集数量 |
random_state | 默认值none, 随机数种子(下面详细介绍) |
shuffle | 默认值True, 表示是否在拆分前打乱数据, 若为False则stratify必须置为none |
stratify | 默认值none,如果不是none,则以分层方式拆分数据,并将其用作类标签 |
返回值说明:
名称 | 含义 |
---|---|
X_train | 训练数据集 |
X_test | 测试数据集 |
y_train | 训练标签集 |
y_test | 测试标签集 |
- random_state说明:
- 为什么要设置这个参数?
train_test_split 函数将数据集随机拆分成训练集和测试集,如果random_state不设置,则每次运行拆分时得到的训练集和测试集都与上次不相同,构建的模型也就不同
【示例如下】
- 为什么要设置这个参数?
>a, b = np.arange(8).reshape(4,2), range(4)
#a,b的值:
a: [[0 1]
[2 3]
[4 5]
[6 7]]
b:range(0, 4)
>X_train,X_test,y_train,y_test = train_test_split(a,b)
>print('训练数据集:\n{}'.format(X_train))
>print('训练标签集:\n{}'.format(y_train))
>print('测试数据集:\n{}'.format(X_test))
>print('测试标签集:\n{}'.format(y_test))
- 第一次运行结果:
训练数据集:[[6 7] [0 1] [2 3]]
训练标签集:[3, 0, 1]
测试数据集:[[4 5]]
测试标签集:[2] - 第二次运行结果:
训练数据集:[[4 5] [6 7] [2 3]]
训练标签集:[2, 3, 1]
测试数据集:[[0 1]]
测试标签集:[0]
可以看出每次运行结果的拆分方式都是随机的
X_train,X_test,y_train,y_test = train_test_split(a,b,random_state=1)
设置了random_state=1后,每次运行结果都是一样的:
训练数据集:[[4 5] [0 1] [2 3]]
训练标签集:[2, 0, 1]
测试数据集:[[6 7]]
测试标签集:[3]
-
shuffle说明:
能够在划分数据前打乱数据,当数据分布不均衡时,可能会导致划分后的训练集和测试集不均匀,比如测试集中的类0占了99%,类1只占了1%(像sklearn中鸢尾花数据集前50个样本都是同一类别,此时打乱数据很有必要,因此默认值为True是合理的) -
stratify说明:
能够保持划分前类的分布,比如(参考文章):
有100个样本,80个属于标签0,20个属于标签1,如果前面参数test_size=0.25,则
75个训练数据,60个标签为0,15个标签为1
25个测试数据,20个标签为0,5个标签为1
更多推荐
已为社区贡献4条内容
所有评论(0)