cforest prints empty tree
Asked Answered
S

2

6

I'm trying to use cforest function(R, party package).

This's what I do to construct forest:

library("party")
set.seed(42)
readingSkills.cf <- cforest(score ~ ., data = readingSkills, 
                         control = cforest_unbiased(mtry = 2, ntree = 50))

Then I want to print the first tree and I do

party:::prettytree(readingSkills.cf@ensemble[[1]],names(readingSkills.cf@data@get("input")))

The result look like this

     1) shoeSize <= 28.29018; criterion = 1, statistic = 89.711
       2) age <= 6; criterion = 1, statistic = 48.324
    3) age <= 5; criterion = 0.997, statistic = 8.917
      4)*  weights = 0 
    3) age > 5
      5)*  weights = 0 
  2) age > 6
    6) age <= 7; criterion = 1, statistic = 13.387
      7) shoeSize <= 26.66743; criterion = 0.214, statistic = 0.073
        8)*  weights = 0 
      7) shoeSize > 26.66743
        9)*  weights = 0 
    6) age > 7
      10)*  weights = 0 
1) shoeSize > 28.29018
  11) age <= 9; criterion = 1, statistic = 36.836
    12) nativeSpeaker == {}; criterion = 0.998, statistic = 9.347
      13)*  weights = 0 
    12) nativeSpeaker == {}
      14)*  weights = 0 
  11) age > 9
    15) nativeSpeaker == {}; criterion = 1, statistic = 19.124
      16) age <= 10; criterion = 1, statistic = 18.441
        17)*  weights = 0 
      16) age > 10
        18)*  weights = 0 
    15) nativeSpeaker == {}
      19)*  weights = 0 

Why is it empty(weights in each node is equal to zero)?

Skydive answered 12/11, 2013 at 8:36 Comment(0)
C
9

Short answer: the case weights weights in each node are NULL, i.e. not stored. The prettytree function outputs weights = 0, since sum(NULL) equals 0 in R.


Consider the following ctree example:

library("party")
x <- ctree(Species ~ ., data=iris)
plot(x, type="simple")

ctree plot

For the resulting object x (class BinaryTree) the case weights are stored in each node:

R> sum(x@tree$left$weights)
[1] 50
R> sum(x@tree$right$weights)
[1] 100
R> sum(x@tree$right$left$weights)
[1] 54
R> sum(x@tree$right$right$weights)
[1] 46

Now lets take a closer look at cforest:

y <- cforest(Species ~ ., data=iris, control=cforest_control(mtry=2))
tr <- party:::prettytree(y@ensemble[[1]], names(y@data@get("input")))
plot(new("BinaryTree", tree=tr, data=y@data, responses=y@responses))

cforest tree

The case weights are not stored in the tree ensemble, which can be seen by the following:

fixInNamespace("print.TerminalNode", "party")

change the print method to

function (x, n = 1, ...)·                                                     
{                                                                             
    print(names(x))                                                           
    print(x$weights)                                                          
    cat(paste(paste(rep(" ", n - 1), collapse = ""), x$nodeID,·               
        ")* ", sep = "", collapse = ""), "weights =", sum(x$weights),·        
        "\n")                                                                 
} 

Now we can observe that weights is NULL in every node:

R> tr
1) Petal.Width <= 0.4; criterion = 10.641, statistic = 10.641
 [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
 [6] "ssplits"    "prediction" "left"       "right"      NA          
NULL
  2)*  weights = 0 
1) Petal.Width > 0.4
  3) Petal.Width <= 1.6; criterion = 8.629, statistic = 8.629
 [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
 [6] "ssplits"    "prediction" "left"       "right"      NA          
NULL
    4)*  weights = 0 
  3) Petal.Width > 1.6
 [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
 [6] "ssplits"    "prediction" "left"       "right"      NA          
NULL
    5)*  weights = 0 

Update this is a hack to display the sums of the case weights:

update_tree <- function(x) {
  if(!x$terminal) {
    x$left <- update_tree(x$left)
    x$right <- update_tree(x$right)
  } else {
    x$weights <- x[[9]]
    x$weights_ <- x[[9]]
  }
  x
}
tr_weights <- update_tree(tr)
plot(new("BinaryTree", tree=tr_weights, data=y@data, responses=y@responses))

cforest tree with case weights

Comnenus answered 21/11, 2013 at 21:14 Comment(3)
Thank you for your answer. But, does it mean that cforest constructs empty forest? And what do you think about the reason of such result?Skydive
The trees are not "empty". The case weights are not available in the tree ensemble, hence the output is misleading. I don't know exactly why, maybe for efficient memory usage.Comnenus
Ok. But, how can I get weights?Skydive
F
7

The solution proposed by @rcs in the Update is interesting but does not work with cforest when the dependent variable is numerical. The code:

set.seed(12345)
y <- cforest(score ~ ., data = readingSkills,
       control = cforest_unbiased(mtry = 2, ntree = 50))
tr <- party:::prettytree(y@ensemble[[1]], names(y@data@get("input")))
tr_weights <- update_tree(tr)
plot(new("BinaryTree", tree=tr_weights, data=y@data, responses=y@responses))

generates the following error message

R> Error in valid.data(rep(units, length.out = length(x)), data) :
   no string supplied for 'strwidth/height' unit

and the following plot:

enter image description here

Below I suggest an improved version of the hack proposed by @rcs:

get_cTree <- function(cf, k=1) {
  dt <- cf@data@get("input")
  tr <- party:::prettytree(cf@ensemble[[k]], names(dt))
  tr_updated <- update_tree(tr, dt)
  new("BinaryTree", tree=tr_updated, data=cf@data, responses=cf@responses, 
      cond_distr_response=cf@cond_distr_response, predict_response=cf@predict_response)
}

update_tree <- function(x, dt) {
  x <- update_weights(x, dt)
  if(!x$terminal) {
    x$left <- update_tree(x$left, dt)
    x$right <- update_tree(x$right, dt)   
  } 
  x
}

update_weights <- function(x, dt) {
  splt <- x$psplit
  spltClass <- attr(splt,"class")
  spltVarName <- splt$variableName
  spltVar <- dt[,spltVarName]
  spltVarLev <- levels(spltVar)
  if (!is.null(spltClass)) {
    if (spltClass=="nominalSplit") {
     attr(x$psplit$splitpoint,"levels") <- spltVarLev   
     filt <- spltVar %in% spltVarLev[as.logical(x$psplit$splitpoint)] 
    } else {
     filt <- (spltVar <= splt$splitpoint)
    }
  x$left$weights <- as.numeric(filt)
  x$right$weights <- as.numeric(!filt)
  }
  x
}

plot(get_cTree(y, 1))

enter image description here

Franz answered 30/12, 2015 at 18:22 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.