Preferred performant procedure for R data.table row-wise operations?
Asked Answered
C

2

5

Does the following code represent the preferred procedure for traversing the rows of an R data.table and passing the values found at each row to a function? Or is there a more performant way to do this?

library(data.table)
set.seed(2)
n <- 100
b <- c(0.5, 1.5, -1)
phi <- 0.8
X <- cbind(1, matrix(rnorm(n*2, 0, 1), ncol = 2))
y <- X %*% matrix(b, ncol = 1) + rnorm(n, 0, phi)
d <- data.table(y, X)
setnames(d, c("y", "x0", "x1", "x2"))

logpost <- function(d, b1, b2, b3, phi, mub = 1, taub = 10, a = 0.5, z = 0.7){
    N <- nrow(d)
    mu <- b1 + b2 * d$x1 + b3 * d$x2
    lp <- -N * log(phi) -
        (1/(2*phi^2)) * sum( (d$y-mu)^2  ) -
        (1/(2*taub^2))*( (b1-mub)^2 + (b2-mub)^2 + (b3-mub)^2 ) -
        (a+1)*log(phi) - (z/phi)
    lp
}

nn <- 21
grid <- data.table(
expand.grid(b1 = seq(0, 1, len = nn),
    b2 = seq(1, 2, len = nn),
    b3 = seq(-1.5, -0.5, len = nn),
    phi = seq(0.4, 1.2, len = nn)))
grid[, id := 1:.N]
setkey(grid, id)

wraplogpost <- function(dd){
    logpost(d, dd$b1, dd$b2, dd$b3, dd$phi)
}
start <- Sys.time()
grid[, lp := wraplogpost(.SD), by = seq_len(nrow(grid))]
difftime(Sys.time(), start)
# Time difference of 2.081544 secs

Edit: display first few records

> head(grid)
b1 b2   b3 phi id        lp
1: 0.00  1 -1.5 0.4  1 -398.7618
2: 0.05  1 -1.5 0.4  2 -380.3674
3: 0.10  1 -1.5 0.4  3 -363.5356
4: 0.15  1 -1.5 0.4  4 -348.2663
5: 0.20  1 -1.5 0.4  5 -334.5595
6: 0.25  1 -1.5 0.4  6 -322.4152

I have tried using set but that approach seems inferior

start <- Sys.time()
grid[, lp := NA_real_]
for(i in 1:nrow(grid)){
    llpp <- wraplogpost(grid[i])
    set(grid, i, "lp", llpp)
}
difftime(Sys.time(), start)
# Time difference of 21.71291 secs

Edit: display first few records

> head(grid)
b1 b2   b3 phi id        lp
1: 0.00  1 -1.5 0.4  1 -398.7618
2: 0.05  1 -1.5 0.4  2 -380.3674
3: 0.10  1 -1.5 0.4  3 -363.5356
4: 0.15  1 -1.5 0.4  4 -348.2663
5: 0.20  1 -1.5 0.4  5 -334.5595
6: 0.25  1 -1.5 0.4  6 -322.4152

Suggestions or pointers to the relevant docs would be appreciated.

Edit: per comments:

start <- Sys.time()
grid[, lp := wraplogpost(.SD), by = .I]
difftime(Sys.time(), start)
Warning messages:
1: In b2 * d$x1 :
    longer object length is not a multiple of shorter object length
2: In b3 * d$x2 :
    longer object length is not a multiple of shorter object length
3: In d$y - mu :
    longer object length is not a multiple of shorter object length
> difftime(Sys.time(), start)
Time difference of 0.01199317 secs
> 
> head(grid)
b1 b2   b3 phi id        lp
1: 0.00  1 -1.5 0.4  1 -620977.2
2: 0.05  1 -1.5 0.4  2 -620977.2
3: 0.10  1 -1.5 0.4  3 -620977.2
4: 0.15  1 -1.5 0.4  4 -620977.2
5: 0.20  1 -1.5 0.4  5 -620977.2
6: 0.25  1 -1.5 0.4  6 -620977.2

which generates the wrong values for lp.

Edit thank you for the comments and responses. I am aware that this scenario could be addressed by using alternative methods, my interest is in what the preferred way to do this is when using data.table.

Edit thank you for the responses again. As there have been none that address the question of how to do this explicitly with data.table, at the moment, I am assuming that there is no ideal way to achieve this without turning to base R.

Casebook answered 18/4, 2021 at 7:58 Comment(4)
Try by = .I. It's faster, see ?.I.Ululate
Thank you. I think the help is saying that .I should be used for obtaining the row indices in j rather than as a by term. Also the answer here: https://mcmap.net/q/676576/-row-operations-in-data-table-using-by-i, suggests (to me at least) that .I should not be used in the by clause. Am I interpreting that answer incorrectly?Casebook
Yes, I believe you are misinterpreting that answer. .I returns seq_len(nrow(grid)) but is faster since it's a value computed by data.table. Try it.Ululate
This is generally slow because you are doing a lot of $ extractions. Your loop would work better if your data was a matrix as opposed to a list (i.e, data.table).Bushelman
M
3

