library(tree)
8 Lab8 Árboles de decisión
Tipo | |
---|---|
Slides: | https://hastie.su.domains/ISLR2/Slides/Ch8_Tree_Based_Methods.pdf |
Lab-html: | https://hastie.su.domains/ISLR2/Labs/Rmarkdown_Notebooks/Ch8-baggboost-lab.html |
Lab-Rscript: | https://hastie.su.domains/ISLR2/Labs/R_Labs/Ch8-baggboost-lab.R |
Data: | Carseats y Boston (ISLR2 ) |
8.1 Árboles de clasificación
El paquete tree
se utiliza para construir árboles de clasificación y regresión.
Primero usamos árboles de clasificación para analizar el conjunto de datos Carseats
. En estos datos, Sales
es una variable continua (unidades vendidas -en miles- en cada ubicación), por lo que comenzamos recodificándola como una variable binaria. Usamos la función ifelse()
para crear una variable, llamada High
, que toma el valor Yes
si la variable Sales
excede de 8, y toma el valor No
de lo contrario. Y usamos la función data.frame()
para fusionar High
con el resto de los datos de Carseats
.
library(ISLR2)
<- factor(ifelse(Carseats$Sales > 8, "Yes", "No"))
High <- data.frame(Carseats, High) Carseats
La función tree()
permite ajustar un árbol de clasificación con el fin de predecir High
usando todas las variables excepto Sales
. Tiene una sintaxis bastante similar a la de la función lm()
.
<- tree(High ~ . - Sales, Carseats) tree.carseats
La función summary()
enumera las variables que se utilizan como nodos internos en el árbol, el número de nodos terminales y la tasa de error (entrenamiento).
summary(tree.carseats)
Classification tree:
tree(formula = High ~ . - Sales, data = Carseats)
Variables actually used in tree construction:
[1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
[6] "Advertising" "Age" "US"
Number of terminal nodes: 27
Residual mean deviance: 0.4575 = 170.7 / 373
Misclassification error rate: 0.09 = 36 / 400
Vemos que la tasa de error de entrenamiento es del 9%.
Para árboles de clasificación, la deviance reportada en la salida de summary()
viene dada por \[-2 \sum_m \sum_k n_{mk} \log \hat{p}_{mk},\] donde \(n_{mk}\) es el número de observaciones en el nodo terminal \(m\)-ésimo que pertenecen a la \(k\)-ésima clase. Esto está estrechamente relacionado con la entropía (definida en el libro, en (8.7)). Una pequeña deviance indica un árbol que proporciona un buen ajuste a los datos (de entrenamiento). La residual mean deviance reportada es simplemente la deviance dividida por \(n-|T_0|\), que en este caso es 400 - 27 = 373
.
8.1.1 Gráfico
Una de las propiedades más atractivas de los árboles es que se pueden representar gráficamente. Usamos la función plot()
para mostrar la estructura de árbol y la función text()
para mostrar las etiquetas de los nodos. El argumento pretty = 0
indica a R
que incluya los nombres de categoría para cualquier predictor cualitativo, en lugar de simplemente mostrar una letra para cada categoría.
plot(tree.carseats)
text(tree.carseats, pretty = 0)
El indicador más importante de Sales
parece ser la ubicación de las estanterías, ya que la primera rama diferencia las ubicaciones Good
de las ubicaciones Bad
y Medium
.
Nota Víctor: De nuevo falta en el material original interpretar el gráfico… A grandes rasgos, se lee de arriba a abajo. Tras cada nodo (variable y criterio asociado), por ejemplo,
Price < 135
, salen dos ramas, a la izquierda serían las de las observaciones queSí
cumplen el criterio, a la izquierda las queNo
lo cumplen.
Existen otros paquetes que ayudan en la visualización…
library(rpart)
library(rpart.plot)
<- rpart(High ~ . - Sales, Carseats, method = "class")
tree.carseats2 rpart.plot(tree.carseats2)
¡Los resultados difieren ligeramente!
Si solo escribimos el nombre del objeto del árbol, R
imprime la salida correspondiente a cada rama (node
) del árbol: split
es el criterio de división (p. ej., Price < 92,5
), n
es el número de observaciones en esa rama, la deviance
, yval
es la predicción general para la rama (Yes
o No
), y yprob
es la fracción de observaciones (probabilidad) en esa rama (en este ejemplo, primero la fracción de No
y luego la de Yes
). Las ramas que conducen a los nodos terminales se indican con asteriscos.
tree.carseats
node), split, n, deviance, yval, (yprob)
* denotes terminal node
1) root 400 541.500 No ( 0.59000 0.41000 )
2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )
4) Price < 92.5 46 56.530 Yes ( 0.30435 0.69565 )
8) Income < 57 10 12.220 No ( 0.70000 0.30000 )
16) CompPrice < 110.5 5 0.000 No ( 1.00000 0.00000 ) *
17) CompPrice > 110.5 5 6.730 Yes ( 0.40000 0.60000 ) *
9) Income > 57 36 35.470 Yes ( 0.19444 0.80556 )
18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) *
19) Population > 207.5 20 7.941 Yes ( 0.05000 0.95000 ) *
5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )
10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )
20) CompPrice < 124.5 96 44.890 No ( 0.93750 0.06250 )
40) Price < 106.5 38 33.150 No ( 0.84211 0.15789 )
80) Population < 177 12 16.300 No ( 0.58333 0.41667 )
160) Income < 60.5 6 0.000 No ( 1.00000 0.00000 ) *
161) Income > 60.5 6 5.407 Yes ( 0.16667 0.83333 ) *
81) Population > 177 26 8.477 No ( 0.96154 0.03846 ) *
41) Price > 106.5 58 0.000 No ( 1.00000 0.00000 ) *
21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )
42) Price < 122.5 51 70.680 Yes ( 0.49020 0.50980 )
84) ShelveLoc: Bad 11 6.702 No ( 0.90909 0.09091 ) *
85) ShelveLoc: Medium 40 52.930 Yes ( 0.37500 0.62500 )
170) Price < 109.5 16 7.481 Yes ( 0.06250 0.93750 ) *
171) Price > 109.5 24 32.600 No ( 0.58333 0.41667 )
342) Age < 49.5 13 16.050 Yes ( 0.30769 0.69231 ) *
343) Age > 49.5 11 6.702 No ( 0.90909 0.09091 ) *
43) Price > 122.5 77 55.540 No ( 0.88312 0.11688 )
86) CompPrice < 147.5 58 17.400 No ( 0.96552 0.03448 ) *
87) CompPrice > 147.5 19 25.010 No ( 0.63158 0.36842 )
174) Price < 147 12 16.300 Yes ( 0.41667 0.58333 )
348) CompPrice < 152.5 7 5.742 Yes ( 0.14286 0.85714 ) *
349) CompPrice > 152.5 5 5.004 No ( 0.80000 0.20000 ) *
175) Price > 147 7 0.000 No ( 1.00000 0.00000 ) *
11) Advertising > 13.5 45 61.830 Yes ( 0.44444 0.55556 )
22) Age < 54.5 25 25.020 Yes ( 0.20000 0.80000 )
44) CompPrice < 130.5 14 18.250 Yes ( 0.35714 0.64286 )
88) Income < 100 9 12.370 No ( 0.55556 0.44444 ) *
89) Income > 100 5 0.000 Yes ( 0.00000 1.00000 ) *
45) CompPrice > 130.5 11 0.000 Yes ( 0.00000 1.00000 ) *
23) Age > 54.5 20 22.490 No ( 0.75000 0.25000 )
46) CompPrice < 122.5 10 0.000 No ( 1.00000 0.00000 ) *
47) CompPrice > 122.5 10 13.860 No ( 0.50000 0.50000 )
94) Price < 125 5 0.000 Yes ( 0.00000 1.00000 ) *
95) Price > 125 5 0.000 No ( 1.00000 0.00000 ) *
3) ShelveLoc: Good 85 90.330 Yes ( 0.22353 0.77647 )
6) Price < 135 68 49.260 Yes ( 0.11765 0.88235 )
12) US: No 17 22.070 Yes ( 0.35294 0.64706 )
24) Price < 109 8 0.000 Yes ( 0.00000 1.00000 ) *
25) Price > 109 9 11.460 No ( 0.66667 0.33333 ) *
13) US: Yes 51 16.880 Yes ( 0.03922 0.96078 ) *
7) Price > 135 17 22.070 No ( 0.64706 0.35294 )
14) Income < 46 6 0.000 No ( 1.00000 0.00000 ) *
15) Income > 46 11 15.160 Yes ( 0.45455 0.54545 ) *
8.1.2 Validación
Para evaluar correctamente el rendimiento de un árbol de clasificación con estos datos, debemos estimar el error en el conjunto de validación en lugar de simplemente calcular el error de entrenamiento. Dividimos las observaciones en los conjuntos de entrenamiento y de validación o prueba, construimos el árbol usando el conjunto de entrenamiento, y evaluamos su desempeño en el conjunto de validación. Se utiliza la función predict()
para este propósito. En el caso de un árbol de clasificación, el argumento type = "class"
indica a R
que devuelva la predicción de clase real. Este enfoque conduce a predicciones correctas para alrededor del 77% de las ubicaciones en el conjunto de datos de prueba/validación.
set.seed(2)
<- sample(1:nrow(Carseats), 200)
train <- Carseats[-train,]
Carseats.test <- High[-train]
High.test <- tree(High ~ . - Sales, Carseats, subset = train)
tree.carseats <- predict(tree.carseats, Carseats.test, type = "class")
tree.pred table(tree.pred, High.test)
High.test
tree.pred No Yes
No 104 33
Yes 13 50
104 + 50) / 200 (
[1] 0.77
(Si vuelve a ejecutar la función predict()
, es posible que obtenga resultados ligeramente diferentes, debido a “empates”: por ejemplo, esto puede suceder cuando las observaciones de entrenamiento correspondientes a un nodo terminal se dividen equitativamente entre valores de respuesta Yes
y No
.)
8.1.3 Poda del árbol
A continuación, consideramos si podar el árbol podría conducir a mejores resultados. La función cv.tree()
realiza una validación cruzada para determinar el nivel óptimo de complejidad del árbol; la poda por costes de complejidad se utiliza para seleccionar una secuencia de árboles para su consideración. Usamos el argumento FUN = prune.misclass
para indicar que queremos que la tasa de error de clasificación guíe el proceso de validación cruzada y poda, en lugar del valor predeterminado para la función cv.tree()
, que es la deviance. La función cv.tree()
reporta el número de nodos terminales de cada árbol considerado (size
), así como la tasa de error correspondiente y el valor del parámetro coste-complejidad utilizado (k
, que corresponde a \(\alpha\) en (8.4)).
set.seed(7)
<- cv.tree(tree.carseats, FUN = prune.misclass)
cv.carseats names(cv.carseats)
[1] "size" "dev" "k" "method"
cv.carseats
$size
[1] 21 19 14 9 8 5 3 2 1
$dev
[1] 75 75 75 74 82 83 83 85 82
$k
[1] -Inf 0.0 1.0 1.4 2.0 3.0 4.0 9.0 18.0
$method
[1] "misclass"
attr(,"class")
[1] "prune" "tree.sequence"
A pesar de su nombre, dev
corresponde al número de errores de validación cruzada. El árbol con 9 nodos terminales da como resultado solo 74 errores de validación cruzada. Dibujamos la tasa de error como una función tanto de size
como de k
.
par(mfrow = c(1, 2))
plot(cv.carseats$size, cv.carseats$dev, type = "b")
plot(cv.carseats$k, cv.carseats$dev, type = "b")
Ahora aplicamos la función prune.misclass()
para podar el árbol y obtener el árbol de 9 nodos (terminales).
<- prune.misclass(tree.carseats, best = 9)
prune.carseats plot(prune.carseats)
text(prune.carseats, pretty = 0)
¿Cómo de bueno es el desempeño de este árbol podado en el conjunto de datos de prueba/validación? Una vez más, aplicamos la función predict()
.
<- predict(prune.carseats, Carseats.test, type = "class")
tree.pred table(tree.pred, High.test)
High.test
tree.pred No Yes
No 97 25
Yes 20 58
97 + 58) / 200 (
[1] 0.775
Ahora, el 77.5% de las observaciones de prueba se clasifican correctamente, por lo que el proceso de poda no solo produjo un árbol más interpretable, sino que también mejoró ligeramente la precisión de la clasificación.
Si aumentamos el valor de best
, obtenemos un árbol podado más grande con menor precisión de clasificación:
<- prune.misclass(tree.carseats, best = 14)
prune.carseats plot(prune.carseats)
text(prune.carseats, pretty = 0)
<- predict(prune.carseats, Carseats.test, type = "class")
tree.pred table(tree.pred, High.test)
High.test
tree.pred No Yes
No 102 31
Yes 15 52
102 + 52) / 200 (
[1] 0.77
8.2 Árboles de regresión
Ajustamos un árbol de regresión al conjunto de datos Boston
. Primero, creamos un conjunto de entrenamiento y ajustamos el árbol a los datos de entrenamiento.
set.seed(1)
<- sample(1:nrow(Boston), nrow(Boston) / 2)
train <- tree(medv ~ ., Boston, subset = train)
tree.boston summary(tree.boston)
Regression tree:
tree(formula = medv ~ ., data = Boston, subset = train)
Variables actually used in tree construction:
[1] "rm" "lstat" "crim" "age"
Number of terminal nodes: 7
Residual mean deviance: 10.38 = 2555 / 246
Distribution of residuals:
Min. 1st Qu. Median Mean 3rd Qu. Max.
-10.1800 -1.7770 -0.1775 0.0000 1.9230 16.5800
Note que la salida de summary()
indica que solo cuatro de las variables han sido usadas en la construcción del árbol. En el contexto de árboles de regresión, la deviance es simplemente la suma de los errores al cuadrado del árbol. Ahora dibujamos el árbol.
plot(tree.boston)
text(tree.boston, pretty = 0)
La variable lstat
mide el porcentaje de población con bajo nivel socioeconómico, mientras que la variable rm
corresponde al número medio de habitaciones por casa. El árbol indica que valores más altos de rm
, o valores más bajos de lstat
, corresponden a casas más caras. Por ejemplo, el árbol predice un precio mediano medio de vivienda de 45.400$ para viviendas en distritos censales en los que rm >= 7.55
.
Vale la pena señalar que podríamos haber ajustado un árbol mucho más grande, pasando control = tree.control(nobs = length(train), mindev = 0)
a la función tree()
.
8.2.1 Poda
Ahora usamos la función cv.tree()
para ver si podar el árbol mejorará el rendimiento.
<- cv.tree(tree.boston)
cv.boston plot(cv.boston$size, cv.boston$dev, type = "b")
En este caso, el árbol más complejo bajo consideración se selecciona mediante validación cruzada. Sin embargo, si deseamos podar el árbol, podemos hacerlo de la siguiente manera, usando la función prune.tree()
:
<- prune.tree(tree.boston, best = 5)
prune.boston plot(prune.boston)
text(prune.boston, pretty = 0)
De acuerdo con los resultados de la validación cruzada, usamos el árbol no podado para hacer predicciones en el conjunto de validación.
<- predict(tree.boston, newdata = Boston[-train,])
yhat <- Boston[-train, "medv"]
boston.test plot(yhat, boston.test)
abline(0, 1)
mean((yhat - boston.test)^2)
[1] 35.28688
En otras palabras, el MSE del conjunto de validación asociado con el árbol de regresión es de 35.29
. Por lo tanto, la raíz cuadrada del MSE es de alrededor de 5.94
, lo que indica que este modelo conduce a predicciones en el conjunto de validación que están (en promedio) dentro de aproximadamente 5.94
del valor mediano real de la vivienda para el distrito censal.
8.3 Bagging y Random Forests
Aquí aplicamos bagging y random forests a los datos de Boston
, usando el paquete randomForest
en R
. Los resultados exactos obtenidos en esta sección pueden depender de la versión de R
y la versión del paquete randomForest
instalado en su computadora.
8.3.1 Bagging
El bagging es simplemente un caso especial de un random forest con \(m=p\). Por lo tanto, la función randomForest()
se puede utilizar para realizar ambos. Realizamos el bagging de la siguiente manera:
library(randomForest)
randomForest 4.7-1.1
Type rfNews() to see new features/changes/bug fixes.
set.seed(1)
<- randomForest(
bag.boston ~ .,
medv data = Boston,
subset = train,
mtry = 12,
importance = TRUE
) bag.boston
Call:
randomForest(formula = medv ~ ., data = Boston, mtry = 12, importance = TRUE, subset = train)
Type of random forest: regression
Number of trees: 500
No. of variables tried at each split: 12
Mean of squared residuals: 11.40162
% Var explained: 85.17
El argumento mtry = 12
indica que se deben considerar los 12 predictores (todos), para cada división del árbol; en otras palabras, se debe realizar el bagging. ¿Cómo de bueno es el desempeño de este modelo bagging en el conjunto de prueba/validación?
<- predict(bag.boston, newdata = Boston[-train,])
yhat.bag plot(yhat.bag, boston.test)
abline(0, 1)
mean((yhat.bag - boston.test)^2)
[1] 23.41916
El MSE del conjunto de validación asociado con el árbol de regresión bagging es de 23.42
, aproximadamente dos tercios del obtenido utilizando un árbol único podado de manera óptima. Podríamos cambiar el número de árboles cultivados por randomForest()
usando el argumento ntree
:
<- randomForest(
bag.boston ~ .,
medv data = Boston,
subset = train,
mtry = 12,
ntree = 25
)<- predict(bag.boston, newdata = Boston[-train,])
yhat.bag mean((yhat.bag - boston.test)^2)
[1] 25.75055
8.3.2 Random forest
El crecimiento de un random forest procede exactamente de la misma manera, excepto que usamos un valor más pequeño del argumento mtry
. Por defecto, randomForest()
usa
- \(p/3\) variables cuando construye un random forest de árboles de regresión, y
- \(\sqrt{p}\) variables cuando construye un random forest de árboles de clasificación.
Aquí usamos mtry = 6
.
set.seed(1)
<- randomForest(
rf.boston ~ .,
medv data = Boston,
subset = train,
mtry = 6,
importance = TRUE
)<- predict(rf.boston, newdata = Boston[-train,])
yhat.rf mean((yhat.rf - boston.test)^2)
[1] 20.06644
El MSE del conjunto de validación es 20.07
; esto indica que los random forests produjeron una mejora con respecto al bagging en este caso.
Usando la función importance()
, podemos ver la importancia de cada variable.
importance(rf.boston)
%IncMSE IncNodePurity
crim 19.435587 1070.42307
zn 3.091630 82.19257
indus 6.140529 590.09536
chas 1.370310 36.70356
nox 13.263466 859.97091
rm 35.094741 8270.33906
age 15.144821 634.31220
dis 9.163776 684.87953
rad 4.793720 83.18719
tax 4.410714 292.20949
ptratio 8.612780 902.20190
lstat 28.725343 5813.04833
Se reportan dos medidas de importancia de las variables. La primera se basa en la disminución media de la precisión en las predicciones sobre las muestras fuera de la bolsa (bag) cuando se permuta una variable determinada. La segunda es una medida de la disminución total en la impureza de los nodos que resulta de las divisiones sobre esa variable, promediada sobre todos los árboles (esto se representó en la Figura 8.9). En el caso de los árboles de regresión, la impureza del nodo se mide por el RSS de entrenamiento, y para los árboles de clasificación por la deviance. Los gráficos de estas medidas de importancia se pueden producir utilizando la función varImpPlot()
.
varImpPlot(rf.boston)
Los resultados indican que en todos los árboles considerados en el random forest, el tamaño de la casa (rm
) y el % de nivel socioeconómico bajo (lstat
) son, con mucho, las dos variables más importantes.
8.4 Boosting
Usamos la función gbm()
(paquete gbm
), para ajustar árboles de regresión potenciados (boosted) al conjunto de datos Boston
:
- con el argumento
distribution = "gaussian"
ya que este es un problema de regresión; - se usaría
distribution = "bernoulli"
si fuera un problema de clasificación binaria.
El argumento n.trees = 5000
indica que queremos 5.000 árboles, y la opción interaction.depth = 4
limita la profundidad de cada árbol.
library(gbm)
Loaded gbm 2.2.2
This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
set.seed(1)
<- gbm(
boost.boston ~ .,
medv data = Boston[train,],
distribution = "gaussian",
n.trees = 5000,
interaction.depth = 4
)
La función summary()
produce un gráfico de influencia relativa y también genera estadísticos de influencia relativa.
summary(boost.boston)
var rel.inf
rm rm 44.48249588
lstat lstat 32.70281223
crim crim 4.85109954
dis dis 4.48693083
nox nox 3.75222394
age age 3.19769210
ptratio ptratio 2.81354826
tax tax 1.54417603
indus indus 1.03384666
rad rad 0.87625748
zn zn 0.16220479
chas chas 0.09671228
Vemos que, de nuevo, rm
y lstat
son, con diferencia, las variables más importantes.
También podemos producir gráficos de dependencia parcial para estas dos variables. Estos gráficos ilustran el efecto marginal de las variables seleccionadas en la respuesta después de integrar las otras variables.
par(mfrow = c(1, 2))
plot(boost.boston, i = "rm")
plot(boost.boston, i = "lstat")
En este caso, como cabría esperar, el precio mediano de la vivienda aumentan con rm
y disminuyen con lstat
.
Ahora usamos el modelo potenciado para predecir medv
(el mencionado precio mediano) en el conjunto de validación:
<- predict(boost.boston,
yhat.boost newdata = Boston[-train,], n.trees = 5000)
mean((yhat.boost - boston.test)^2)
[1] 18.39057
El MSE de validación obtenido es de 18.39
: esto es “superior” al MSE de validación de random forests y bagging. Si queremos, podemos realizar un boosting con un valor diferente del parámetro de shrinkage \(\lambda\) en (8.10). El valor predeterminado es 0.001, pero esto se modifica fácilmente. Aquí tomamos \(\lambda=0.2\).
<- gbm(
boost.boston ~ .,
medv data = Boston[train,],
distribution = "gaussian",
n.trees = 5000,
interaction.depth = 4,
shrinkage = 0.2,
verbose = F
)<- predict(boost.boston,
yhat.boost newdata = Boston[-train,], n.trees = 5000)
mean((yhat.boost - boston.test)^2)
[1] 16.54778
En este caso, usar \(\lambda=0.2\) conduce a un MSE de validación más bajo que \(\lambda=0.001\).
8.5 Árboles de regresión aditivos bayesianos (omitido)
Nota Víctor: Considero que es suficiente para esta asignatura entender los árboles anteriores. Aquí sólo se mencionan las funciones en R que permiten el análisis pudiéndose consultar los detalles (escasos) en el material original.
La función gbart()
del paquete BART
permite ajustar un modelo de árbol de regresión aditivo bayesiano, para variables de respuesta cuantitativas. Para respuestas binarias están disponibles lbart()
y pbart()
.