Motivation
I am trying to create my first real deal program in Rust for a school project (not a requirement..I just have been fascinated by Rust and decided that I'm going to take the plunge).
The project is a simple simulation of a robot's decisions based on some sensor data, some probabilities, prediction of future rewards, and some other stuff. The program consists of a main loop where lots of math takes place at each time step for some time horizon into the future. The data that gets carried to each subsequent time step is represented by a matrix Y that consists of two columns of linear coefficients (which are modified at each time step) of a set of linear constraints (where more constraints/rows of coefficients are added to the set at each time step).
Since the program will require lots of element-wise matrix operations and I'm well experienced in NumPy, the ndarray
crate seemed like a perfect fit for the job. My thought process for the program was to make a mutable 2D array for Y that would get modified with each loop iteration, rather than allocating a new array every time. It has since dawned on me that the number of rows will grow an unknown amount with each iteration as well, so maybe this approach wasn't the greatest idea, but my question on the error I'm getting stands regardless.
Question
My question is this: if I want to modify an array at each iteration of a loop by passing a reference to the array into several functions that will modify its data, how can I also use the same array in basic element-wise arithmetic operations?
Here is a bare-bones example of my code to demonstrate:
extern crate ndarray;
use ndarray::prelude::*;
fn main() {
let pz = array![[0.7, 0.3], [0.3, 0.7]]; // measurement probabilities
let mut Y = Array2::<f64>::zeros((1, 2));
for i in 1..10 {
do_some_maths(&mut Y, pz);
// other functions that will modify Y
}
println!("Result: {}", Y);
}
fn do_some_maths(Y: &mut Array2<f64>, pz: Array2<f64>) {
let Yp = Y * pz.slice(s![.., 0]); // <-- this is the problem
// do lots of matrix math with Yp
// ...
// then modify Y's data using Yp (hence Y needs to be &mut)
}
Which gives the following compiling error:
error[E0369]: binary operation `*` cannot be applied to type `&mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>`
--> src/main2.rs:21:16
|
21 | let Yp = Y * pz.slice(s![.., 0]); // <-- this is the problem
| - ^ ------------------- ndarray::ArrayBase<ndarray::ViewRepr<&f64>, _>
| |
| &mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>
|
= note: an implementation of `std::ops::Mul` might be missing for `&mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>`
I have spent many hours trying to understand
- what is the correct approach to my use case, and
- why the code I have written doesn't work.
I read several questions on this site that were somewhat related, but none of them really went into the case of dealing with an Array reference as a function parameter and performing a binary operation on it.
I've studied hard the first 5 chapters of the Rust book and dived deep into the documentation of ndarray
, and I still can't find answers. ndarray
's documentation of ArrayBase
contains the following explanation, which I don't fully understand:
Binary Operations on Two Arrays
Let A be an array or view of any kind. Let B be an array with owned storage (either Array or ArcArray). Let C be an array with mutable data (either Array, ArcArray or ArrayViewMut). The following combinations of operands are supported for an arbitrary binary operator denoted by @ (it can be +, -, *, / and so on).
- &A @ &A which produces a new Array
- B @ A which consumes B, updates it with the result, and returns it
- B @ &A which consumes B, updates it with the result, and returns it
- C @= &A which performs an arithmetic operation in place
Given this description, and searching through the many trait implementations for Add
, Mul
, etc., it seems to me that a mutable ndarray::Array
cannot be an operand in a binary operation, except in the case of compound assignment.
Is that true, or am I missing something here? I don't want to simply memorize this little tidbit and move on; I really want to understand what is actually going on here, and where my understanding is lacking. Please help me to wrap my C++/Python trained brain around this. :)