株式会社ホクソエムのブログ

R, Python, データ分析, 機械学習

Rでのナウなデータ分割のやり方: rsampleパッケージによる交差検証

前処理大全の「分割」の章では、予測モデルの評価のためのデータセット分割方法が解説されています。基礎から時系列データへ適用する際の注意まで説明されているだけでなく、awesomeなコードの例がRおよびPythonで書かれており、実践的な側面もあります(お手元にぜひ!)。

しかし今回は、Awesome例とは異なる、より新しいやり方で・簡単にRでのデータ分割を行う方法を紹介したいと思います。前処理大全でも取り上げられているcaretパッケージですが、その開発者のMax Kuhnが開発するパッケージの中に rsample を使う方法です。ここでは前処理大全で書かれている一般的なデータと時系列データの交差検証による分割をrsampleの使い方を紹介しながらやっていきます。加えて、rsampleの層化サンプリングについても最後に触れます。

rsampleは名前の通り、再標本化法、ホールドアウト検証を利用するモデル性能推定用データセットの作成を行います(モデリング等は他のパッケージが担います1これらのtidymodelsを使った一連の処理は id:dropout009さんの記事が参考になります。)。

扱うデータセットは前処理大全で使われているものを利用させていただきます。

github.com

この記事では以下のパッケージに含まれる関数を利用します。メインはrsampleです。

library(readr) # データ読み込み
library(dplyr) # データ操作一般
library(assertr) # データのチェック
library(rsample)

前処理大全[データ分析のためのSQL/R/Python実践テクニック]

前処理大全[データ分析のためのSQL/R/Python実践テクニック]

1. レコードデータにおけるモデル検証用のデータ分割

一般的な交差検証の例として紹介されている、交差数4の交差検証を行う処理をやってみます。 まず20%をテストデータとして確保、残りのデータを交差検証に回します。

rsampleパッケージのinitial_split()でデータセットを訓練とテスト用に分けられます。ここではprop =によりその比率を調整可能です。今回は例題と同じく、訓練に8割のデータが含まれるように指定しました。

# サポートページで公開されているデータを読み込む(製造レコード)
production_tb <- 
  read_csv("https://raw.githubusercontent.com/ghmagazine/awesomebook/master/data/production.csv") %>% 
  verify(expr = dim(.) == c(1000, 4))
# prop では学習データへの分割の比率を指定します
df_split <- initial_split(production_tb, prop = 0.8)

initial_split()の返り値はrsplitオブジェクトです。出力してみると、分割したデータの情報を得ることができます。区切り文字で区切られた数値はそれぞれ、学習データ、テストデータ、元のデータの件数を示します。

df_split

## <801/199/1000>

この時点では実際にデータが分割されている訳ではありません。データの分割は次のtraining()testing()によって実行します。rsplitオブジェクトを引数に渡して実行すると先ほどの件数分のデータがランダムに割り当てられます。

df_train <- 
  training(df_split) %>% # 学習データ
  verify(expr = nrow(.) == 801L)

df_test <- 
  testing(df_split) %>%  # テストデータ
  verify(expr = nrow(.) == 199L)

続いて学習データを交差検証のためにさらに分割していきましょう。rsampleでは関数名vfold_cv()でk分割交差検証 (k-fold cross validation) を実行します(名前こそ違いますが、やっていることは同じです… 学習データをk個に分割、そのうちk-1個を学習用に、残りの1個をモデル精度を評価するために用いる)。

vfold_cv()の結果を見てみるとデータが4行のデータフレームになっているのがわかります。列はsplits,idの2列からなり、各行にFoldのデータセットが含まれています。

train_folds <- vfold_cv(df_train, v = 4)

train_folds

## #  4-fold cross-validation 
## # A tibble: 4 x 2
##   splits            id   
##   <list>            <chr>
## 1 <split [600/201]> Fold1
## 2 <split [601/200]> Fold2
## 3 <split [601/200]> Fold3
## 4 <split [601/200]> Fold4

ここで分割したデータセットの中身をk分割交差検証のイメージと合わせて確認しましょう。例題では交差数が4なので、下記の図のようにデータが分割されています。学習データ全体をk(=4)に分割しk-1を学習用、残りを検証用として利用するようにします。1回分のデータでは検証用のデータがkだけなので分割後のデータがもれなく検証データに割り当てられるよう、kの回数分繰り返されます。

f:id:u_ribo:20190608190640j:plain

Foldのデータを参照するにはanalysis()assessment()を使います。これらの関数はそれぞれ学習、検証データを参照します。

analysis(train_folds$splits[[1]])

