From 9601724b11bbd9081b1bee6f7e478a5d2b9ace41 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Sun, 4 Oct 2020 12:39:39 +0000 Subject: [PATCH] Avoid unchecked casts in net parser --- library/std/src/net/parser.rs | 69 +++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/library/std/src/net/parser.rs b/library/std/src/net/parser.rs index 0570a7c41bf..da94a573503 100644 --- a/library/std/src/net/parser.rs +++ b/library/std/src/net/parser.rs @@ -6,11 +6,34 @@ #[cfg(test)] mod tests; +use crate::convert::TryInto as _; use crate::error::Error; use crate::fmt; use crate::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use crate::str::FromStr; +trait ReadNumberHelper: crate::marker::Sized { + const ZERO: Self; + fn checked_mul(&self, other: u32) -> Option; + fn checked_add(&self, other: u32) -> Option; +} + +macro_rules! impl_helper { + ($($t:ty)*) => ($(impl ReadNumberHelper for $t { + const ZERO: Self = 0; + #[inline] + fn checked_mul(&self, other: u32) -> Option { + Self::checked_mul(*self, other.try_into().ok()?) + } + #[inline] + fn checked_add(&self, other: u32) -> Option { + Self::checked_add(*self, other.try_into().ok()?) + } + })*) +} + +impl_helper! { u8 u16 } + struct Parser<'a> { // parsing as ASCII, so can use byte array state: &'a [u8], @@ -59,7 +82,7 @@ impl<'a> Parser<'a> { fn read_char(&mut self) -> Option { self.state.split_first().map(|(&b, tail)| { self.state = tail; - b as char + char::from(b) }) } @@ -84,25 +107,26 @@ impl<'a> Parser<'a> { }) } - // Read a single digit in the given radix. For instance, 0-9 in radix 10; - // 0-9A-F in radix 16. - fn read_digit(&mut self, radix: u32) -> Option { - self.read_atomically(move |p| p.read_char()?.to_digit(radix)) - } - // Read a number off the front of the input in the given radix, stopping // at the first non-digit character or eof. Fails if the number has more - // digits than max_digits, or the value is >= upto, or if there is no number. - fn read_number(&mut self, radix: u32, max_digits: u32, upto: u32) -> Option { + // digits than max_digits or if there is no number. + fn read_number( + &mut self, + radix: u32, + max_digits: Option, + ) -> Option { self.read_atomically(move |p| { - let mut result = 0; + let mut result = T::ZERO; let mut digit_count = 0; - while let Some(digit) = p.read_digit(radix) { - result = (result * radix) + digit; + while let Some(digit) = p.read_atomically(|p| p.read_char()?.to_digit(radix)) { + result = result.checked_mul(radix)?; + result = result.checked_add(digit)?; digit_count += 1; - if digit_count > max_digits || result >= upto { - return None; + if let Some(max_digits) = max_digits { + if digit_count > max_digits { + return None; + } } } @@ -116,7 +140,7 @@ impl<'a> Parser<'a> { let mut groups = [0; 4]; for (i, slot) in groups.iter_mut().enumerate() { - *slot = p.read_separator('.', i, |p| p.read_number(10, 3, 0x100))? as u8; + *slot = p.read_separator('.', i, |p| p.read_number(10, None))?; } Some(groups.into()) @@ -140,17 +164,17 @@ impl<'a> Parser<'a> { let ipv4 = p.read_separator(':', i, |p| p.read_ipv4_addr()); if let Some(v4_addr) = ipv4 { - let octets = v4_addr.octets(); - groups[i + 0] = ((octets[0] as u16) << 8) | (octets[1] as u16); - groups[i + 1] = ((octets[2] as u16) << 8) | (octets[3] as u16); + let [one, two, three, four] = v4_addr.octets(); + groups[i + 0] = u16::from_be_bytes([one, two]); + groups[i + 1] = u16::from_be_bytes([three, four]); return (i + 2, true); } } - let group = p.read_separator(':', i, |p| p.read_number(16, 4, 0x10000)); + let group = p.read_separator(':', i, |p| p.read_number(16, Some(4))); match group { - Some(g) => *slot = g as u16, + Some(g) => *slot = g, None => return (i, false), } } @@ -195,12 +219,11 @@ impl<'a> Parser<'a> { self.read_ipv4_addr().map(IpAddr::V4).or_else(move || self.read_ipv6_addr().map(IpAddr::V6)) } - /// Read a : followed by a port in base 10 + /// Read a : followed by a port in base 10. fn read_port(&mut self) -> Option { self.read_atomically(|p| { let _ = p.read_given_char(':')?; - let port = p.read_number(10, 5, 0x10000)?; - Some(port as u16) + p.read_number(10, None) }) }