use super::monty::monty_modpow;
use super::BigUint;
use crate::big_digit::{self, BigDigit};
use num_integer::Integer;
use num_traits::{One, Pow, ToPrimitive, Zero};
impl<'b> Pow<&'b BigUint> for BigUint {
type Output = BigUint;
#[inline]
fn pow(self, exp: &BigUint) -> BigUint {
if self.is_one() || exp.is_zero() {
BigUint::one()
} else if self.is_zero() {
BigUint::zero()
} else if let Some(exp) = exp.to_u64() {
self.pow(exp)
} else if let Some(exp) = exp.to_u128() {
self.pow(exp)
} else {
panic!("memory overflow")
}
}
}
impl Pow<BigUint> for BigUint {
type Output = BigUint;
#[inline]
fn pow(self, exp: BigUint) -> BigUint {
Pow::pow(self, &exp)
}
}
impl<'a, 'b> Pow<&'b BigUint> for &'a BigUint {
type Output = BigUint;
#[inline]
fn pow(self, exp: &BigUint) -> BigUint {
if self.is_one() || exp.is_zero() {
BigUint::one()
} else if self.is_zero() {
BigUint::zero()
} else {
self.clone().pow(exp)
}
}
}
impl<'a> Pow<BigUint> for &'a BigUint {
type Output = BigUint;
#[inline]
fn pow(self, exp: BigUint) -> BigUint {
Pow::pow(self, &exp)
}
}
macro_rules! pow_impl {
($T:ty) => {
impl Pow<$T> for BigUint {
type Output = BigUint;
fn pow(self, mut exp: $T) -> BigUint {
if exp == 0 {
return BigUint::one();
}
let mut base = self;
while exp & 1 == 0 {
base = &base * &base;
exp >>= 1;
}
if exp == 1 {
return base;
}
let mut acc = base.clone();
while exp > 1 {
exp >>= 1;
base = &base * &base;
if exp & 1 == 1 {
acc = &acc * &base;
}
}
acc
}
}
impl<'b> Pow<&'b $T> for BigUint {
type Output = BigUint;
#[inline]
fn pow(self, exp: &$T) -> BigUint {
Pow::pow(self, *exp)
}
}
impl<'a> Pow<$T> for &'a BigUint {
type Output = BigUint;
#[inline]
fn pow(self, exp: $T) -> BigUint {
if exp == 0 {
return BigUint::one();
}
Pow::pow(self.clone(), exp)
}
}
impl<'a, 'b> Pow<&'b $T> for &'a BigUint {
type Output = BigUint;
#[inline]
fn pow(self, exp: &$T) -> BigUint {
Pow::pow(self, *exp)
}
}
};
}
pow_impl!(u8);
pow_impl!(u16);
pow_impl!(u32);
pow_impl!(u64);
pow_impl!(usize);
pow_impl!(u128);
pub(super) fn modpow(x: &BigUint, exponent: &BigUint, modulus: &BigUint) -> BigUint {
assert!(
!modulus.is_zero(),
"attempt to calculate with zero modulus!"
);
if modulus.is_odd() {
monty_modpow(x, exponent, modulus)
} else {
plain_modpow(x, &exponent.data, modulus)
}
}
fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint {
assert!(
!modulus.is_zero(),
"attempt to calculate with zero modulus!"
);
let i = match exp_data.iter().position(|&r| r != 0) {
None => return BigUint::one(),
Some(i) => i,
};
let mut base = base % modulus;
for _ in 0..i {
for _ in 0..big_digit::BITS {
base = &base * &base % modulus;
}
}
let mut r = exp_data[i];
let mut b = 0u8;
while r.is_even() {
base = &base * &base % modulus;
r >>= 1;
b += 1;
}
let mut exp_iter = exp_data[i + 1..].iter();
if exp_iter.len() == 0 && r.is_one() {
return base;
}
let mut acc = base.clone();
r >>= 1;
b += 1;
{
let mut unit = |exp_is_odd| {
base = &base * &base % modulus;
if exp_is_odd {
acc = &acc * &base % modulus;
}
};
if let Some(&last) = exp_iter.next_back() {
for _ in b..big_digit::BITS {
unit(r.is_odd());
r >>= 1;
}
for &r in exp_iter {
let mut r = r;
for _ in 0..big_digit::BITS {
unit(r.is_odd());
r >>= 1;
}
}
r = last;
}
debug_assert_ne!(r, 0);
while !r.is_zero() {
unit(r.is_odd());
r >>= 1;
}
}
acc
}
#[test]
fn test_plain_modpow() {
let two = &BigUint::from(2u32);
let modulus = BigUint::from(0x1100u32);
let exp = vec![0, 0b1];
assert_eq!(
two.pow(0b1_00000000_u32) % &modulus,
plain_modpow(&two, &exp, &modulus)
);
let exp = vec![0, 0b10];
assert_eq!(
two.pow(0b10_00000000_u32) % &modulus,
plain_modpow(&two, &exp, &modulus)
);
let exp = vec![0, 0b110010];
assert_eq!(
two.pow(0b110010_00000000_u32) % &modulus,
plain_modpow(&two, &exp, &modulus)
);
let exp = vec![0b1, 0b1];
assert_eq!(
two.pow(0b1_00000001_u32) % &modulus,
plain_modpow(&two, &exp, &modulus)
);
let exp = vec![0b1100, 0, 0b1];
assert_eq!(
two.pow(0b1_00000000_00001100_u32) % &modulus,
plain_modpow(&two, &exp, &modulus)
);
}
#[test]
fn test_pow_biguint() {
let base = BigUint::from(5u8);
let exponent = BigUint::from(3u8);
assert_eq!(BigUint::from(125u8), base.pow(exponent));
}