## # A tibble: 600 x 4
##    type  length thickness fault_flg
##    <chr>  <dbl>     <dbl> <lgl>    
##  1 E      274.      40.2  FALSE    
##  2 D       86.3     16.9  FALSE    
##  3 E      124.       1.02 FALSE    
##  4 B      245.      29.1  FALSE    
##  5 B      226.      39.8  FALSE    
##  6 A      201.      12.2  FALSE    
##  7 C      276.      29.9  FALSE    
##  8 E      215.      41.8  FALSE    
##  9 D      218.      11.8  FALSE    
## 10 C      239.      18.9  FALSE    
## # … with 590 more rows

assessment(train_folds$splits[[1]])

## # A tibble: 201 x 4
##    type  length thickness fault_flg
##    <chr>  <dbl>     <dbl> <lgl>    
##  1 C       332.     16.8  FALSE    
##  2 E       168.      1.27 FALSE    
##  3 E       218.     39.6  FALSE    
##  4 C       326.     44.3  FALSE    
##  5 A       132.      4.72 TRUE     
##  6 D       182.     19.6  FALSE    
##  7 B       153.      1.10 FALSE    
##  8 B       150.     11.0  FALSE    
##  9 D       157.     11.2  FALSE    
## 10 A       206.     21.8  FALSE    
## # … with 191 more rows

zeallotによる代入

Pythonでのa, b = 0, 1といった parallel assignment を可能にするzeallotパッケージの演算子を使うと学習・テストデータへの割り当ては次のように実行できます。

library(zeallot)

df_split = initial_split(production_tb,  p = 0.8)
c(df_train, df_test) %<-% list(
  training(df_split),
  testing(df_split))

2. 時系列データにおけるモデル検証用のデータ分割

f:id:u_ribo:20190608190613j:plain

先ほどの無作為に行われる学習データ、検証データの分割を時系列データに適用すると、学習データに未来と過去のデータが混同してしまうことになるため単純なk分割交差検証ではダメだと前処理大全では記されています。またそれに対する方法として、データ全体を時系列に並べ、学習と検証に利用するデータをスライドさせていくという処理が紹介されています。これもrsampleでやってみましょう。今度のデータは月ごとの経営指標のデータセットとなっており、行ごとに各月の値が記録されています。先ほどと同じく、サポートページからデータを読み込みんだらデータ型といくつかの行の値を表示してみましょう。

monthly_index_tb <- 
  read_csv("https://raw.githubusercontent.com/ghmagazine/awesomebook/master/data/monthly_index.csv")

glimpse(monthly_index_tb)

## Observations: 120
## Variables: 3
## $ year_month      <chr> "2010-01", "2010-02", "2010-03", "2010-04", "20…
## $ sales_amount    <dbl> 7191240, 6253663, 6868320, 7147388, 8755929, 83…
## $ customer_number <dbl> 6885, 6824, 7834, 8552, 8171, 8925, 10104, 1123…

ここでyear_monthの値がYYYY-MMの形式で与えられていて、2010年1月を起点として並べられていること、データ型が文字列であることに注意してください。次の時系列データのためのデータ分割を適用するrolling_origin()は日付データに限らず、ある並びを考慮してランダムではない方法での抽出を行います。

例題の通り、学習用24ヶ月(周期性をみるために2年)、検証用12ヶ月のデータとなるようにデータを分割します。スキップの単位も12ヶ月です。これらのオプションは引数で指定可能です。

df_split <- 
  rolling_origin(monthly_index_tb, 
                 initial = 24, 
                 assess = 12, 
                 skip = 12, 
                 cumulative = FALSE)

df_split

## # Rolling origin forecast resampling 
## # A tibble: 7 x 2
##   splits          id    
##   <list>          <chr> 
## 1 <split [24/12]> Slice1
## 2 <split [24/12]> Slice2
## 3 <split [24/12]> Slice3
## 4 <split [24/12]> Slice4
## 5 <split [24/12]> Slice5
## 6 <split [24/12]> Slice6
## 7 <split [24/12]> Slice7

24、12行にデータが分かれたことがわかります。またデータセット全体では120行あるため、7通りの学習、検証データセットがあります。分割後の値を参照するには再びanalysis()assessment()を使います。最初のsplitデータでは2010年1月から24ヶ月分のデータ、つまり2011年12月までが含まれています。同様に検証データでは2012年月からの12ヶ月の値が格納されています。

analysis(df_split$splits[[1]]) %>% 
  verify(expr = nrow(.) == 24L)

## # A tibble: 24 x 3
## 省略


assessment(df_split$splits[[1]]) %>% 
  verify(expr = nrow(.) == 12L)

