Is it safe to interleave manual realWorld# state passing with an arbitrary Monad
Asked Answered
C

1

6

Consider this function that generates a list for an arbitrary Monad:

generateListM :: Monad m => Int -> (Int -> m a) -> m [a]
generateListM sz f = go 0
  where go i | i < sz = do x <- f i
                           xs <- go (i + 1)
                           return (x:xs)
             | otherwise = pure []

Implementation maybe isn't perfect, but it is presented here solely for demonstration of the desired effect, which is pretty straightforward. For example if a monad is a list well get list of lists:

λ> generateListM 3 (\i -> [0 :: Int64 .. fromIntegral i])
[[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,1,2]]

What I'd like to do is to achieve the same affect, but for ByteArray instead of a List. As it turns out, this is much trickier than I thought when I first stumbled upon this problem. The end goal is to use that generator to implement mapM in massiv, but that is besides the point.

The approach that requires the least effort is to use a function generateM from vector package while doing a bit of manual conversion. But as it turns out there is a way to achieve at least a factor of x2 performance gain with this neat little trick of handling the state token manually and interleaving it with the monad:

{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnboxedTuples       #-}
import           Data.Primitive.ByteArray
import           Data.Primitive.Types
import qualified Data.Vector.Primitive    as VP
import           GHC.Int
import           GHC.Magic
import           GHC.Prim

-- | Can't `return` unlifted types, so we need a wrapper for the state and MutableByteArray
data MutableByteArrayState s = MutableByteArrayState !(State# s) !(MutableByteArray# s)

generatePrimM :: forall m a . (Prim a, Monad m) => Int -> (Int -> m a) -> m (VP.Vector a)
generatePrimM (I# sz#) f =
  runRW# $ \s0# -> do
    let go i# = do
          case i# <# sz# of
            0# ->
              case newByteArray# (sz# *# sizeOf# (undefined :: a)) (noDuplicate# s0#) of
                (# s1#, mba# #) -> return (MutableByteArrayState s1# mba#)
            _ -> do
              res <- f (I# i#)
              MutableByteArrayState si# mba# <- go (i# +# 1#)
              return (MutableByteArrayState (writeByteArray# mba# i# res si#) mba#)
    MutableByteArrayState s# mba# <- go 0#
    case unsafeFreezeByteArray# mba# s# of
      (# _, ba# #) -> return (VP.Vector 0 (I# sz#) (ByteArray ba#))

We can use it in the same fashion as before, except now we'll get a primitive Vector, which is backed by ByteArray, which is what I really need:

λ> generatePrimM 3 (\i -> [0 :: Int64 .. fromIntegral i])
[[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,1,2]]

This seems to work great, performs well for ghc version 8.0 and 8.2, except that there is a regression in 8.4 and 8.6, but that issue is orthogonal.

Finally I get to the actual question. Is this approach really safe? Is there some edge case that I am not aware of that could bite me later? Any other suggestions or opinions are welcome as well with regard to the above function.

PS. m doesn't have to be restricted to a Monad, an Applicative would work just fine, but the example is a bit clearer when it is presented with do syntax.

Cavafy answered 13/12, 2018 at 21:4 Comment(3)
Maybe I am not completely understanding your goals, but have you considered using, say, ST s or any other PrimMonad? I would attempt to achieve the same performance through known-safe methods before resorting to low-level primitives.Turtledove
@Turtledove That's a good question. It's not totally obvious, but it is impossible to use either PrimMonad or ST in order to achieve the desired function.Cavafy
@chi, I stand corrected, primitive does indeed has a function I was looking for. :) See my answer.Cavafy
C
0

TLDR; From what I gathered so far, it does seem to be a safe way to generate a primitive Vector in a way I originally proposed. Moreover, the use of noDuplicate# is not really necessary, since all of the operations are idempotent and order of operations will not have an affect on the resulted array(s).

Disclosure: It's been over a year since I first thought about that problem. It was only last month that I tried to get back to it. Reason why I am saying this is because checking out primitive package now I noticed a new module Data.Primitive.PrimArray to me. As @chi mentioned in the comments, there isn't really a need to drop down to the low-level primitives in order to get a solution, since it might already exist. Which contains exactly the function generatePrimArrayA, which was exactly what I was looking for (a bit simplified copy of the source code):

newtype STA a = STA {_runSTA :: forall s. MutableByteArray# s -> ST s (PrimArray a)}

runSTA :: forall a. Prim a => Int -> STA a -> PrimArray a
runSTA !sz =
  \(STA m) -> runST $ newPrimArray sz >>= \(ar :: MutablePrimArray s a) -> m (unMutablePrimArray ar)

generatePrimArrayA :: (Applicative f, Prim a) => Int -> (Int -> f a) -> f (PrimArray a)
generatePrimArrayA len f =
  let go !i
        | i == len = pure $ STA $ \mary -> unsafeFreezePrimArray (MutablePrimArray mary)
        | otherwise =
          liftA2
            (\b (STA m) -> STA $ \mary -> writePrimArray (MutablePrimArray mary) i b >> m mary)
            (f i)
            (go (i + 1))
   in runSTA len <$> go 0

Just as a fun exercise if we go through the basic simplification with usual reduction rules we get a very similar thing to what I had in the first place:

generatePrimArrayA :: forall f a. (Applicative f, Prim a) => Int -> (Int -> f a) -> f (PrimArray a)
generatePrimArrayA !(I# n#) f =
  let go i# = case i# <# n# of
                0# -> pure $ \mary s# ->
                        case unsafeFreezeByteArray# mary s# of
                          (# s'#, arr'# #) -> (# s'#, PrimArray arr'# #)
                _ -> liftA2
                     (\b m ->
                        \mary s ->
                          case writeByteArray# mary i# b s of
                            s'# -> m mary s'#)
                     (f (I# i#))
                     (go (i# +# 1#))
   in (\m -> runRW# $ \s0# ->
                case newByteArray# (n# *# sizeOf# (undefined :: a)) s0# of
                  (# s'#, arr# #) -> case m arr# s'# of
                                       (# _, a #) -> a)
      <$> go 0#

Here is my version adjusted for an Applicative instead of a Monad:

generatePrimM :: forall m a . (Prim a, Applicative m) => Int -> (Int -> m a) -> m (PrimArray a)
generatePrimM (I# sz#) f =
  let go i# = case i# <# sz# of
                0# -> runRW# $ \s0# ->
                      case newByteArray# (sz# *# sizeOf# (undefined :: a)) s0# of
                        (# s1#, mba# #) -> pure (MutableByteArrayState s1# mba#)
                _  -> liftA2
                      (\b (MutableByteArrayState si# mba#) ->
                         MutableByteArrayState (writeByteArray# mba# i# b si#) mba#)
                      (f (I# i#))
                      (go (i# +# 1#))
   in (\(MutableByteArrayState s# mba#) ->
         case unsafeFreezeByteArray# mba# s# of
           (# _, ba# #) -> PrimArray ba#) <$>
      (go 0#)

Functionally and performance wise they are very close to each other, and in the end they will both produce exactly the same answer. The difference is what the inner loop go produces in the end. The latter one will return an applicative containing the closure that can construct the MutableByteArray#s, which will later be frozen. While the former has a loop that returns an applicative containing an action that will create a frozen ByteArray#s, once an action that can create a MutableByteArray# is supplied to it.

Nevertheless, the reason what makes both approaches safe is that each element of every produced array within the loop gets written to exactly once, and each MutableByteArray# created does get frozen prior to getting returned by the generating function, but not before it finished writing to them.

Cavafy answered 23/1, 2019 at 23:7 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.