use alloc::vec;
use alloc::vec::Vec;
use rand_core::{CryptoRng, RngCore};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
use zeroize::Zeroizing;
use crate::errors::{Error, Result};
use crate::hash::Hash;
use crate::key::{self, PrivateKey, PublicKey};
#[inline]
pub fn encrypt<R: RngCore + CryptoRng, PK: PublicKey>(
rng: &mut R,
pub_key: &PK,
msg: &[u8],
) -> Result<Vec<u8>> {
key::check_public(pub_key)?;
let k = pub_key.size();
if msg.len() > k - 11 {
return Err(Error::MessageTooLong);
}
let mut em = Zeroizing::new(vec![0u8; k]);
em[1] = 2;
non_zero_random_bytes(rng, &mut em[2..k - msg.len() - 1]);
em[k - msg.len() - 1] = 0;
em[k - msg.len()..].copy_from_slice(msg);
pub_key.raw_encryption_primitive(&em, pub_key.size())
}
#[inline]
pub fn decrypt<R: RngCore + CryptoRng, SK: PrivateKey>(
rng: Option<&mut R>,
priv_key: &SK,
ciphertext: &[u8],
) -> Result<Vec<u8>> {
key::check_public(priv_key)?;
let (valid, out, index) = decrypt_inner(rng, priv_key, ciphertext)?;
if valid == 0 {
return Err(Error::Decryption);
}
Ok(out[index as usize..].to_vec())
}
#[inline]
pub fn sign<R: RngCore + CryptoRng, SK: PrivateKey>(
rng: Option<&mut R>,
priv_key: &SK,
hash: Option<&Hash>,
hashed: &[u8],
) -> Result<Vec<u8>> {
let (hash_len, prefix) = hash_info(hash, hashed.len())?;
let t_len = prefix.len() + hash_len;
let k = priv_key.size();
if k < t_len + 11 {
return Err(Error::MessageTooLong);
}
let mut em = vec![0xff; k];
em[0] = 0;
em[1] = 1;
em[k - t_len - 1] = 0;
em[k - t_len..k - hash_len].copy_from_slice(&prefix);
em[k - hash_len..k].copy_from_slice(hashed);
priv_key.raw_decryption_primitive(rng, &em, priv_key.size())
}
#[inline]
pub fn verify<PK: PublicKey>(
pub_key: &PK,
hash: Option<&Hash>,
hashed: &[u8],
sig: &[u8],
) -> Result<()> {
let (hash_len, prefix) = hash_info(hash, hashed.len())?;
let t_len = prefix.len() + hash_len;
let k = pub_key.size();
if k < t_len + 11 {
return Err(Error::Verification);
}
let em = pub_key.raw_encryption_primitive(sig, pub_key.size())?;
let mut ok = em[0].ct_eq(&0u8);
ok &= em[1].ct_eq(&1u8);
ok &= em[k - hash_len..k].ct_eq(hashed);
ok &= em[k - t_len..k - hash_len].ct_eq(&prefix);
ok &= em[k - t_len - 1].ct_eq(&0u8);
for el in em.iter().skip(2).take(k - t_len - 3) {
ok &= el.ct_eq(&0xff)
}
if ok.unwrap_u8() != 1 {
return Err(Error::Verification);
}
Ok(())
}
#[inline]
fn hash_info(hash: Option<&Hash>, digest_len: usize) -> Result<(usize, &'static [u8])> {
match hash {
Some(hash) => {
let hash_len = hash.size();
if digest_len != hash_len {
return Err(Error::InputNotHashed);
}
Ok((hash_len, hash.asn1_prefix()))
}
None => Ok((digest_len, &[])),
}
}
#[inline]
fn decrypt_inner<R: RngCore + CryptoRng, SK: PrivateKey>(
rng: Option<&mut R>,
priv_key: &SK,
ciphertext: &[u8],
) -> Result<(u8, Vec<u8>, u32)> {
let k = priv_key.size();
if k < 11 {
return Err(Error::Decryption);
}
let em = priv_key.raw_decryption_primitive(rng, ciphertext, priv_key.size())?;
let first_byte_is_zero = em[0].ct_eq(&0u8);
let second_byte_is_two = em[1].ct_eq(&2u8);
let mut looking_for_index = 1u8;
let mut index = 0u32;
for (i, el) in em.iter().enumerate().skip(2) {
let equals0 = el.ct_eq(&0u8);
index.conditional_assign(&(i as u32), Choice::from(looking_for_index) & equals0);
looking_for_index.conditional_assign(&0u8, equals0);
}
let valid_ps = Choice::from((((2i32 + 8i32 - index as i32 - 1i32) >> 31) & 1) as u8);
let valid =
first_byte_is_zero & second_byte_is_two & Choice::from(!looking_for_index & 1) & valid_ps;
index = u32::conditional_select(&0, &(index + 1), valid);
Ok((valid.unwrap_u8(), em, index))
}
#[inline]
fn non_zero_random_bytes<R: RngCore + CryptoRng>(rng: &mut R, data: &mut [u8]) {
rng.fill_bytes(data);
for el in data {
if *el == 0u8 {
while *el == 0u8 {
rng.fill_bytes(core::slice::from_mut(el));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64ct::{Base64, Encoding};
use hex_literal::hex;
use num_bigint::BigUint;
use num_traits::FromPrimitive;
use num_traits::Num;
use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng};
use sha1::{Digest, Sha1};
use crate::{Hash, PaddingScheme, PublicKey, PublicKeyParts, RsaPrivateKey, RsaPublicKey};
#[test]
fn test_non_zero_bytes() {
for _ in 0..10 {
let mut rng = ChaCha8Rng::from_seed([42; 32]);
let mut b = vec![0u8; 512];
non_zero_random_bytes(&mut rng, &mut b);
for el in &b {
assert_ne!(*el, 0u8);
}
}
}
fn get_private_key() -> RsaPrivateKey {
RsaPrivateKey::from_components(
BigUint::from_str_radix("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077", 10).unwrap(),
BigUint::from_u64(65537).unwrap(),
BigUint::from_str_radix("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861", 10).unwrap(),
vec![
BigUint::from_str_radix("98920366548084643601728869055592650835572950932266967461790948584315647051443",10).unwrap(),
BigUint::from_str_radix("94560208308847015747498523884063394671606671904944666360068158221458669711639", 10).unwrap()
],
)
}
#[test]
fn test_decrypt_pkcs1v15() {
let priv_key = get_private_key();
let tests = [[
"gIcUIoVkD6ATMBk/u/nlCZCCWRKdkfjCgFdo35VpRXLduiKXhNz1XupLLzTXAybEq15juc+EgY5o0DHv/nt3yg==",
"x",
], [
"Y7TOCSqofGhkRb+jaVRLzK8xw2cSo1IVES19utzv6hwvx+M8kFsoWQm5DzBeJCZTCVDPkTpavUuEbgp8hnUGDw==",
"testing.",
], [
"arReP9DJtEVyV2Dg3dDp4c/PSk1O6lxkoJ8HcFupoRorBZG+7+1fDAwT1olNddFnQMjmkb8vxwmNMoTAT/BFjQ==",
"testing.\n",
], [
"WtaBXIoGC54+vH0NH0CHHE+dRDOsMc/6BrfFu2lEqcKL9+uDuWaf+Xj9mrbQCjjZcpQuX733zyok/jsnqe/Ftw==",
"01234567890123456789012345678901234567890123456789012",
]];
for test in &tests {
let out = priv_key
.decrypt(
PaddingScheme::new_pkcs1v15_encrypt(),
&Base64::decode_vec(test[0]).unwrap(),
)
.unwrap();
assert_eq!(out, test[1].as_bytes());
}
}
#[test]
fn test_encrypt_decrypt_pkcs1v15() {
let mut rng = ChaCha8Rng::from_seed([42; 32]);
let priv_key = get_private_key();
let k = priv_key.size();
for i in 1..100 {
let mut input = vec![0u8; i * 8];
rng.fill_bytes(&mut input);
if input.len() > k - 11 {
input = input[0..k - 11].to_vec();
}
let pub_key: RsaPublicKey = priv_key.clone().into();
let ciphertext = encrypt(&mut rng, &pub_key, &input).unwrap();
assert_ne!(input, ciphertext);
let blind: bool = rng.next_u32() < (1u32 << 31);
let blinder = if blind { Some(&mut rng) } else { None };
let plaintext = decrypt(blinder, &priv_key, &ciphertext).unwrap();
assert_eq!(input, plaintext);
}
}
#[test]
fn test_sign_pkcs1v15() {
let priv_key = get_private_key();
let tests = [(
"Test.\n",
hex!(
"a4f3fa6ea93bcdd0c57be020c1193ecbfd6f200a3d95c409769b029578fa0e33"
"6ad9a347600e40d3ae823b8c7e6bad88cc07c1d54c3a1523cbbb6d58efc362ae"
),
)];
for (text, expected) in &tests {
let digest = Sha1::digest(text.as_bytes()).to_vec();
let out = priv_key
.sign(PaddingScheme::new_pkcs1v15_sign(Some(Hash::SHA1)), &digest)
.unwrap();
assert_ne!(out, digest);
assert_eq!(out, expected);
let mut rng = ChaCha8Rng::from_seed([42; 32]);
let out2 = priv_key
.sign_blinded(
&mut rng,
PaddingScheme::new_pkcs1v15_sign(Some(Hash::SHA1)),
&digest,
)
.unwrap();
assert_eq!(out2, expected);
}
}
#[test]
fn test_verify_pkcs1v15() {
let priv_key = get_private_key();
let tests = [(
"Test.\n",
hex!(
"a4f3fa6ea93bcdd0c57be020c1193ecbfd6f200a3d95c409769b029578fa0e33"
"6ad9a347600e40d3ae823b8c7e6bad88cc07c1d54c3a1523cbbb6d58efc362ae"
),
)];
let pub_key: RsaPublicKey = priv_key.into();
for (text, sig) in &tests {
let digest = Sha1::digest(text.as_bytes()).to_vec();
pub_key
.verify(
PaddingScheme::new_pkcs1v15_sign(Some(Hash::SHA1)),
&digest,
sig,
)
.expect("failed to verify");
}
}
#[test]
fn test_unpadded_signature() {
let msg = b"Thu Dec 19 18:06:16 EST 2013\n";
let expected_sig = Base64::decode_vec("pX4DR8azytjdQ1rtUiC040FjkepuQut5q2ZFX1pTjBrOVKNjgsCDyiJDGZTCNoh9qpXYbhl7iEym30BWWwuiZg==").unwrap();
let priv_key = get_private_key();
let sig = priv_key
.sign(PaddingScheme::new_pkcs1v15_sign(None), msg)
.unwrap();
assert_eq!(expected_sig, sig);
let pub_key: RsaPublicKey = priv_key.into();
pub_key
.verify(PaddingScheme::new_pkcs1v15_sign(None), msg, &sig)
.expect("failed to verify");
}
}