I think you can use matrix multiplication and other vectorization techniques to simplify your code, which helps you avoid running function logpost in a row-wise manner.


Below is a vectorized version of logpost, i.e., logpost2

logpost2 <- function(d, dd, mub = 1, taub = 10, a = 0.5, z = 0.7) {
  bmat <- as.matrix(dd[, .(b1, b2, b3)])
  xmat <- cbind(1, as.matrix(d[, .(x1, x2)]))
  phi <- dd$phi
  phi_log <- log(phi)
  lp <- -(a + nrow(d) + 1) * phi_log -
    (1 / (2 * phi^2)) * colSums((d$y - tcrossprod(xmat, bmat))^2) -
    (1 / (2 * taub^2)) * rowSums((bmat - mub)^2) - (z / phi)
  lp
}

and you will see

> start <- Sys.time()

> grid[, lp := logpost2(d, .SD)]

> difftime(Sys.time(), start)
Time difference of 0.1966231 secs

and

> head(grid)
     b1 b2   b3 phi id        lp
1: 0.00  1 -1.5 0.4  1 -398.7618
2: 0.05  1 -1.5 0.4  2 -380.3674
3: 0.10  1 -1.5 0.4  3 -363.5356
4: 0.15  1 -1.5 0.4  4 -348.2663
5: 0.20  1 -1.5 0.4  5 -334.5595
6: 0.25  1 -1.5 0.4  6 -322.4152
Minivet answered 21/4, 2021 at 9:38 Comment(0)
P
5

If you want to have a better performance (time) you could rewrite the rowwise function to a calculation with matrices.

start <- Sys.time()
grid_mat <- as.matrix(grid[, list(b1, b2, b3, 1)])
# function parameters
N <- nrow(d); mub = 1; taub = 10; a = 0.5; z = 0.7
d$const <- 1

# combining d$y - mu in this step already
mu_op <- matrix(c(-d$const, -d$x1, -d$x2, d$y), nrow = 4, byrow = TRUE)
mu_mat <- grid_mat %*% mu_op
mub_mat <- (grid_mat[, c("b1", "b2", "b3")] - mub)^2
# just to save one calculation of the log
phi <- grid$phi
log_phi <- log(grid$phi)

grid$lp2 <- -N * log_phi -
  (1/(2*phi^2)) * rowSums(mu_mat^2) -
  (1/(2*taub^2))*( rowSums(mub_mat) ) -
  (a+1)*log_phi - (z/phi)
head(grid)
difftime(Sys.time(), start)

The first rows:

     b1 b2   b3 phi id        lp       lp2
1: 0.00  1 -1.5 0.4  1 -398.7618 -398.7618
2: 0.05  1 -1.5 0.4  2 -380.3674 -380.3674
3: 0.10  1 -1.5 0.4  3 -363.5356 -363.5356
4: 0.15  1 -1.5 0.4  4 -348.2663 -348.2663
5: 0.20  1 -1.5 0.4  5 -334.5595 -334.5595
6: 0.25  1 -1.5 0.4  6 -322.4152 -322.4152

For the timing:

# on your code on my pc:
Time difference of 4.390684 secs
# my code on my pc:
Time difference of 0.680476 secs
Phalangeal answered 19/4, 2021 at 8:21 Comment(2)
Great answer, you catch the point for speed up. Upvoted your answer!Minivet
Thanks, but you made that into a much more usable function with your answer.Phalangeal
M
3

I think you can use matrix multiplication and other vectorization techniques to simplify your code, which helps you avoid running function logpost in a row-wise manner.


Below is a vectorized version of logpost, i.e., logpost2

logpost2 <- function(d, dd, mub = 1, taub = 10, a = 0.5, z = 0.7) {
  bmat <- as.matrix(dd[, .(b1, b2, b3)])
  xmat <- cbind(1, as.matrix(d[, .(x1, x2)]))
  phi <- dd$phi
  phi_log <- log(phi)
  lp <- -(a + nrow(d) + 1) * phi_log -
    (1 / (2 * phi^2)) * colSums((d$y - tcrossprod(xmat, bmat))^2) -
    (1 / (2 * taub^2)) * rowSums((bmat - mub)^2) - (z / phi)
  lp
}

and you will see

> start <- Sys.time()

> grid[, lp := logpost2(d, .SD)]

> difftime(Sys.time(), start)
Time difference of 0.1966231 secs

and

> head(grid)
     b1 b2   b3 phi id        lp
1: 0.00  1 -1.5 0.4  1 -398.7618
2: 0.05  1 -1.5 0.4  2 -380.3674
3: 0.10  1 -1.5 0.4  3 -363.5356
4: 0.15  1 -1.5 0.4  4 -348.2663
5: 0.20  1 -1.5 0.4  5 -334.5595
6: 0.25  1 -1.5 0.4  6 -322.4152
Minivet answered 21/4, 2021 at 9:38 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.