diff --git a/soroban-sdk/src/num.rs b/soroban-sdk/src/num.rs index e96b729af..7a5b59b75 100644 --- a/soroban-sdk/src/num.rs +++ b/soroban-sdk/src/num.rs @@ -2,8 +2,8 @@ use core::{cmp::Ordering, convert::Infallible, fmt::Debug}; use super::{ env::internal::{ - DurationSmall, DurationVal, Env as _, I256Small, I256Val, TimepointSmall, TimepointVal, - U256Small, U256Val, + DurationSmall, DurationVal, Env as _, I256Object, I256Small, I256Val, TimepointSmall, + TimepointVal, U256Object, U256Small, U256Val, }, Bytes, ConversionError, Env, TryFromVal, TryIntoVal, Val, }; @@ -227,10 +227,9 @@ impl U256 { } pub fn from_u128(env: &Env, u: u128) -> Self { - let lo: Bytes = Bytes::from_array(env, &u.to_be_bytes()); - let mut bytes: Bytes = Bytes::from_array(env, &[0u8; 16]); - bytes.append(&lo); - Self::from_be_bytes(env, &bytes) + let lo_hi = (u >> 64) as u64; + let lo_lo = u as u64; + Self::from_parts(env, 0, 0, lo_hi, lo_lo) } pub fn from_parts(env: &Env, hi_hi: u64, hi_lo: u64, lo_hi: u64, lo_lo: u64) -> Self { @@ -254,14 +253,23 @@ impl U256 { } pub fn to_u128(&self) -> Option { - let be_bytes = self.to_be_bytes(); - let be_bytes_hi: [u8; 16] = be_bytes.slice(0..16).try_into().unwrap(); - let be_bytes_lo: [u8; 16] = be_bytes.slice(16..32).try_into().unwrap(); - if u128::from_be_bytes(be_bytes_hi) == 0 { - Some(u128::from_be_bytes(be_bytes_lo)) - } else { - None + let v = *self.val.as_val(); + + // If v is U256Small it can be converted directly + if let Ok(small) = U256Small::try_from(v) { + return Some(u64::from(small) as u128); + } + + // Otherwise use U256Object and take low sections if high are empty + let obj: U256Object = v.try_into().ok()?; + let hi_hi = self.env.obj_to_u256_hi_hi(obj).unwrap_infallible(); + let hi_lo = self.env.obj_to_u256_hi_lo(obj).unwrap_infallible(); + if hi_hi != 0 || hi_lo != 0 { + return None; } + let lo_hi = self.env.obj_to_u256_lo_hi(obj).unwrap_infallible(); + let lo_lo = self.env.obj_to_u256_lo_lo(obj).unwrap_infallible(); + Some(((lo_hi as u128) << 64) | (lo_lo as u128)) } pub fn to_be_bytes(&self) -> Bytes { @@ -474,16 +482,14 @@ impl I256 { } pub fn from_i128(env: &Env, i: i128) -> Self { - let lo: Bytes = Bytes::from_array(env, &i.to_be_bytes()); - if i < 0 { - let mut i256_bytes: Bytes = Bytes::from_array(env, &[255_u8; 16]); - i256_bytes.append(&lo); - Self::from_be_bytes(env, &i256_bytes) + let lo_hi = (i >> 64) as u64; + let lo_lo = i as u64; + let (hi_hi, hi_lo) = if i < 0 { + (-1_i64, u64::MAX) // sign extend 1 bit } else { - let mut i256_bytes: Bytes = Bytes::from_array(env, &[0_u8; 16]); - i256_bytes.append(&lo); - Self::from_be_bytes(env, &i256_bytes) - } + (0_i64, 0_u64) // sign extend 0 bit + }; + I256::from_parts(env, hi_hi, hi_lo, lo_hi, lo_lo) } pub fn from_parts(env: &Env, hi_hi: i64, hi_lo: u64, lo_hi: u64, lo_lo: u64) -> Self { @@ -507,13 +513,27 @@ impl I256 { } pub fn to_i128(&self) -> Option { - let be_bytes = self.to_be_bytes(); - let be_bytes_hi: [u8; 16] = be_bytes.slice(0..16).try_into().unwrap(); - let be_bytes_lo: [u8; 16] = be_bytes.slice(16..32).try_into().unwrap(); - let i128_hi = i128::from_be_bytes(be_bytes_hi); - let i128_lo = i128::from_be_bytes(be_bytes_lo); - if (i128_hi == 0 && i128_lo >= 0) || (i128_hi == -1 && i128_lo < 0) { - Some(i128_lo) + let v = *self.val.as_val(); + + // If v is I256Small it can be converted directly + if let Ok(small) = I256Small::try_from(v) { + return Some(i64::from(small) as i128); + } + + // Otherwise use I256Object and take low sections if high are either + // all 1 bits or all 0 bits (negative and positive respectively) + let obj: I256Object = v.try_into().ok()?; + let hi_hi = self.env.obj_to_i256_hi_hi(obj).unwrap_infallible(); + let hi_lo = self.env.obj_to_i256_hi_lo(obj).unwrap_infallible(); + let lo_hi = self.env.obj_to_i256_lo_hi(obj).unwrap_infallible(); + let lo_lo = self.env.obj_to_i256_lo_lo(obj).unwrap_infallible(); + // The low 128 bits as an i128 + let lo = (((lo_hi as u128) << 64) | (lo_lo as u128)) as i128; + + if lo < 0 && hi_hi == -1 && hi_lo == u64::MAX { + Some(lo) // if negative low, then high must be all 1 bit + } else if 0 <= lo && hi_hi == 0 && hi_lo == 0 { + Some(lo) // if non-negative low, then high must be all 0 bit } else { None }