В последнее время я публикую заметки, которые демонстрируют работу с пакетом tidymodels . Я разбираю как простые, так и более сложные модели. Сегодняшняя заметка подойдет тем, кто только начинает свое знакомство с пакетом tidymodels

Знакомимся с данными

Мы будем работать с набором данных, в котором хранится информация о пингвинах, живущих на архипелаге Палмера. Наша задача — предсказать пол пингвинов. Для этого мы будем использовать модель классификации. 

library(tidyverse)
library(palmerpenguins)

penguins
# A tibble: 344 × 8
   species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex   
   <fct>   <fct>              <dbl>         <dbl>             <int>       <int> <fct> 
 1 Adelie  Torgersen           39.1          18.7               181        3750 male  
 2 Adelie  Torgersen           39.5          17.4               186        3800 female
 3 Adelie  Torgersen           40.3          18                 195        3250 female
 4 Adelie  Torgersen           NA            NA                  NA          NA NA    
 5 Adelie  Torgersen           36.7          19.3               193        3450 female
 6 Adelie  Torgersen           39.3          20.6               190        3650 male  
 7 Adelie  Torgersen           38.9          17.8               181        3625 female
 8 Adelie  Torgersen           39.2          19.6               195        4675 male  
 9 Adelie  Torgersen           34.1          18.1               193        3475 NA    
10 Adelie  Torgersen           42            20.2               190        4250 NA    
# … with 334 more rows, and 1 more variable: year <int>

В нашем датасете есть такие переменные как:

  • species - вид пингвина;

  • island — остров, на котором обитает особь;

  • bill_length_mm — длина клюва;

  • bill_depth_mm — глубина клюва;

  • flipper_length_mm — длина плавника;

  • body_mass_g — масса тела;

  • sex — пол;

  • year — год наблюдения.

Хочу обратить внимание, что если мы решим использовать модель классификации для предсказания вида пингвина, то обнаружим, что переменные о клюве и массе тела дают высокую точность прогноза. Поэтому мы сосредоточимся на предсказании пола особи, так как эта задача чуть интересней.

Посмотрим на переменную пола особи:

penguins %>% 
  count(sex)
# A tibble: 3 × 2
  sex        n
  <fct>  <int>
1 female   165
2 male     168
3 NA        11

В наших данных есть 11 пропущенных значений, что, в целом, не является критической проблемой для нас.

Давайте визуализируем наши данные:

penguins %>%
  #исключим пропущенные значения
  filter(!is.na(sex)) %>% 
  ggplot(aes(flipper_length_mm,
             bill_length_mm,
             color = sex,
             size = body_mass_g)) +
  geom_point(alpha = 0.5) +
  facet_wrap(~species) +
  theme_light()

Похоже, что самки более мелкие, чем особи мужского пола. Что, в целом, очевидно. Давайте приступим к нашему моделированию. Стоит отметить, что мы не будем использовать переменные island и year.

penguins_df <- penguins %>% 
  #Удалим из наших данных пропущенные значения
  filter(!is.na(sex)) %>%
  #Удалим из наших данных переменные island и year
  select(-year, -island)

Построение модели

Начнем наше построение модели с подключения пакета tidymodels и разделим данные на тестовую и обучающую выборки. 

library(tidymodels)

set.seed(123)
# Мы разделим нашу выборку по полу, чтобы соблюсти пропорции самцов и самок
penguin_split <- initial_split(penguins_df, strata = sex)

penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)

Так как у нас недостаточно наблюдений мы можем создать набор данных с повторной выборкой.

set.seed(123)
penguin_boot <- bootstraps(penguin_train)
penguin_boot
# Bootstrap sampling 
# A tibble: 25 × 2
   splits           id         
   <list>           <chr>      
 1 <split [249/93]> Bootstrap01
 2 <split [249/91]> Bootstrap02
 3 <split [249/90]> Bootstrap03
 4 <split [249/91]> Bootstrap04
 5 <split [249/85]> Bootstrap05
 6 <split [249/87]> Bootstrap06
 7 <split [249/94]> Bootstrap07
 8 <split [249/88]> Bootstrap08
 9 <split [249/95]> Bootstrap09
10 <split [249/89]> Bootstrap10
# … with 15 more rows

Давайте сравним две различные модели: модель логистической регрессии и модель случайного леса. Начнем с создания спецификаций моделей.

# спецификация для логистичсекой модели
glm_spec <- logistic_reg() %>%
  set_engine("glm")

glm_spec
Logistic Regression Model Specification (classification)

Computational engine: glm 
# спецификация для модели случайного леса
rf_spec <- rand_forest() %>%
  set_mode("classification") %>%
  set_engine("ranger")

rf_spec
Random Forest Model Specification (classification)

Computational engine: ranger 

Сейчас это пустые спецификации модели, в которых нет данных. Нам нужно “собрать” модель и для этого мы будем использовать функцию workflow()

