Efficient implementation of Markov Chains in julia
Asked Answered
E

1

9

I want to simulate the movement of a random walker in a network as efficiently as possible. Below I show a toy model with the three approaches I have tried so far. I should note that in my original problem the edges of the network are fixed, however the weights of the edges may be updated (i.e. the list of neighbours is the same but the weights may change).

using QuantEcon
using LightGraphs
using Distributions
using StatsBase

n = 700 #number of nodes
#setting an arbitrary network and its transition matrix
G_erdos = erdos_renyi(n, 15/n)
A_erdos = adjacency_matrix(G_erdos) + eye(n, n);
A_transition = A_erdos ./ sum(A_erdos, 2);

##Method 1
#using QuantEcon library
function QE_markov_draw(i::Int, A::Array{Float64,2})
    d = DiscreteRV(A[i, :]);
    return rand(d, 1)   
end

##Method 2
#using a simple random draw
function matrix_draw(i::Int, A::Array{Float64,2}, choices::Array{Int64,1})
    return sample(choices, Weights(A[i, :]))
end

##Method 3
# The matrix may be sparse. Therefore I obtain first list of neighbors and weights
#for each node. Then run sample using the list of neighbors and weights.
function neighbor_weight_list(A::Array{Float64,2}, i::Int)
    n = size(A)[1]
    neighbor_list = Int[]
    weight_list = Float64[]
    for i = 1:n
        for j = 1:n
            if A[i, j] > 0
                push!(neighbor_list, j)
                push!(weight_list, A[i, j])
            end
        end
    end
    return neighbor_list, weight_list
end
#Using sample on the reduced list.
function neigh_weights_draw(i::Int, neighs::Array{Int,1}, weigh::Array{Float64,1})
    return sample(neighs, Weights(weigh))
end

neighbor_list, weight_list = neighbor_weight_list(A_transition, 1)
states = [i for i = 1:n];

println("Method 1")
@time for t = 1:100000
    QE_markov_draw(3, A_transition)
end

println("Method 2")
@time for t = 1:100000
    matrix_draw(3, A_transition, states)
end

println("Method 3")
@time for t = 1:100000
    neigh_weights_draw(3, neighbor_list, weight_list)
end

The general results show (after first itereation) that method 2 is the fastest. Method 3 uses the least memory followed by method 2, however this might be because they "feed" on neighbor_list and states.

Method 1
  0.327805 seconds (500.00 k allocations: 1.086 GiB, 14.70% gc time)
Method 2
  0.227060 seconds (329.47 k allocations: 554.344 MiB, 11.24% gc time)
Method 3
  1.224682 seconds (128.19 k allocations: 3.482 MiB)

I was wondering which implementation would be most efficient and if there was a way of improving it.

Embryotomy answered 10/5, 2018 at 14:11 Comment(0)
L
6

Here are some recommendations I can give:

In option 2 use view instead and work on a transpose of a matrix (so you work on columns not rows):

# here A should be a transpose of your original A
function matrix_draw(i::Int, A::Array{Float64,2}, choices::Array{Int64,1})
    return sample(choices, Weights(view(A, i, :)))
end

This gives almost 7x speedup in my tests.

But in general method 3 should be fastest but it seems to be incorrectly implemented. Here is a fixed approach

function neighbor_weight_list(A::Array{Float64,2})
    n = size(A)[1]
    neighbor_list = Vector{Int}[]
    weight_list = Vector{Float64}[]
    for i = 1:n
        push!(neighbor_list, Int[])
        push!(weight_list, Float64[])
        for j = 1:n
            if A[i, j] > 0
                push!(neighbor_list[end], j)
                push!(weight_list[end], A[i, j])
            end
        end
    end
    return neighbor_list, weight_list
end

function neigh_weights_draw(i::Int, neighs::Vector{Vector{Int}}, weigh::Vector{Vector{Float64}})
    return sample(neighs[i], Weights(weigh[i]))
end

neighbor_list, weight_list = neighbor_weight_list(A_transition)

When I run this code it is 4x faster than fixed method 2. Also note that you could use method 3 without creating adjacency matrix at all but directly working using neighbors function from LightGraphs.

Len answered 10/5, 2018 at 21:10 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.