How to count decision tree rules in R
Asked Answered
D

1

6

I employed RPart to build a decision tree. Without a problem, I am doing this. But, I need to learn (or count) how many times the tree has been splitted? I mean, how many rules (if-else statement) the tree has? For instance:

                  X
                 - - 
        if (a<9)-   - if(a>=9)
               Y     H
              -
      if(b>2)- 
            Z

There are 3 rules.

When I write summary(model):

summary(model_dt)

Call:
rpart(formula = Alert ~ ., data = train)
  n= 18576811 

         CP nsplit  rel error     xerror         xstd
1 0.9597394      0 1.00000000 1.00000000 0.0012360956
2 0.0100000      1 0.04026061 0.05290522 0.0002890205

Variable importance
         ip.src frame.protocols   tcp.flags.ack tcp.flags.reset       frame.len 
             20              17              17              17              16 
         ip.ttl 
        `    12 

Node number 1: 18576811 observations,    complexity param=0.9597394
  predicted class=yes  expected loss=0.034032  P(node) =1
    class counts: 632206 1.79446e+07
   probabilities: 0.034 0.966 
  left son=2 (627091 obs) right son=3 (17949720 obs)
  Primary splits:
      ip.src          splits as LLLLLLLRRRLLRR ............ LLRLRLRRRRRRRRRRRRRRRR
    improve=1170831.0, (0 missing)

      ip.dts splits as  LLLLLLLLLLLLLLLLLLLRLLLLLLLLLLL, improve=1013082.0, (0 missing)
      tcp.flags.ctl   < 1.5   to the right, improve=1007953.0, (2645 missing)
      tcp.flags.syn < 1.5   to the right, improve=1007953.0, (2645 missing)
      frame.len       < 68    to the right, improve= 972871.3, (30 missing)
  Surrogate splits:
      frame.protocols splits as  LLLLLLLLLLLLLLLLLLLRLLLLLLLLLLL, agree=0.995, adj=0.841, (0 split)
      tcp.flags.ack   < 1.5   to the right, agree=0.994, adj=0.836, (0 split)
      tcp.flags.reset < 1.5   to the right, agree=0.994, adj=0.836, (0 split)
      frame.len       < 68    to the right, agree=0.994, adj=0.809, (0 split)
      ip.ttl          < 230.5 to the right, agree=0.987, adj=0.612, (0 split)

Node number 2: 627091 observations
  predicted class=no   expected loss=0.01621615  P(node) =0.03375666
    class counts: 616922 10169
   probabilities: 0.984 0.016 

Node number 3: 17949720 observations
  predicted class=yes  expected loss=0.0008514896  P(node) =0.9662433
    class counts: 15284 1.79344e+07
   probabilities: 0.001 0.999

If anyone helps me to understand it, I will be grateful

SIncerely Eray

Demetri answered 30/5, 2014 at 18:46 Comment(0)
R
5

There are a couple of ways to achieve this, through some knowledge of how the tree object (?rpart.object) is returned.

I'll show two ways using the kyphosis data set in R following the first example in ?rpart:

fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)

Option 1

> tail(fit$cptable[, "nsplit"], 1)
3 
4
> unname(tail(fit$cptable[, "nsplit"], 1)) ## or
[1] 4

From the cptable, which contains information on the cost complexity of trees of a given size

> fit$cptable
          CP nsplit rel error   xerror      xstd
1 0.17647059      0 1.0000000 1.000000 0.2155872
2 0.01960784      1 0.8235294 1.176471 0.2282908
3 0.01000000      4 0.7647059 1.176471 0.2282908

From what I recall, the last line of this table will refer to the current largest tree. If you prune the tree to a particular size based on CP, the last line of this matrix will contain the information on the tree of this size:

> fit2 <- prune(fit, cp = 0.02)
> fit2$cptable
         CP nsplit rel error   xerror      xstd
1 0.1764706      0 1.0000000 1.000000 0.2155872
2 0.0200000      1 0.8235294 1.176471 0.2282908

Option 2

A second option is to count the occurrence of <leaf> in the var column of the frame component of the fitted model:

> fit$frame
      var  n wt dev yval complexity ncompete nsurrogate    yval2.V1    yval2.V2
1   Start 81 81  17    1 0.17647059        2          1  1.00000000 64.00000000
2   Start 62 62   6    1 0.01960784        2          2  1.00000000 56.00000000
4  <leaf> 29 29   0    1 0.01000000        0          0  1.00000000 29.00000000
5     Age 33 33   6    1 0.01960784        2          2  1.00000000 27.00000000
10 <leaf> 12 12   0    1 0.01000000        0          0  1.00000000 12.00000000
11    Age 21 21   6    1 0.01960784        2          0  1.00000000 15.00000000
22 <leaf> 14 14   2    1 0.01000000        0          0  1.00000000 12.00000000
23 <leaf>  7  7   3    2 0.01000000        0          0  2.00000000  3.00000000
3  <leaf> 19 19   8    2 0.01000000        0          0  2.00000000  8.00000000
      yval2.V3    yval2.V4    yval2.V5 yval2.nodeprob
1  17.00000000  0.79012346  0.20987654     1.00000000
2   6.00000000  0.90322581  0.09677419     0.76543210
4   0.00000000  1.00000000  0.00000000     0.35802469
5   6.00000000  0.81818182  0.18181818     0.40740741
10  0.00000000  1.00000000  0.00000000     0.14814815
11  6.00000000  0.71428571  0.28571429     0.25925926
22  2.00000000  0.85714286  0.14285714     0.17283951
23  4.00000000  0.42857143  0.57142857     0.08641975
3  11.00000000  0.42105263  0.57894737     0.23456790

This value - 1 is the number of splits. To do the counting we can use:

> grepl("^<leaf>$", as.character(fit$frame$var))
[1] FALSE FALSE  TRUE FALSE  TRUE FALSE  TRUE  TRUE  TRUE
> sum(grepl("^<leaf>$", as.character(fit$frame$var))) - 1
[1] 4

The regular expression I use is probably overkill but it means check for a string starting with (^) and ending in ($) "<leaf>", i.e. this is the entire string,. I use grepl() to return the matches on the var column as a logical vector, which we can sum up the TRUEs and subtract 1 from. As var is stored as a factor, I convert this to a character vector in the grepl() call.

You could also do this using grep() to return the indices of the matches and use length() to count them:

> grep("^<leaf>$", as.character(fit$frame$var))
[1] 3 5 7 8 9
> length(grep("^<leaf>$", as.character(fit$frame$var))) - 1
[1] 4
Rashad answered 30/5, 2014 at 19:20 Comment(4)
Thank you very much for your explanations!! > sum(grepl("^<leaf>$", as.character(model_dt$frame$var))) - 1 [1] 1 so it means I have only one splitting rule? if yes, then what is the meaning of, for example " ip.dts splits as LLLLLLLLLLLLLLLLLLLRLLLLLLLLLLL" . I think there is 31 split rules for this part? My dataset is more than 10GB, is it logical that I have only 1 split?Demetri
Type fit at the R prompt and it will show you how many splits you have. (where fit is the object containing your tree). That way you can check...Rashad
Hi I had a similar requirement where I want to select the CP value for pruning so that I get max 5 leaves only. Is there a way I can do such selection. Also how can I now the values that each of the leaves are satisfying. Actually I am using rpart trees to have groups of my categorical variable such that they are associated with my continuous variables. You inputs would help me a lot !! Thank you !!Cavitation
I have posted my question and link to the same is #49229286 it would help a lot if you can provide your suggestion to achieve this.Cavitation

© 2022 - 2024 — McMap. All rights reserved.