На первом шаге нам нужно указать формулу, где мы выберем переменные для предсказания пола особи.

penguin_wf <- workflow() %>%
  add_formula(sex ~ .) # будем использовать все переменные для предсказания пола особи

penguin_wf
══ Workflow ══════════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: None

── Preprocessor ──────────────────────────────────────────────────────────────────────
sex ~ .

На втором шаге, нам нужно добавить спецификацию модели. Начнем с логистической регрессии. 

glm_rs <- penguin_wf %>%
  add_model(glm_spec) %>%
  fit_resamples(
    resamples = penguin_boot,
    control = control_resamples(save_pred = TRUE)
  )

glm_rs
# Resampling results
# Bootstrap sampling 
# A tibble: 25 × 5
   splits           id          .metrics         .notes           .predictions     
   <list>           <chr>       <list>           <list>           <list>           
 1 <split [249/93]> Bootstrap01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [93 × 6]>
 2 <split [249/91]> Bootstrap02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [91 × 6]>
 3 <split [249/90]> Bootstrap03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [90 × 6]>
 4 <split [249/91]> Bootstrap04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [91 × 6]>
 5 <split [249/85]> Bootstrap05 <tibble [2 × 4]> <tibble [1 × 3]> <tibble [85 × 6]>
 6 <split [249/87]> Bootstrap06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [87 × 6]>
 7 <split [249/94]> Bootstrap07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [94 × 6]>
 8 <split [249/88]> Bootstrap08 <tibble [2 × 4]> <tibble [1 × 3]> <tibble [88 × 6]>
 9 <split [249/95]> Bootstrap09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [95 × 6]>
10 <split [249/89]> Bootstrap10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [89 × 6]>
# … with 15 more rows

Сделаем тоже самое для модели случайного леса:

rf_rs <- penguin_wf %>%
  add_model(rf_spec) %>%
  fit_resamples(
    resamples = penguin_boot,
    control = control_resamples(save_pred = TRUE)
  )

rf_rs
# Resampling results
# Bootstrap sampling 
# A tibble: 25 × 5
   splits           id          .metrics         .notes           .predictions     
   <list>           <chr>       <list>           <list>           <list>           
 1 <split [249/93]> Bootstrap01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [93 × 6]>
 2 <split [249/91]> Bootstrap02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [91 × 6]>
 3 <split [249/90]> Bootstrap03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [90 × 6]>
 4 <split [249/91]> Bootstrap04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [91 × 6]>
 5 <split [249/85]> Bootstrap05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [85 × 6]>
 6 <split [249/87]> Bootstrap06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [87 × 6]>
 7 <split [249/94]> Bootstrap07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [94 × 6]>
 8 <split [249/88]> Bootstrap08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [88 × 6]>
 9 <split [249/95]> Bootstrap09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [95 × 6]>
10 <split [249/89]> Bootstrap10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [89 × 6]>
# … with 15 more rows

Оцениваем модель

Теперь давайте проверим, что у нас получилось.

Посмотрим метрики точности для модели случайного леса:

collect_metrics(rf_rs)
# A tibble: 2 × 6
  .metric  .estimator  mean     n std_err .config             
  <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy binary     0.914    25 0.00545 Preprocessor1_Model1
2 roc_auc  binary     0.977    25 0.00202 Preprocessor1_Model1

Теперь посмотрим метрики точности для логистической регрессии:

collect_metrics(glm_rs)
# A tibble: 2 × 6
  .metric  .estimator  mean     n std_err .config             
  <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy binary     0.918    25 0.00639 Preprocessor1_Model1
2 roc_auc  binary     0.979    25 0.00254 Preprocessor1_Model1

Мы видим, что модель glm_rsсработала чуть лучше, чем модель rf_rs. Учитывая, что эти модели дают примерно одинаковый результат, есть смысл выбрать ту, которая по своей сути проще. Поэтому мы остановимся на модели логистической регрессии.

Построим матрицу ошибок:

glm_rs %>%
  conf_mat_resampled()
# A tibble: 4 × 3
  Prediction Truth   Freq
  <fct>      <fct>  <dbl>
1 female     female  41.1
2 female     male     3  
3 male       female   4.4
4 male       male    42.3

Мы видим, что у нас нет проблем с предсказанием пола особи. Модель справляется достаточно хорошо.

Одним из наших показателей был roc_auc поэтому было бы интересно посмотреть на ROC-кривые.

glm_rs %>%
  collect_predictions() %>%
  group_by(id) %>%
  roc_curve(sex, .pred_female) %>%
  ggplot(aes(1 - specificity, sensitivity, color = id)) +
  geom_abline(lty = 2, color = "gray80", size = 1.5) +
  geom_path(show.legend = FALSE, alpha = 0.6, size = 1.2) +
  coord_equal() +
  theme_light()

Мы видим, что ROC-кривая ступенчатая — это связано с тем, что у нас не большой набор данных.

