BufWriter: handle possibility of overflow
This commit is contained in:
parent
5fd9372c11
commit
72aecbfd01
1 changed files with 39 additions and 15 deletions
|
@ -190,7 +190,7 @@ impl<W: Write> BufWriter<W> {
|
||||||
/// data. Writes as much as possible without exceeding capacity. Returns
|
/// data. Writes as much as possible without exceeding capacity. Returns
|
||||||
/// the number of bytes written.
|
/// the number of bytes written.
|
||||||
pub(super) fn write_to_buf(&mut self, buf: &[u8]) -> usize {
|
pub(super) fn write_to_buf(&mut self, buf: &[u8]) -> usize {
|
||||||
let available = self.buf.capacity() - self.buf.len();
|
let available = self.spare_capacity();
|
||||||
let amt_to_buffer = available.min(buf.len());
|
let amt_to_buffer = available.min(buf.len());
|
||||||
|
|
||||||
// SAFETY: `amt_to_buffer` is <= buffer's spare capacity by construction.
|
// SAFETY: `amt_to_buffer` is <= buffer's spare capacity by construction.
|
||||||
|
@ -353,7 +353,7 @@ impl<W: Write> BufWriter<W> {
|
||||||
// or their write patterns are somewhat pathological.
|
// or their write patterns are somewhat pathological.
|
||||||
#[inline(never)]
|
#[inline(never)]
|
||||||
fn write_cold(&mut self, buf: &[u8]) -> io::Result<usize> {
|
fn write_cold(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
if self.buf.len() + buf.len() > self.buf.capacity() {
|
if buf.len() > self.spare_capacity() {
|
||||||
self.flush_buf()?;
|
self.flush_buf()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -371,7 +371,7 @@ impl<W: Write> BufWriter<W> {
|
||||||
|
|
||||||
// SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and
|
// SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and
|
||||||
// we entered this else block because `buf.len() < self.buf.capacity()`.
|
// we entered this else block because `buf.len() < self.buf.capacity()`.
|
||||||
// Therefore, `self.buf.len() + buf.len() <= self.buf.capacity()`.
|
// Therefore, `buf.len() <= self.buf.capacity() - self.buf.len()`.
|
||||||
unsafe {
|
unsafe {
|
||||||
self.write_to_buffer_unchecked(buf);
|
self.write_to_buffer_unchecked(buf);
|
||||||
}
|
}
|
||||||
|
@ -391,7 +391,8 @@ impl<W: Write> BufWriter<W> {
|
||||||
// by calling `self.get_mut().write_all()` directly, which avoids
|
// by calling `self.get_mut().write_all()` directly, which avoids
|
||||||
// round trips through the buffer in the event of a series of partial
|
// round trips through the buffer in the event of a series of partial
|
||||||
// writes in some circumstances.
|
// writes in some circumstances.
|
||||||
if self.buf.len() + buf.len() > self.buf.capacity() {
|
|
||||||
|
if buf.len() > self.spare_capacity() {
|
||||||
self.flush_buf()?;
|
self.flush_buf()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -409,7 +410,7 @@ impl<W: Write> BufWriter<W> {
|
||||||
|
|
||||||
// SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and
|
// SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and
|
||||||
// we entered this else block because `buf.len() < self.buf.capacity()`.
|
// we entered this else block because `buf.len() < self.buf.capacity()`.
|
||||||
// Therefore, `self.buf.len() + buf.len() <= self.buf.capacity()`.
|
// Therefore, `buf.len() <= self.buf.capacity() - self.buf.len()`.
|
||||||
unsafe {
|
unsafe {
|
||||||
self.write_to_buffer_unchecked(buf);
|
self.write_to_buffer_unchecked(buf);
|
||||||
}
|
}
|
||||||
|
@ -418,11 +419,11 @@ impl<W: Write> BufWriter<W> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SAFETY: Requires `self.buf.len() + buf.len() <= self.buf.capacity()`,
|
// SAFETY: Requires `buf.len() <= self.buf.capacity() - self.buf.len()`,
|
||||||
// i.e., that input buffer length is less than or equal to spare capacity.
|
// i.e., that input buffer length is less than or equal to spare capacity.
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
unsafe fn write_to_buffer_unchecked(&mut self, buf: &[u8]) {
|
unsafe fn write_to_buffer_unchecked(&mut self, buf: &[u8]) {
|
||||||
debug_assert!(self.buf.len() + buf.len() <= self.buf.capacity());
|
debug_assert!(buf.len() <= self.spare_capacity());
|
||||||
let old_len = self.buf.len();
|
let old_len = self.buf.len();
|
||||||
let buf_len = buf.len();
|
let buf_len = buf.len();
|
||||||
let src = buf.as_ptr();
|
let src = buf.as_ptr();
|
||||||
|
@ -430,6 +431,11 @@ impl<W: Write> BufWriter<W> {
|
||||||
ptr::copy_nonoverlapping(src, dst, buf_len);
|
ptr::copy_nonoverlapping(src, dst, buf_len);
|
||||||
self.buf.set_len(old_len + buf_len);
|
self.buf.set_len(old_len + buf_len);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn spare_capacity(&self) -> usize {
|
||||||
|
self.buf.capacity() - self.buf.len()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[unstable(feature = "bufwriter_into_raw_parts", issue = "80690")]
|
#[unstable(feature = "bufwriter_into_raw_parts", issue = "80690")]
|
||||||
|
@ -505,7 +511,7 @@ impl<W: Write> Write for BufWriter<W> {
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
// Use < instead of <= to avoid a needless trip through the buffer in some cases.
|
// Use < instead of <= to avoid a needless trip through the buffer in some cases.
|
||||||
// See `write_cold` for details.
|
// See `write_cold` for details.
|
||||||
if self.buf.len() + buf.len() < self.buf.capacity() {
|
if buf.len() < self.spare_capacity() {
|
||||||
// SAFETY: safe by above conditional.
|
// SAFETY: safe by above conditional.
|
||||||
unsafe {
|
unsafe {
|
||||||
self.write_to_buffer_unchecked(buf);
|
self.write_to_buffer_unchecked(buf);
|
||||||
|
@ -521,7 +527,7 @@ impl<W: Write> Write for BufWriter<W> {
|
||||||
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
|
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
|
||||||
// Use < instead of <= to avoid a needless trip through the buffer in some cases.
|
// Use < instead of <= to avoid a needless trip through the buffer in some cases.
|
||||||
// See `write_all_cold` for details.
|
// See `write_all_cold` for details.
|
||||||
if self.buf.len() + buf.len() < self.buf.capacity() {
|
if buf.len() < self.spare_capacity() {
|
||||||
// SAFETY: safe by above conditional.
|
// SAFETY: safe by above conditional.
|
||||||
unsafe {
|
unsafe {
|
||||||
self.write_to_buffer_unchecked(buf);
|
self.write_to_buffer_unchecked(buf);
|
||||||
|
@ -537,16 +543,31 @@ impl<W: Write> Write for BufWriter<W> {
|
||||||
// FIXME: Consider applying `#[inline]` / `#[inline(never)]` optimizations already applied
|
// FIXME: Consider applying `#[inline]` / `#[inline(never)]` optimizations already applied
|
||||||
// to `write` and `write_all`. The performance benefits can be significant. See #79930.
|
// to `write` and `write_all`. The performance benefits can be significant. See #79930.
|
||||||
if self.get_ref().is_write_vectored() {
|
if self.get_ref().is_write_vectored() {
|
||||||
let total_len = bufs.iter().map(|b| b.len()).sum::<usize>();
|
// We have to handle the possibility that the total length of the buffers overflows
|
||||||
if self.buf.len() + total_len > self.buf.capacity() {
|
// `usize` (even though this can only happen if multiple `IoSlice`s reference the
|
||||||
|
// same underlying buffer, as otherwise the buffers wouldn't fit in memory). If the
|
||||||
|
// computation overflows, then surely the input cannot fit in our buffer, so we forward
|
||||||
|
// to the inner writer's `write_vectored` method to let it handle it appropriately.
|
||||||
|
let saturated_total_len =
|
||||||
|
bufs.iter().fold(0usize, |acc, b| acc.saturating_add(b.len()));
|
||||||
|
|
||||||
|
if saturated_total_len > self.spare_capacity() {
|
||||||
|
// Flush if the total length of the input exceeds our buffer's spare capacity.
|
||||||
|
// If we would have overflowed, this condition also holds, and we need to flush.
|
||||||
self.flush_buf()?;
|
self.flush_buf()?;
|
||||||
}
|
}
|
||||||
if total_len >= self.buf.capacity() {
|
|
||||||
|
if saturated_total_len >= self.buf.capacity() {
|
||||||
|
// Forward to our inner writer if the total length of the input is greater than or
|
||||||
|
// equal to our buffer capacity. If we would have overflowed, this condition also
|
||||||
|
// holds, and we punt to the inner writer.
|
||||||
self.panicked = true;
|
self.panicked = true;
|
||||||
let r = self.get_mut().write_vectored(bufs);
|
let r = self.get_mut().write_vectored(bufs);
|
||||||
self.panicked = false;
|
self.panicked = false;
|
||||||
r
|
r
|
||||||
} else {
|
} else {
|
||||||
|
// `saturated_total_len < self.buf.capacity()` implies that we did not saturate.
|
||||||
|
|
||||||
// SAFETY: We checked whether or not the spare capacity was large enough above. If
|
// SAFETY: We checked whether or not the spare capacity was large enough above. If
|
||||||
// it was, then we're safe already. If it wasn't, we flushed, making sufficient
|
// it was, then we're safe already. If it wasn't, we flushed, making sufficient
|
||||||
// room for any input <= the buffer size, which includes this input.
|
// room for any input <= the buffer size, which includes this input.
|
||||||
|
@ -554,14 +575,14 @@ impl<W: Write> Write for BufWriter<W> {
|
||||||
bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b));
|
bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b));
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(total_len)
|
Ok(saturated_total_len)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let mut iter = bufs.iter();
|
let mut iter = bufs.iter();
|
||||||
let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) {
|
let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) {
|
||||||
// This is the first non-empty slice to write, so if it does
|
// This is the first non-empty slice to write, so if it does
|
||||||
// not fit in the buffer, we still get to flush and proceed.
|
// not fit in the buffer, we still get to flush and proceed.
|
||||||
if self.buf.len() + buf.len() > self.buf.capacity() {
|
if buf.len() > self.spare_capacity() {
|
||||||
self.flush_buf()?;
|
self.flush_buf()?;
|
||||||
}
|
}
|
||||||
if buf.len() >= self.buf.capacity() {
|
if buf.len() >= self.buf.capacity() {
|
||||||
|
@ -586,12 +607,15 @@ impl<W: Write> Write for BufWriter<W> {
|
||||||
};
|
};
|
||||||
debug_assert!(total_written != 0);
|
debug_assert!(total_written != 0);
|
||||||
for buf in iter {
|
for buf in iter {
|
||||||
if self.buf.len() + buf.len() <= self.buf.capacity() {
|
if buf.len() <= self.spare_capacity() {
|
||||||
// SAFETY: safe by above conditional.
|
// SAFETY: safe by above conditional.
|
||||||
unsafe {
|
unsafe {
|
||||||
self.write_to_buffer_unchecked(buf);
|
self.write_to_buffer_unchecked(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This cannot overflow `usize`. If we are here, we've written all of the bytes
|
||||||
|
// so far to our buffer, and we've ensured that we never exceed the buffer's
|
||||||
|
// capacity. Therefore, `total_written` <= `self.buf.capacity()` <= `usize::MAX`.
|
||||||
total_written += buf.len();
|
total_written += buf.len();
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue