sklearn中的train_test_split()函数解析

train_test_split()函数:机器学习中用于分割数据集(训练集和测试集)

X_train,X_test,y_train,y_test = train_test_split(X, y, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None)

参数说明:

参数含义
X待划分的样本特征集
y待划分的样本标签
test_size默认值为none,值为0.0-1.0时表示测试集占总样本比例;值为整数时表示测试集数量
train_size默认值为none,值为0.0-1.0时表示训练集占总样本比例;值为整数时表示训练集数量
random_state默认值none, 随机数种子(下面详细介绍)
shuffle默认值True, 表示是否在拆分前打乱数据, 若为False则stratify必须置为none
stratify默认值none,如果不是none,则以分层方式拆分数据,并将其用作类标签

返回值说明:

名称含义
X_train训练数据集
X_test测试数据集
y_train训练标签集
y_test测试标签集
  1. random_state说明:
    • 为什么要设置这个参数?
      train_test_split 函数将数据集随机拆分成训练集和测试集,如果random_state不设置,则每次运行拆分时得到的训练集和测试集都与上次不相同,构建的模型也就不同
      【示例如下】
>a, b = np.arange(8).reshape(4,2), range(4)
#a,b的值:
a: [[0 1]
   [2 3]
   [4 5]
   [6 7]]
 b:range(0, 4)
 >X_train,X_test,y_train,y_test = train_test_split(a,b)
 >print('训练数据集:\n{}'.format(X_train))
 >print('训练标签集:\n{}'.format(y_train))
 >print('测试数据集:\n{}'.format(X_test))
 >print('测试标签集:\n{}'.format(y_test))
  • 第一次运行结果:
    训练数据集:[[6 7] [0 1] [2 3]]
    训练标签集:[3, 0, 1]
    测试数据集:[[4 5]]
    测试标签集:[2]
  • 第二次运行结果:
    训练数据集:[[4 5] [6 7] [2 3]]
    训练标签集:[2, 3, 1]
    测试数据集:[[0 1]]
    测试标签集:[0]

可以看出每次运行结果的拆分方式都是随机的

X_train,X_test,y_train,y_test = train_test_split(a,b,random_state=1)

设置了random_state=1后,每次运行结果都是一样的:
训练数据集:[[4 5] [0 1] [2 3]]
训练标签集:[2, 0, 1]
测试数据集:[[6 7]]
测试标签集:[3]

  • shuffle说明
    能够在划分数据前打乱数据,当数据分布不均衡时,可能会导致划分后的训练集和测试集不均匀,比如测试集中的类0占了99%,类1只占了1%(像sklearn中鸢尾花数据集前50个样本都是同一类别,此时打乱数据很有必要,因此默认值为True是合理的)

  • stratify说明:
    能够保持划分前类的分布,比如(参考文章):
    有100个样本,80个属于标签0,20个属于标签1,如果前面参数test_size=0.25,则
    75个训练数据,60个标签为0,15个标签为1
    25个测试数据,20个标签为0,5个标签为1

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