Consider this function that generates a list for an arbitrary Monad
:
generateListM :: Monad m => Int -> (Int -> m a) -> m [a]
generateListM sz f = go 0
where go i | i < sz = do x <- f i
xs <- go (i + 1)
return (x:xs)
| otherwise = pure []
Implementation maybe isn't perfect, but it is presented here solely for demonstration of the desired effect, which is pretty straightforward. For example if a monad is a list well get list of lists:
λ> generateListM 3 (\i -> [0 :: Int64 .. fromIntegral i])
[[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,1,2]]
What I'd like to do is to achieve the same affect, but for ByteArray
instead of a List. As it turns out, this is much trickier than I thought when I first stumbled upon this problem. The end goal is to use that generator to implement mapM
in massiv, but that is besides the point.
The approach that requires the least effort is to use a function generateM
from vector package while doing a bit of manual conversion. But as it turns out there is a way to achieve at least a factor of x2 performance gain with this neat little trick of handling the state token manually and interleaving it with the monad:
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnboxedTuples #-}
import Data.Primitive.ByteArray
import Data.Primitive.Types
import qualified Data.Vector.Primitive as VP
import GHC.Int
import GHC.Magic
import GHC.Prim
-- | Can't `return` unlifted types, so we need a wrapper for the state and MutableByteArray
data MutableByteArrayState s = MutableByteArrayState !(State# s) !(MutableByteArray# s)
generatePrimM :: forall m a . (Prim a, Monad m) => Int -> (Int -> m a) -> m (VP.Vector a)
generatePrimM (I# sz#) f =
runRW# $ \s0# -> do
let go i# = do
case i# <# sz# of
0# ->
case newByteArray# (sz# *# sizeOf# (undefined :: a)) (noDuplicate# s0#) of
(# s1#, mba# #) -> return (MutableByteArrayState s1# mba#)
_ -> do
res <- f (I# i#)
MutableByteArrayState si# mba# <- go (i# +# 1#)
return (MutableByteArrayState (writeByteArray# mba# i# res si#) mba#)
MutableByteArrayState s# mba# <- go 0#
case unsafeFreezeByteArray# mba# s# of
(# _, ba# #) -> return (VP.Vector 0 (I# sz#) (ByteArray ba#))
We can use it in the same fashion as before, except now we'll get a primitive Vector
, which is backed by ByteArray
, which is what I really need:
λ> generatePrimM 3 (\i -> [0 :: Int64 .. fromIntegral i])
[[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,1,2]]
This seems to work great, performs well for ghc version 8.0 and 8.2, except that there is a regression in 8.4 and 8.6, but that issue is orthogonal.
Finally I get to the actual question. Is this approach really safe? Is there some edge case that I am not aware of that could bite me later? Any other suggestions or opinions are welcome as well with regard to the above function.
PS. m
doesn't have to be restricted to a Monad
, an Applicative
would work just fine, but the example is a bit clearer when it is presented with do
syntax.
ST s
or any otherPrimMonad
? I would attempt to achieve the same performance through known-safe methods before resorting to low-level primitives. – TurtledovePrimMonad
orST
in order to achieve the desired function. – Cavafyprimitive
does indeed has a function I was looking for. :) See my answer. – Cavafy