Removes all nodes that did not improve the loss by more than cp times the initial loss. Either by themselves or by one of their successors. Note that the tree is pruned in place. If you intend to keep the original tree, make a copy of it before pruning.

# S3 method for class 'SDTree'
prune(object, cp, ...)

Arguments

object

an SDTree object

cp

Complexity parameter, the higher the value the more nodes are pruned.

...

Further arguments passed to or from other methods.

Value

A pruned SDTree object

Author

Markus Ulmer

Examples

set.seed(1)
X <- matrix(rnorm(10 * 20), nrow = 10)
Y <- rnorm(10)
tree <- SDTree(x = X, y = Y)
pruned_tree <- prune(tree, 0.2)
tree
#> $predictions
#>  [1] -0.4113688  1.4656657  1.4656657  1.4656657 -0.4113688  1.4656657
#>  [7]  1.4656657 -0.4113688 -0.4113688 -0.4113688
#> 
#> $tree
#>      name left right j         s      value     dloss res_dloss         cp
#> [1,]    1    2     3 2 0.4918723  0.5271484 1.4898237 0.6965944 10.0000000
#> [2,]    1    0     0 0 0.0000000  1.4656657 0.6965944 0.0000000  0.4675684
#> [3,]    2    0     0 0 0.0000000 -0.4113688 0.6965944 0.0000000  0.4675684
#>      n_samples leaf
#> [1,]        10    2
#> [2,]         5    1
#> [3,]         5    1
#> 
#> $var_names
#>  [1] "X1"  "X2"  "X3"  "X4"  "X5"  "X6"  "X7"  "X8"  "X9"  "X10" "X11" "X12"
#> [13] "X13" "X14" "X15" "X16" "X17" "X18" "X19" "X20"
#> 
#> $var_importance
#>        X1        X2        X3        X4        X5        X6        X7        X8 
#> 0.0000000 0.6965944 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 
#>        X9       X10       X11       X12       X13       X14       X15       X16 
#> 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 
#>       X17       X18       X19       X20 
#> 0.0000000 0.0000000 0.0000000 0.0000000 
#> 
#> attr(,"class")
#> [1] "SDTree"
pruned_tree
#> $tree
#>      name left right j         s      value     dloss res_dloss         cp
#> [1,]    1    2     3 2 0.4918723  0.5271484 1.4898237 0.6965944 10.0000000
#> [2,]    1    0     0 0 0.0000000  1.4656657 0.6965944 0.0000000  0.4675684
#> [3,]    2    0     0 0 0.0000000 -0.4113688 0.6965944 0.0000000  0.4675684
#>      n_samples leaf
#> [1,]        10    2
#> [2,]         5    1
#> [3,]         5    1
#> 
#> $var_names
#>  [1] "X1"  "X2"  "X3"  "X4"  "X5"  "X6"  "X7"  "X8"  "X9"  "X10" "X11" "X12"
#> [13] "X13" "X14" "X15" "X16" "X17" "X18" "X19" "X20"
#> 
#> $var_importance
#>        X1        X2        X3        X4        X5        X6        X7        X8 
#> 0.0000000 0.6965944 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 
#>        X9       X10       X11       X12       X13       X14       X15       X16 
#> 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 
#>       X17       X18       X19       X20 
#> 0.0000000 0.0000000 0.0000000 0.0000000 
#> 
#> attr(,"class")
#> [1] "SDTree"