From 72aecbfd01e73da710ee35826f28f41d5d0cfebe Mon Sep 17 00:00:00 2001 From: Tyson Nottingham Date: Tue, 15 Dec 2020 14:18:02 -0800 Subject: [PATCH] BufWriter: handle possibility of overflow --- library/std/src/io/buffered/bufwriter.rs | 54 +++++++++++++++++------- 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/library/std/src/io/buffered/bufwriter.rs b/library/std/src/io/buffered/bufwriter.rs index f0ff186d99b..a9fc450de31 100644 --- a/library/std/src/io/buffered/bufwriter.rs +++ b/library/std/src/io/buffered/bufwriter.rs @@ -190,7 +190,7 @@ impl BufWriter { /// data. Writes as much as possible without exceeding capacity. Returns /// the number of bytes written. pub(super) fn write_to_buf(&mut self, buf: &[u8]) -> usize { - let available = self.buf.capacity() - self.buf.len(); + let available = self.spare_capacity(); let amt_to_buffer = available.min(buf.len()); // SAFETY: `amt_to_buffer` is <= buffer's spare capacity by construction. @@ -353,7 +353,7 @@ impl BufWriter { // or their write patterns are somewhat pathological. #[inline(never)] fn write_cold(&mut self, buf: &[u8]) -> io::Result { - if self.buf.len() + buf.len() > self.buf.capacity() { + if buf.len() > self.spare_capacity() { self.flush_buf()?; } @@ -371,7 +371,7 @@ impl BufWriter { // SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and // we entered this else block because `buf.len() < self.buf.capacity()`. - // Therefore, `self.buf.len() + buf.len() <= self.buf.capacity()`. + // Therefore, `buf.len() <= self.buf.capacity() - self.buf.len()`. unsafe { self.write_to_buffer_unchecked(buf); } @@ -391,7 +391,8 @@ impl BufWriter { // by calling `self.get_mut().write_all()` directly, which avoids // round trips through the buffer in the event of a series of partial // writes in some circumstances. - if self.buf.len() + buf.len() > self.buf.capacity() { + + if buf.len() > self.spare_capacity() { self.flush_buf()?; } @@ -409,7 +410,7 @@ impl BufWriter { // SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and // we entered this else block because `buf.len() < self.buf.capacity()`. - // Therefore, `self.buf.len() + buf.len() <= self.buf.capacity()`. + // Therefore, `buf.len() <= self.buf.capacity() - self.buf.len()`. unsafe { self.write_to_buffer_unchecked(buf); } @@ -418,11 +419,11 @@ impl BufWriter { } } - // SAFETY: Requires `self.buf.len() + buf.len() <= self.buf.capacity()`, + // SAFETY: Requires `buf.len() <= self.buf.capacity() - self.buf.len()`, // i.e., that input buffer length is less than or equal to spare capacity. #[inline(always)] unsafe fn write_to_buffer_unchecked(&mut self, buf: &[u8]) { - debug_assert!(self.buf.len() + buf.len() <= self.buf.capacity()); + debug_assert!(buf.len() <= self.spare_capacity()); let old_len = self.buf.len(); let buf_len = buf.len(); let src = buf.as_ptr(); @@ -430,6 +431,11 @@ impl BufWriter { ptr::copy_nonoverlapping(src, dst, buf_len); self.buf.set_len(old_len + buf_len); } + + #[inline] + fn spare_capacity(&self) -> usize { + self.buf.capacity() - self.buf.len() + } } #[unstable(feature = "bufwriter_into_raw_parts", issue = "80690")] @@ -505,7 +511,7 @@ impl Write for BufWriter { fn write(&mut self, buf: &[u8]) -> io::Result { // Use < instead of <= to avoid a needless trip through the buffer in some cases. // See `write_cold` for details. - if self.buf.len() + buf.len() < self.buf.capacity() { + if buf.len() < self.spare_capacity() { // SAFETY: safe by above conditional. unsafe { self.write_to_buffer_unchecked(buf); @@ -521,7 +527,7 @@ impl Write for BufWriter { fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { // Use < instead of <= to avoid a needless trip through the buffer in some cases. // See `write_all_cold` for details. - if self.buf.len() + buf.len() < self.buf.capacity() { + if buf.len() < self.spare_capacity() { // SAFETY: safe by above conditional. unsafe { self.write_to_buffer_unchecked(buf); @@ -537,16 +543,31 @@ impl Write for BufWriter { // FIXME: Consider applying `#[inline]` / `#[inline(never)]` optimizations already applied // to `write` and `write_all`. The performance benefits can be significant. See #79930. if self.get_ref().is_write_vectored() { - let total_len = bufs.iter().map(|b| b.len()).sum::(); - if self.buf.len() + total_len > self.buf.capacity() { + // We have to handle the possibility that the total length of the buffers overflows + // `usize` (even though this can only happen if multiple `IoSlice`s reference the + // same underlying buffer, as otherwise the buffers wouldn't fit in memory). If the + // computation overflows, then surely the input cannot fit in our buffer, so we forward + // to the inner writer's `write_vectored` method to let it handle it appropriately. + let saturated_total_len = + bufs.iter().fold(0usize, |acc, b| acc.saturating_add(b.len())); + + if saturated_total_len > self.spare_capacity() { + // Flush if the total length of the input exceeds our buffer's spare capacity. + // If we would have overflowed, this condition also holds, and we need to flush. self.flush_buf()?; } - if total_len >= self.buf.capacity() { + + if saturated_total_len >= self.buf.capacity() { + // Forward to our inner writer if the total length of the input is greater than or + // equal to our buffer capacity. If we would have overflowed, this condition also + // holds, and we punt to the inner writer. self.panicked = true; let r = self.get_mut().write_vectored(bufs); self.panicked = false; r } else { + // `saturated_total_len < self.buf.capacity()` implies that we did not saturate. + // SAFETY: We checked whether or not the spare capacity was large enough above. If // it was, then we're safe already. If it wasn't, we flushed, making sufficient // room for any input <= the buffer size, which includes this input. @@ -554,14 +575,14 @@ impl Write for BufWriter { bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b)); }; - Ok(total_len) + Ok(saturated_total_len) } } else { let mut iter = bufs.iter(); let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) { // This is the first non-empty slice to write, so if it does // not fit in the buffer, we still get to flush and proceed. - if self.buf.len() + buf.len() > self.buf.capacity() { + if buf.len() > self.spare_capacity() { self.flush_buf()?; } if buf.len() >= self.buf.capacity() { @@ -586,12 +607,15 @@ impl Write for BufWriter { }; debug_assert!(total_written != 0); for buf in iter { - if self.buf.len() + buf.len() <= self.buf.capacity() { + if buf.len() <= self.spare_capacity() { // SAFETY: safe by above conditional. unsafe { self.write_to_buffer_unchecked(buf); } + // This cannot overflow `usize`. If we are here, we've written all of the bytes + // so far to our buffer, and we've ensured that we never exceed the buffer's + // capacity. Therefore, `total_written` <= `self.buf.capacity()` <= `usize::MAX`. total_written += buf.len(); } else { break;