data.table sample with probabilities stored in columns
Asked Answered
B

3

5

I have a data table with probabilities for a discrete distribution stored in columns.

For example, dt <- data.table(p1 = c(0.5, 0.25, 0.1), p2 = c(0.25, 0.5, 0.1), p3 = c(0.25, 0.25, 0.8))

I'd like to create a new column of a random variable sampled using the probabilities in the same row. In data.table syntax I imagine it working like this:

dt[, sample := sample(1:3, 1, prob = c(p1, p2, p3))]

If there were a 'psample' function similar to 'pmin' and 'pmax' this would work. I was able to make this work using apply, the downside is that with my real data set this takes longer than I would like. Is there a way to make this work using data.table? The apply solution is given below.

dt[, sample := apply(dt, 1, function(x) sample(1:3, 1, prob = x[c('p1', 'p2', 'p3')]))]
Beast answered 16/7, 2022 at 4:3 Comment(4)
Related: Efficiently apply sample() in R.Korykorzybski
@Korykorzybski Nice. Throws an error, but it's solved there.Embody
Not an answer, but for the record: if you fail to vectorize your function, common alternatives to apply are (1) by = 1:nrow(dt), or (2) melt to long format. Described e.g. here: Efficient row-wise operations on a data.table; How to do row wise operations on .SD columns in data.table, posts that you should have found, even with a very poor google-fu - "R data.table rowwise" ;)Korykorzybski
@Korykorzybski Exactly. With my psampv I actually was inspired from the pmin solution of your first link.Embody
E
4

If you are choosing from 1:n you could use sampl.int which is faster. Also applying on a matrix is faster. Putting both in a function psamp is even faster.

So, try this (I added dt[, 1:3] so that it won't fail once the column is added):

psamp <- function(x) sample.int(n=3, size=1, prob=x)
dt[, sample :=apply(as.matrix(dt[, 1:3]), 1, psamp)]

To get somewhat rid of the apply we could Vectorize psamp and use do.call. Additionally—as @IRTFM suggests in his answer—we should make use of the .SD symbol.

psampv <- Vectorize(function(p1, p2, p3) sample.int(n=3, size=1, replace=TRUE, prob=c(p1, p2, p3)))
dt[, sample := do.call(psampv, .SD), .SDcols=c('p1','p2','p3')]

To improve performance by even more than an order of magnitude, as suggested by @Henrik in comments we may use Rcpp. I have slightly adapted the code from this answer and use the new Rcpp::sample, which kindly gives identical results to base::sample with the same set.seed.

#include <Rcpp.h>
// [[Rcpp::export]]
Rcpp::IntegerVector sample_matrix1(Rcpp::NumericMatrix x, Rcpp::IntegerVector choice_set) {
  int n = x.nrow();
  Rcpp::IntegerVector result(n);
  for (int i = 0; i < n; ++i) {
    Rcpp::NumericVector z(x(i, Rcpp::_));
    result[i] = Rcpp::sample(choice_set, 1, false, z)[0];
  }
  return result;
}

Rcpp::sourceCpp("sample_matrix1.cpp")

dt[, sample := sample_matrix1(as.matrix(.SD), 1:3), .SDcols=c('p1','p2','p3')] 

Benchmark, 100k*100 repetitions each:

Unit: milliseconds
          expr        min         lq       mean     median         uq       max neval cld
      psamp_:= 1195.16708 1259.06558 1327.19581 1311.17878 1349.98905 1515.1187   100   b
     psamp_.SD 1225.90467 1257.37766 1318.74885 1289.27571 1335.07736 1522.3423   100   b
     psamp_set 1181.44985 1256.73204 1320.29317 1301.75657 1335.22009 1491.3870   100   b
 psamp_do.call 1181.93117 1251.45863 1316.23306 1285.85710 1337.06674 1476.8023   100   b
          rcpp   60.73652   67.15291   72.76073   70.47052   73.91629  127.8278   100  a 
Embody answered 16/7, 2022 at 5:56 Comment(0)
C
3

I think the proper data.table approach would be to use the .SD facilities:

dt2 <- rbind(dt,dt,dt,dt)
psamp <- function(x) sample.int(n=3, size=1, prob=x) # from jay.sf

dt2[, sample :=apply(.SD, 1, psamp), .SDcols=c('p1','p2','p3')]
> dt2
      p1   p2   p3 sample
 1: 0.50 0.25 0.25      2
 2: 0.25 0.50 0.25      1
 3: 0.10 0.10 0.80      2
 4: 0.50 0.25 0.25      3
 5: 0.25 0.50 0.25      2
 6: 0.10 0.10 0.80      3
 7: 0.50 0.25 0.25      3
 8: 0.25 0.50 0.25      2
 9: 0.10 0.10 0.80      3
10: 0.50 0.25 0.25      1
11: 0.25 0.50 0.25      2
12: 0.10 0.10 0.80      3

A note on style: it's better to refrain from naming R objects with strings that are also names of R functions, such as df (density of the F distribution), dt (density of the t-distribution), data (method to load a canned dataset).

Carbylamine answered 16/7, 2022 at 21:55 Comment(3)
I do like this approach because it allows you to use different probability columns without defining a new function.Beast
Of course @IRTFM, thanks for sharing your wisdom! I kept working on it, trying to get rid of the apply. In the end, Rcpp beats us all, though.Embody
Oh yeah, I would never want to challenge Rcpp.Carbylamine
P
2

I think this is also an option, and might be quite fast? Perhaps @jay.sf can let us know, as it also uses psamp (thanks @jay.sf)

set(dt,j="sample",value=apply(dt,1,psamp))
Prunella answered 17/7, 2022 at 0:53 Comment(1)
Isn't set equivalent to :=? I took it up in the benchmark, appears to be slightly faster.Embody

© 2022 - 2024 — McMap. All rights reserved.