Теперь мы можем вернуться к тестовому набору. Обратите внимание, что мы еще не использовали тестовую выборку. Ее мы можем использовать только для оценки производительности модели на новых данных.

penguin_final <- penguin_wf %>%
  add_model(glm_spec) %>%
  last_fit(penguin_split)

penguin_final
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits           id               .metrics         .notes   .predictions .workflow 
  <list>           <chr>            <list>           <list>   <list>       <list>    
1 <split [249/84]> train/test split <tibble [2 × 4]> <tibble> <tibble>     <workflow>

Теперь мы можем посмотреть на оценки модели для тестового набора данных:

collect_metrics(penguin_final)
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.857 Preprocessor1_Model1
2 roc_auc  binary         0.938 Preprocessor1_Model1

Мы видим, что результат не сильно отличается от нашего примера выше, что является хорошим признаком.

Мы также можем посмотреть на предсказанные данные и построить матрицу ошибок.

collect_predictions(penguin_final)
# A tibble: 84 × 7
   id               .pred_female .pred_male  .row .pred_class sex    .config          
   <chr>                   <dbl>      <dbl> <int> <fct>       <fct>  <chr>            
 1 train/test split   0.597         0.403       2 female      female Preprocessor1_Mo…
 2 train/test split   0.928         0.0724      3 female      female Preprocessor1_Mo…
 3 train/test split   0.647         0.353       4 female      female Preprocessor1_Mo…
 4 train/test split   0.219         0.781      18 male        female Preprocessor1_Mo…
 5 train/test split   0.0132        0.987      25 male        male   Preprocessor1_Mo…
 6 train/test split   0.970         0.0298     28 female      female Preprocessor1_Mo…
 7 train/test split   0.0000232     1.00       31 male        male   Preprocessor1_Mo…
 8 train/test split   0.872         0.128      34 female      female Preprocessor1_Mo…
 9 train/test split   0.998         0.00250    38 female      female Preprocessor1_Mo…
10 train/test split   0.00000253    1.00       39 male        male   Preprocessor1_Mo…
# … with 74 more rows

Строим матрицу ошибок:

collect_predictions(penguin_final) %>%
  conf_mat(sex, .pred_class)
          Truth
Prediction female male
    female     37    7
    male        5   35

Мы видим, что наша модель достаточно хорошо справляется.

Теперь мы можем посмотреть на подходящий рабочий процесс нашей модели.

penguin_final$.workflow[[1]]
══ Workflow [trained] ════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: logistic_reg()

── Preprocessor ──────────────────────────────────────────────────────────────────────
sex ~ .

── Model ─────────────────────────────────────────────────────────────────────────────

Call:  stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)

Coefficients:
      (Intercept)   speciesChinstrap      speciesGentoo     bill_length_mm  
       -1.042e+02         -8.892e+00         -1.138e+01          6.459e-01  
    bill_depth_mm  flipper_length_mm        body_mass_g  
        2.124e+00          5.654e-02          8.102e-03  

Degrees of Freedom: 248 Total (i.e. Null);  242 Residual
Null Deviance:	    345.2 
Residual Deviance: 70.02 	AIC: 84.02

Здесь мы видим какую модель мы использовали, а также мы можем увидеть рассчитанные коэффициенты нашей модели.

Используя функцию tidy() , мы получим данные коэффициенты в аккуратном виде. Кроме того, мы можем применить аргумент exponentiate = TRUE , преобразовав их в коэффициенты шансов.

penguin_final$.workflow[[1]] %>%
  tidy(exponentiate = TRUE)
# A tibble: 7 × 5
  term              estimate std.error statistic     p.value
  <chr>                <dbl>     <dbl>     <dbl>       <dbl>
1 (Intercept)       5.75e-46  19.6        -5.31  0.000000110
2 speciesChinstrap  1.37e- 4   2.34       -3.79  0.000148   
3 speciesGentoo     1.14e- 5   3.75       -3.03  0.00243    
4 bill_length_mm    1.91e+ 0   0.180       3.60  0.000321   
5 bill_depth_mm     8.36e+ 0   0.478       4.45  0.00000868 
6 flipper_length_mm 1.06e+ 0   0.0611      0.926 0.355      
7 body_mass_g       1.01e+ 0   0.00176     4.59  0.00000442 

Мы видим, что глубина и длина клюва являются основными предикторами для классификации пола особи. Увеличение глубины клюва на 1 мм почти в 8 раза увеличивает шанс быть самцом.

Посмотрим на наш график, который мы строили в начале, но заменим переменную flipper_length_mmна bill_depth_mm:

penguins %>%
  filter(!is.na(sex)) %>%
  ggplot(aes(bill_depth_mm,
             bill_length_mm,
             color = sex,
             size = body_mass_g)) +
  geom_point(alpha = 0.5) +
  facet_wrap(~species) +
  theme_light()

Да, можно с уверенность сказать, что переменная bill_depth_mmделит особей по полу более четко.

Комментарии (0)