Let’s imagine there’s input from some device, that produces only zeros and ones and it needed to be converted into actual integers, something like
let bits: Vec<u8> = vec![0, 0, 0, 0, 0, 1, 0, 1];
let result: u8 = convert(&bits);
assert_eq!(result, 5);
In this article, I will show how Rust functions can be generalized using generics and which problem I faced.
Naive solution
Here’s my first and naive implementation:
fn convert(bits: &[u8]) -> u8 {
let mut result: u8 = 0;
bits.iter().for_each(|&bit| {
result <<= 1;
result ^= bit;
});
result
}
No magic, doing bit shift, and then XORing result with each bit. If it is 1 it will be added to result, if it is zero, zero remains in the rightmost position.
But there’s room for improvement. fold
function is more idiomatic:
fn convert(bits: &[u8]) -> u8 {
bits.iter()
.fold(0, |result, &bit| {
(result << 1) ^ bit
})
}
Generic Approach
The last solution is quite good, but it only produces an output of type u8
. What if we want to have u32
, or i64
. Rust has support of Generic Data Types for cases like that, let’s try!
We will need PartialEq
, BitXor
, Shl
and From<bool>
traits to define behavior inside convert function:
use std::cmp::PartialEq;
use std::ops::Shl;
use std::ops::BitXor;
fn convert<T: PartialEq + From<bool> + BitXor + Shl>(bits: &[T]) -> T {
let zero = T::from(false);
let one = T::from(true);
bits.iter()
.filter(|&&bit| bit == zero || bit == one)
.fold(zero, |result, bit| (result << one) ^ bit)
}
But this implementation gives this error during compilation:
10 | .fold(zero, |result, bit| (result << one) ^ bit)
| --------------- ^ --- &T
| |
| <T as std::ops::Shl>::Output
|
= note: an implementation of `std::ops::BitXor` might be
missing for `<T as std::ops::Shl>::Output`
Hm, this doesn’t seem right. According to trait declaration, it has Output
defined as Associated type:
pub trait Shl<Rhs = Self> {
type Output;
fn shl(self, rhs: Rhs) -> Self::Output;
}
Based on documentation BitXor
has an implementation for bool
, which doesn’t have implementation for output of <<
operator.
Associated type is syntactic sugar and thankfully, Rust allows putting constraints on generic function implementation:
use std::cmp::PartialEq;
use std::ops::BitXor;
use std::ops::Shl;
fn convert<T: PartialEq + From<bool> + BitXor<Output = T> + Shl<Output = T> + Clone>(
bits: &[T],
) -> T {
let zero = T::from(false);
let one = T::from(true);
bits.iter()
.filter(|&bit| bit == &zero.clone() || bit == &one.clone())
.fold(zero.clone(), |result, bit| {
(result << one.clone()) ^ bit.clone()
})
}
It also has a couple of other fixes, but generally, this is a working solution for all integer types.
Improvements
What if I want to add the following to my function:
- Check if the input vector is larger than the target integer
- Check if the vector has only ones and zeros
- Reduce memory footprint, by changing input vector always be u8, and the return type will be derived from caller definition.
The first thing we need to do, is to change output type from plain T
to Result<T, ConversionError>
where ConversionError
is an enum, that holds all error types we have:
#[derive(Debug)]
pub enum ConversionError {
Overflow,
NonBinaryInput,
}
To check the size of a generic type we will use std::mem::size_of
like that:
if bits.len() > (std::mem::size_of::<T>() * 8) {
return Err(ConversionError::Overflow);
}
Changing the input vector to always be u8
as the smallest integer is a questionable change because it forces always explicitly declare a return type, but I believe it is a fair deal for reducing memory size.
This is how it should be called now:
let bits: Vec<u8> = vec![0, 0, 0, 0, 0, 1, 0, 1];
let result: Result<u32, ConversionError> = convert(&bits);
assert_eq!(result.unwrap(), 5);
Function signature will change mainly in From<boolean
> to From<u8>
:
pub fn convert<T: PartialEq + From<u8> + BitXor<Output=T> + Shl<Output=T> + Clone>(
bits: &[u8],
) -> Result<T, ConversionError> {
Checking that input has only zeros and ones is optional, but can be done with simple filtering and counting:
if bits.iter()
.filter(|&&bit| bit != 0 && bit != 1).count() > 0 {
return Err(ConversionError::NonBinaryInput);
}
The final solution with all latest changes:
use std::cmp::PartialEq;
use std::ops::BitXor;
use std::ops::Shl;
#[derive(Debug)]
pub enum ConversionError {
Overflow,
NonBinaryInput,
}
pub fn convert<T: PartialEq + From<u8> + BitXor<Output=T> + Shl<Output=T> + Clone>(
bits: &[u8],
) -> Result<T, ConversionError> {
if bits.len() > (std::mem::size_of::<T>() * 8) {
return Err(ConversionError::Overflow);
}
if bits.iter()
.filter(|&&bit| bit != 0 && bit != 1).count() > 0 {
return Err(ConversionError::NonBinaryInput);
}
Ok(bits.iter()
.fold(T::from(0), |result, &bit| {
(result << T::from(1)) ^ T::from(bit)
}))
}
Top comments (2)
you can replace
filter().count() > 0
with any() (or all() and reversed predicate) and it will shortcircuit: exit as soon as you find first non 0 and non 1, thus making this whole walk through vector fasterThanks Shamir, it is a good improvement!