How to Idiomatically Test for Overflow when Shifting Left (<<) in Rust?
Asked Answered
N

3

8

For most operators that might overflow, Rust provides a checked version. For example, to test if an addition overflows one could use checked_add:

match 255u8.checked_add(1) {
    Some(_) => println!("no overflow"),
    None => println!("overflow!"),
}

This prints "overflow!". There is also a checked_shl, but according to the documentation it only checks if the shift is larger than or equal to the number of bits in self. That means that while this:

match 255u8.checked_shl(8) {
    Some(val) => println!("{}", val),
    None => println!("overflow!"),
}

is caught and prints "overflow!", This:

match 255u8.checked_shl(7) {
    Some(val) => println!("{}", val),
    None => println!("overflow!"),
}

simply prints 128, clearly not catching the overflow. What is the correct way to check for any kind of overflow when shifting left?

Norway answered 22/8, 2020 at 15:3 Comment(0)
E
6

I'm not aware of any idiomatic way of doing this, but something like implementing your own trait would work: Playground

The algorithm is basically to check if there are not fewer leading zeros in the number than the shift size

trait LossCheckedShift {
    fn loss_checked_shl(self, rhs: u32) -> Option<Self> 
        where Self: std::marker::Sized;
}

impl LossCheckedShift for u8 {
    fn loss_checked_shl(self, rhs: u32) -> Option<Self> {
        (rhs <= self.leading_zeros()).then_some(self << rhs)
    }
}

fn main() {
    match 255u8.loss_checked_shl(7) {
        Some(val) => println!("{}", val),
        None => println!("overflow!"), // <--
    } 
    
    match 127u8.loss_checked_shl(1) {
        Some(val) => println!("{}", val), // <--
        None => println!("overflow!"),
    }
    match 127u8.loss_checked_shl(2) {
        Some(val) => println!("{}", val),
        None => println!("overflow!"), // <--
    }
}
Ethylethylate answered 22/8, 2020 at 15:35 Comment(0)
D
7

You could do a complementary right-shift (right-shift by 8 - requested_number_of_bits) and check if 0 remains. If so, it means that no bits would be lost by left-shifting:

fn safe_shl(n: u8, shift_for: u8) -> Option<u8> {
    if n >> (8 - shift_for) != 0 {
        return None; // would lose some data
    }
    Some(n << shift_for)
}

One can also write a generic version that accepts any numeric type, including bigints (and which applied to u8 generates exactly the same code as above):

use std::mem::size_of;
use std::ops::{Shl, Shr};

fn safe_shl<T>(n: T, shift_for: u32) -> Option<T>
where
    T: Default + Eq,
    for<'a> &'a T: Shl<u32, Output = T> + Shr<u32, Output = T>,
{
    let bits_in_t = size_of::<T>() as u32 * 8;
    let zero = T::default();
    if &n >> (bits_in_t - shift_for) != zero {
        return None; // would lose some data
    }
    Some(&n << shift_for)
}

Playground

Danielladanielle answered 22/8, 2020 at 15:38 Comment(1)
Thanks, I'm kind of surprised there's nothing like this in the standard library but this is a great solution too.Norway
E
6

I'm not aware of any idiomatic way of doing this, but something like implementing your own trait would work: Playground

The algorithm is basically to check if there are not fewer leading zeros in the number than the shift size

trait LossCheckedShift {
    fn loss_checked_shl(self, rhs: u32) -> Option<Self> 
        where Self: std::marker::Sized;
}

impl LossCheckedShift for u8 {
    fn loss_checked_shl(self, rhs: u32) -> Option<Self> {
        (rhs <= self.leading_zeros()).then_some(self << rhs)
    }
}

fn main() {
    match 255u8.loss_checked_shl(7) {
        Some(val) => println!("{}", val),
        None => println!("overflow!"), // <--
    } 
    
    match 127u8.loss_checked_shl(1) {
        Some(val) => println!("{}", val), // <--
        None => println!("overflow!"),
    }
    match 127u8.loss_checked_shl(2) {
        Some(val) => println!("{}", val),
        None => println!("overflow!"), // <--
    }
}
Ethylethylate answered 22/8, 2020 at 15:35 Comment(0)
M
0

I've always taken the shifted value, shifted it right by the same amount and then compared to the original number. If they are not equal then some shift overflow must have occured.

Mascle answered 3/5, 2023 at 1:13 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.