## # A tibble: 12 x 3
## 省略

少数データへの対策として出されている、学習データを増やしてく処理にはcumulative = TRUEを指定するだけです(デフォルトでTRUE)。

df_split <- 
  rolling_origin(monthly_index_tb, 
                 initial = 24, 
                 assess = 12, 
                 skip = 12, 
                 cumulative = TRUE)

df_split

## # Rolling origin forecast resampling 
## # A tibble: 7 x 2
##   splits           id    
##   <list>           <chr> 
## 1 <split [24/12]>  Slice1
## 2 <split [37/12]>  Slice2
## 3 <split [50/12]>  Slice3
## 4 <split [63/12]>  Slice4
## 5 <split [76/12]>  Slice5
## 6 <split [89/12]>  Slice6
## 7 <split [102/12]> Slice7

今度の分割では学習データの件数が分割のたびに増えていることに注意してください。

# 最初の分割データセットでは学習24、検証に12のデータ
analysis(df_split$splits[[1]]) %>% 
  verify(expr = nrow(.) == 24L)

## # A tibble: 24 x 3
## 省略

assessment(df_split$splits[[1]]) %>% 
  verify(expr = nrow(.) == 12L)

## # A tibble: 12 x 3
## 省略

# 2番目の分割データセットには最初の分割データと同じ期間 + 13件のデータ
# 最初と最後の行を確認
analysis(df_split$splits[[3]]) %>% 
  slice(c(1, nrow(.)))

## # A tibble: 2 x 3
##   year_month sales_amount customer_number
##   <chr>             <dbl>           <dbl>
## 1 2010-01         7191240            6885
## 2 2014-02        41809454           35630

おまけ: 層化抽出法

データに含まれる出身地や性別などの属性を「層」として扱い、層ごとに抽出を行う方法として(層化サンプリング stratified sampling)があります。層化抽出法は母集団の各層の比率を反映して抽出を行う方法で、無作為抽出よりもサンプル数が少ない層を抽出可能にするものです。rsampleではstrata引数がオプションに用意されており、これを分割用の関数実行時に層の名前を指定して実行することで層化サンプリングを実現します。

例としてアヤメのデータセットを使います。元のデータは3種 (Species)が50件ずつ均等に含まれているため130件に限定して偏りを生じさています。

iris %>% 
  count(Species) %>%
  mutate(prop = n / sum(n))

## # A tibble: 3 x 3
##   Species        n  prop
##   <fct>      <int> <dbl>
## 1 setosa        50 0.333
## 2 versicolor    50 0.333
## 3 virginica     50 0.333

# データセットの一部を抽出し、データセットに含まれる件数を種ごとに変える
df_iris_subset <- iris[1:130, ]
df_iris_subset %>% 
  count(Species) %>%
  mutate(prop = n / sum(n))

## # A tibble: 3 x 3
##   Species        n  prop
##   <fct>      <int> <dbl>
## 1 setosa        50 0.385
## 2 versicolor    50 0.385
## 3 virginica     30 0.231

加工したデータでは 3種のアヤメのうち、virginicaが30件(23%)と減っています。

それでは層化しない方法と比較してみましょう。次の例は、k=5に分割したデータに含まれるvirginicaの割合を示します。2番目のstrata = "Species"を与えて実行したものが層化サンプリングの結果です。

# ホールドごとに含まれる割合が異なる
set.seed(13)
folds1 <- vfold_cv(df_iris_subset, v = 5)
purrr::map_dbl(folds1$splits,
               function(x) {
                 dat <- as.data.frame(x)$Species
                 mean(dat == "virginica")})

##              1              2              3              4              5 
## 0.259615384615 0.230769230769 0.240384615385 0.211538461538 0.211538461538

# strata = による層化を行うことで元データの偏りを反映してサンプリング
set.seed(13)
folds2 <- vfold_cv(df_iris_subset, strata = "Species", v = 5)
purrr::map_dbl(folds2$splits,
               function(x) {
                 dat <- as.data.frame(x)$Species
                 mean(dat == "virginica")})

##              1              2              3              4              5 
## 0.230769230769 0.230769230769 0.230769230769 0.230769230769 0.230769230769

層化サンプリングした場合では、ホールド間で元データの偏り(virginica ... 23%)を反映することができました。便利ですね。

Enjoy!


  1. Max Kuhnが中心となって整備されているtidymodelsを利用すると便利です。tidymodelsはざっくりいうとtidyverseのデータモデリング版。tidyverseのパッケージや、パイプ処理フレンドリーな統一的関数を提供するパッケージ群。rsampleの他にinferparsnipなど