From 17cfbc6fa37b7b82a4b4f2fb1637082fec2967ad Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Fri, 16 Aug 2024 12:03:42 +0200 Subject: [PATCH] FD: remove big surrounding RefCell, simplify socketpair --- src/tools/miri/src/shims/unix/fd.rs | 103 ++++++------- src/tools/miri/src/shims/unix/fs.rs | 48 +++--- src/tools/miri/src/shims/unix/linux/epoll.rs | 24 ++- .../miri/src/shims/unix/linux/eventfd.rs | 37 +++-- src/tools/miri/src/shims/unix/socket.rs | 144 ++++++++---------- .../miri/tests/pass-dep/libc/libc-epoll.rs | 4 + 6 files changed, 171 insertions(+), 189 deletions(-) diff --git a/src/tools/miri/src/shims/unix/fd.rs b/src/tools/miri/src/shims/unix/fd.rs index 98a124b9a56..1fd1cf4d99e 100644 --- a/src/tools/miri/src/shims/unix/fd.rs +++ b/src/tools/miri/src/shims/unix/fd.rs @@ -2,9 +2,9 @@ //! standard file descriptors (stdin/stdout/stderr). use std::any::Any; -use std::cell::{Ref, RefCell, RefMut}; use std::collections::BTreeMap; use std::io::{self, ErrorKind, IsTerminal, Read, SeekFrom, Write}; +use std::ops::Deref; use std::rc::Rc; use std::rc::Weak; @@ -27,7 +27,7 @@ pub trait FileDescription: std::fmt::Debug + Any { /// Reads as much as possible into the given buffer, and returns the number of bytes read. fn read<'tcx>( - &mut self, + &self, _communicate_allowed: bool, _fd_id: FdId, _bytes: &mut [u8], @@ -38,7 +38,7 @@ pub trait FileDescription: std::fmt::Debug + Any { /// Writes as much as possible from the given buffer, and returns the number of bytes written. fn write<'tcx>( - &mut self, + &self, _communicate_allowed: bool, _fd_id: FdId, _bytes: &[u8], @@ -50,7 +50,7 @@ pub trait FileDescription: std::fmt::Debug + Any { /// Reads as much as possible into the given buffer from a given offset, /// and returns the number of bytes read. fn pread<'tcx>( - &mut self, + &self, _communicate_allowed: bool, _bytes: &mut [u8], _offset: u64, @@ -62,7 +62,7 @@ pub trait FileDescription: std::fmt::Debug + Any { /// Writes as much as possible from the given buffer starting at a given offset, /// and returns the number of bytes written. fn pwrite<'tcx>( - &mut self, + &self, _communicate_allowed: bool, _bytes: &[u8], _offset: u64, @@ -74,7 +74,7 @@ pub trait FileDescription: std::fmt::Debug + Any { /// Seeks to the given offset (which can be relative to the beginning, end, or current position). /// Returns the new position from the start of the stream. fn seek<'tcx>( - &mut self, + &self, _communicate_allowed: bool, _offset: SeekFrom, ) -> InterpResult<'tcx, io::Result> { @@ -111,14 +111,9 @@ pub trait FileDescription: std::fmt::Debug + Any { impl dyn FileDescription { #[inline(always)] - pub fn downcast_ref(&self) -> Option<&T> { + pub fn downcast(&self) -> Option<&T> { (self as &dyn Any).downcast_ref() } - - #[inline(always)] - pub fn downcast_mut(&mut self) -> Option<&mut T> { - (self as &mut dyn Any).downcast_mut() - } } impl FileDescription for io::Stdin { @@ -127,7 +122,7 @@ impl FileDescription for io::Stdin { } fn read<'tcx>( - &mut self, + &self, communicate_allowed: bool, _fd_id: FdId, bytes: &mut [u8], @@ -137,7 +132,7 @@ impl FileDescription for io::Stdin { // We want isolation mode to be deterministic, so we have to disallow all reads, even stdin. helpers::isolation_abort_error("`read` from stdin")?; } - Ok(Read::read(self, bytes)) + Ok(Read::read(&mut { self }, bytes)) } fn is_tty(&self, communicate_allowed: bool) -> bool { @@ -151,14 +146,14 @@ impl FileDescription for io::Stdout { } fn write<'tcx>( - &mut self, + &self, _communicate_allowed: bool, _fd_id: FdId, bytes: &[u8], _ecx: &mut MiriInterpCx<'tcx>, ) -> InterpResult<'tcx, io::Result> { // We allow writing to stderr even with isolation enabled. - let result = Write::write(self, bytes); + let result = Write::write(&mut { self }, bytes); // Stdout is buffered, flush to make sure it appears on the // screen. This is the write() syscall of the interpreted // program, we want it to correspond to a write() syscall on @@ -180,7 +175,7 @@ impl FileDescription for io::Stderr { } fn write<'tcx>( - &mut self, + &self, _communicate_allowed: bool, _fd_id: FdId, bytes: &[u8], @@ -206,7 +201,7 @@ impl FileDescription for NullOutput { } fn write<'tcx>( - &mut self, + &self, _communicate_allowed: bool, _fd_id: FdId, bytes: &[u8], @@ -221,26 +216,23 @@ impl FileDescription for NullOutput { #[derive(Clone, Debug)] pub struct FileDescWithId { id: FdId, - file_description: RefCell>, + file_description: Box, } #[derive(Clone, Debug)] pub struct FileDescriptionRef(Rc>); +impl Deref for FileDescriptionRef { + type Target = dyn FileDescription; + + fn deref(&self) -> &Self::Target { + &*self.0.file_description + } +} + impl FileDescriptionRef { fn new(fd: impl FileDescription, id: FdId) -> Self { - FileDescriptionRef(Rc::new(FileDescWithId { - id, - file_description: RefCell::new(Box::new(fd)), - })) - } - - pub fn borrow(&self) -> Ref<'_, dyn FileDescription> { - Ref::map(self.0.file_description.borrow(), |fd| fd.as_ref()) - } - - pub fn borrow_mut(&self) -> RefMut<'_, dyn FileDescription> { - RefMut::map(self.0.file_description.borrow_mut(), |fd| fd.as_mut()) + FileDescriptionRef(Rc::new(FileDescWithId { id, file_description: Box::new(fd) })) } pub fn close<'tcx>( @@ -256,7 +248,7 @@ impl FileDescriptionRef { // Remove entry from the global epoll_event_interest table. ecx.machine.epoll_interests.remove(id); - RefCell::into_inner(fd.file_description).close(communicate_allowed, ecx) + fd.file_description.close(communicate_allowed, ecx) } None => Ok(Ok(())), } @@ -277,7 +269,7 @@ impl FileDescriptionRef { ecx: &mut InterpCx<'tcx, MiriMachine<'tcx>>, ) -> InterpResult<'tcx, ()> { use crate::shims::unix::linux::epoll::EvalContextExt; - ecx.check_and_update_readiness(self.get_id(), || self.borrow_mut().get_epoll_ready_events()) + ecx.check_and_update_readiness(self.get_id(), || self.get_epoll_ready_events()) } } @@ -334,11 +326,20 @@ impl FdTable { fds } - /// Insert a new file description to the FdTable. - pub fn insert_new(&mut self, fd: impl FileDescription) -> i32 { + pub fn new_ref(&mut self, fd: impl FileDescription) -> FileDescriptionRef { let file_handle = FileDescriptionRef::new(fd, self.next_file_description_id); self.next_file_description_id = FdId(self.next_file_description_id.0.strict_add(1)); - self.insert_ref_with_min_fd(file_handle, 0) + file_handle + } + + /// Insert a new file description to the FdTable. + pub fn insert_new(&mut self, fd: impl FileDescription) -> i32 { + let fd_ref = self.new_ref(fd); + self.insert(fd_ref) + } + + pub fn insert(&mut self, fd_ref: FileDescriptionRef) -> i32 { + self.insert_ref_with_min_fd(fd_ref, 0) } /// Insert a file description, giving it a file descriptor that is at least `min_fd`. @@ -368,17 +369,7 @@ impl FdTable { new_fd } - pub fn get(&self, fd: i32) -> Option> { - let fd = self.fds.get(&fd)?; - Some(fd.borrow()) - } - - pub fn get_mut(&self, fd: i32) -> Option> { - let fd = self.fds.get(&fd)?; - Some(fd.borrow_mut()) - } - - pub fn get_ref(&self, fd: i32) -> Option { + pub fn get(&self, fd: i32) -> Option { let fd = self.fds.get(&fd)?; Some(fd.clone()) } @@ -397,7 +388,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { fn dup(&mut self, old_fd: i32) -> InterpResult<'tcx, Scalar> { let this = self.eval_context_mut(); - let Some(dup_fd) = this.machine.fds.get_ref(old_fd) else { + let Some(dup_fd) = this.machine.fds.get(old_fd) else { return Ok(Scalar::from_i32(this.fd_not_found()?)); }; Ok(Scalar::from_i32(this.machine.fds.insert_ref_with_min_fd(dup_fd, 0))) @@ -406,7 +397,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { fn dup2(&mut self, old_fd: i32, new_fd: i32) -> InterpResult<'tcx, Scalar> { let this = self.eval_context_mut(); - let Some(dup_fd) = this.machine.fds.get_ref(old_fd) else { + let Some(dup_fd) = this.machine.fds.get(old_fd) else { return Ok(Scalar::from_i32(this.fd_not_found()?)); }; if new_fd != old_fd { @@ -492,7 +483,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { } let start = this.read_scalar(&args[2])?.to_i32()?; - match this.machine.fds.get_ref(fd) { + match this.machine.fds.get(fd) { Some(dup_fd) => Ok(Scalar::from_i32(this.machine.fds.insert_ref_with_min_fd(dup_fd, start))), None => Ok(Scalar::from_i32(this.fd_not_found()?)), @@ -565,7 +556,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let communicate = this.machine.communicate(); // We temporarily dup the FD to be able to retain mutable access to `this`. - let Some(fd) = this.machine.fds.get_ref(fd) else { + let Some(fd) = this.machine.fds.get(fd) else { trace!("read: FD not found"); return Ok(Scalar::from_target_isize(this.fd_not_found()?, this)); }; @@ -576,14 +567,14 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { // `usize::MAX` because it is bounded by the host's `isize`. let mut bytes = vec![0; usize::try_from(count).unwrap()]; let result = match offset { - None => fd.borrow_mut().read(communicate, fd.get_id(), &mut bytes, this), + None => fd.read(communicate, fd.get_id(), &mut bytes, this), Some(offset) => { let Ok(offset) = u64::try_from(offset) else { let einval = this.eval_libc("EINVAL"); this.set_last_error(einval)?; return Ok(Scalar::from_target_isize(-1, this)); }; - fd.borrow_mut().pread(communicate, &mut bytes, offset, this) + fd.pread(communicate, &mut bytes, offset, this) } }; @@ -629,19 +620,19 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let bytes = this.read_bytes_ptr_strip_provenance(buf, Size::from_bytes(count))?.to_owned(); // We temporarily dup the FD to be able to retain mutable access to `this`. - let Some(fd) = this.machine.fds.get_ref(fd) else { + let Some(fd) = this.machine.fds.get(fd) else { return Ok(Scalar::from_target_isize(this.fd_not_found()?, this)); }; let result = match offset { - None => fd.borrow_mut().write(communicate, fd.get_id(), &bytes, this), + None => fd.write(communicate, fd.get_id(), &bytes, this), Some(offset) => { let Ok(offset) = u64::try_from(offset) else { let einval = this.eval_libc("EINVAL"); this.set_last_error(einval)?; return Ok(Scalar::from_target_isize(-1, this)); }; - fd.borrow_mut().pwrite(communicate, &bytes, offset, this) + fd.pwrite(communicate, &bytes, offset, this) } }; diff --git a/src/tools/miri/src/shims/unix/fs.rs b/src/tools/miri/src/shims/unix/fs.rs index 9da36e64a0f..e6076abedbd 100644 --- a/src/tools/miri/src/shims/unix/fs.rs +++ b/src/tools/miri/src/shims/unix/fs.rs @@ -31,29 +31,29 @@ impl FileDescription for FileHandle { } fn read<'tcx>( - &mut self, + &self, communicate_allowed: bool, _fd_id: FdId, bytes: &mut [u8], _ecx: &mut MiriInterpCx<'tcx>, ) -> InterpResult<'tcx, io::Result> { assert!(communicate_allowed, "isolation should have prevented even opening a file"); - Ok(self.file.read(bytes)) + Ok((&mut &self.file).read(bytes)) } fn write<'tcx>( - &mut self, + &self, communicate_allowed: bool, _fd_id: FdId, bytes: &[u8], _ecx: &mut MiriInterpCx<'tcx>, ) -> InterpResult<'tcx, io::Result> { assert!(communicate_allowed, "isolation should have prevented even opening a file"); - Ok(self.file.write(bytes)) + Ok((&mut &self.file).write(bytes)) } fn pread<'tcx>( - &mut self, + &self, communicate_allowed: bool, bytes: &mut [u8], offset: u64, @@ -63,13 +63,13 @@ impl FileDescription for FileHandle { // Emulates pread using seek + read + seek to restore cursor position. // Correctness of this emulation relies on sequential nature of Miri execution. // The closure is used to emulate `try` block, since we "bubble" `io::Error` using `?`. + let file = &mut &self.file; let mut f = || { - let cursor_pos = self.file.stream_position()?; - self.file.seek(SeekFrom::Start(offset))?; - let res = self.file.read(bytes); + let cursor_pos = file.stream_position()?; + file.seek(SeekFrom::Start(offset))?; + let res = file.read(bytes); // Attempt to restore cursor position even if the read has failed - self.file - .seek(SeekFrom::Start(cursor_pos)) + file.seek(SeekFrom::Start(cursor_pos)) .expect("failed to restore file position, this shouldn't be possible"); res }; @@ -77,7 +77,7 @@ impl FileDescription for FileHandle { } fn pwrite<'tcx>( - &mut self, + &self, communicate_allowed: bool, bytes: &[u8], offset: u64, @@ -87,13 +87,13 @@ impl FileDescription for FileHandle { // Emulates pwrite using seek + write + seek to restore cursor position. // Correctness of this emulation relies on sequential nature of Miri execution. // The closure is used to emulate `try` block, since we "bubble" `io::Error` using `?`. + let file = &mut &self.file; let mut f = || { - let cursor_pos = self.file.stream_position()?; - self.file.seek(SeekFrom::Start(offset))?; - let res = self.file.write(bytes); + let cursor_pos = file.stream_position()?; + file.seek(SeekFrom::Start(offset))?; + let res = file.write(bytes); // Attempt to restore cursor position even if the write has failed - self.file - .seek(SeekFrom::Start(cursor_pos)) + file.seek(SeekFrom::Start(cursor_pos)) .expect("failed to restore file position, this shouldn't be possible"); res }; @@ -101,12 +101,12 @@ impl FileDescription for FileHandle { } fn seek<'tcx>( - &mut self, + &self, communicate_allowed: bool, offset: SeekFrom, ) -> InterpResult<'tcx, io::Result> { assert!(communicate_allowed, "isolation should have prevented even opening a file"); - Ok(self.file.seek(offset)) + Ok((&mut &self.file).seek(offset)) } fn close<'tcx>( @@ -580,7 +580,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let communicate = this.machine.communicate(); - let Some(mut file_description) = this.machine.fds.get_mut(fd) else { + let Some(file_description) = this.machine.fds.get(fd) else { return Ok(Scalar::from_i64(this.fd_not_found()?)); }; let result = file_description @@ -1276,7 +1276,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { // FIXME: Support ftruncate64 for all FDs let FileHandle { file, writable } = - file_description.downcast_ref::().ok_or_else(|| { + file_description.downcast::().ok_or_else(|| { err_unsup_format!("`ftruncate64` is only supported on file-backed file descriptors") })?; @@ -1328,7 +1328,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { }; // Only regular files support synchronization. let FileHandle { file, writable } = - file_description.downcast_ref::().ok_or_else(|| { + file_description.downcast::().ok_or_else(|| { err_unsup_format!("`fsync` is only supported on file-backed file descriptors") })?; let io_result = maybe_sync_file(file, *writable, File::sync_all); @@ -1353,7 +1353,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { }; // Only regular files support synchronization. let FileHandle { file, writable } = - file_description.downcast_ref::().ok_or_else(|| { + file_description.downcast::().ok_or_else(|| { err_unsup_format!("`fdatasync` is only supported on file-backed file descriptors") })?; let io_result = maybe_sync_file(file, *writable, File::sync_data); @@ -1401,7 +1401,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { }; // Only regular files support synchronization. let FileHandle { file, writable } = - file_description.downcast_ref::().ok_or_else(|| { + file_description.downcast::().ok_or_else(|| { err_unsup_format!( "`sync_data_range` is only supported on file-backed file descriptors" ) @@ -1708,7 +1708,7 @@ impl FileMetadata { }; let file = &file_description - .downcast_ref::() + .downcast::() .ok_or_else(|| { err_unsup_format!( "obtaining metadata is only supported on file-backed file descriptors" diff --git a/src/tools/miri/src/shims/unix/linux/epoll.rs b/src/tools/miri/src/shims/unix/linux/epoll.rs index 89616bd0d07..a1b943f2e82 100644 --- a/src/tools/miri/src/shims/unix/linux/epoll.rs +++ b/src/tools/miri/src/shims/unix/linux/epoll.rs @@ -12,7 +12,7 @@ use crate::*; struct Epoll { /// A map of EpollEventInterests registered under this epoll instance. /// Each entry is differentiated using FdId and file descriptor value. - interest_list: BTreeMap<(FdId, i32), Rc>>, + interest_list: RefCell>>>, /// A map of EpollEventInstance that will be returned when `epoll_wait` is called. /// Similar to interest_list, the entry is also differentiated using FdId /// and file descriptor value. @@ -226,18 +226,17 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { } // Check if epfd is a valid epoll file descriptor. - let Some(epfd) = this.machine.fds.get_ref(epfd_value) else { + let Some(epfd) = this.machine.fds.get(epfd_value) else { return Ok(Scalar::from_i32(this.fd_not_found()?)); }; - let mut binding = epfd.borrow_mut(); - let epoll_file_description = &mut binding - .downcast_mut::() + let epoll_file_description = epfd + .downcast::() .ok_or_else(|| err_unsup_format!("non-epoll FD passed to `epoll_ctl`"))?; - let interest_list = &mut epoll_file_description.interest_list; + let mut interest_list = epoll_file_description.interest_list.borrow_mut(); let ready_list = &epoll_file_description.ready_list; - let Some(file_descriptor) = this.machine.fds.get_ref(fd) else { + let Some(file_descriptor) = this.machine.fds.get(fd) else { return Ok(Scalar::from_i32(this.fd_not_found()?)); }; let id = file_descriptor.get_id(); @@ -399,16 +398,15 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { throw_unsup_format!("epoll_wait: timeout value can only be 0"); } - let Some(epfd) = this.machine.fds.get_ref(epfd) else { + let Some(epfd) = this.machine.fds.get(epfd) else { return Ok(Scalar::from_i32(this.fd_not_found()?)); }; - let mut binding = epfd.borrow_mut(); - let epoll_file_description = &mut binding - .downcast_mut::() + let epoll_file_description = epfd + .downcast::() .ok_or_else(|| err_unsup_format!("non-epoll FD passed to `epoll_wait`"))?; - let binding = epoll_file_description.get_ready_list(); - let mut ready_list = binding.borrow_mut(); + let ready_list = epoll_file_description.get_ready_list(); + let mut ready_list = ready_list.borrow_mut(); let mut num_of_events: i32 = 0; let mut array_iter = this.project_array_fields(&event)?; diff --git a/src/tools/miri/src/shims/unix/linux/eventfd.rs b/src/tools/miri/src/shims/unix/linux/eventfd.rs index 8a11f225b22..cede3967bc8 100644 --- a/src/tools/miri/src/shims/unix/linux/eventfd.rs +++ b/src/tools/miri/src/shims/unix/linux/eventfd.rs @@ -1,4 +1,5 @@ //! Linux `eventfd` implementation. +use std::cell::{Cell, RefCell}; use std::io; use std::io::{Error, ErrorKind}; use std::mem; @@ -27,9 +28,9 @@ const MAX_COUNTER: u64 = u64::MAX - 1; struct Event { /// The object contains an unsigned 64-bit integer (uint64_t) counter that is maintained by the /// kernel. This counter is initialized with the value specified in the argument initval. - counter: u64, + counter: Cell, is_nonblock: bool, - clock: VClock, + clock: RefCell, } impl FileDescription for Event { @@ -42,8 +43,8 @@ impl FileDescription for Event { // need to be supported in the future, the check should be added here. Ok(EpollReadyEvents { - epollin: self.counter != 0, - epollout: self.counter != MAX_COUNTER, + epollin: self.counter.get() != 0, + epollout: self.counter.get() != MAX_COUNTER, ..EpollReadyEvents::new() }) } @@ -58,7 +59,7 @@ impl FileDescription for Event { /// Read the counter in the buffer and return the counter if succeeded. fn read<'tcx>( - &mut self, + &self, _communicate_allowed: bool, fd_id: FdId, bytes: &mut [u8], @@ -69,7 +70,8 @@ impl FileDescription for Event { return Ok(Err(Error::from(ErrorKind::InvalidInput))); }; // Block when counter == 0. - if self.counter == 0 { + let counter = self.counter.get(); + if counter == 0 { if self.is_nonblock { return Ok(Err(Error::from(ErrorKind::WouldBlock))); } else { @@ -78,13 +80,13 @@ impl FileDescription for Event { } } else { // Synchronize with all prior `write` calls to this FD. - ecx.acquire_clock(&self.clock); + ecx.acquire_clock(&self.clock.borrow()); // Return the counter in the host endianness using the buffer provided by caller. *bytes = match ecx.tcx.sess.target.endian { - Endian::Little => self.counter.to_le_bytes(), - Endian::Big => self.counter.to_be_bytes(), + Endian::Little => counter.to_le_bytes(), + Endian::Big => counter.to_be_bytes(), }; - self.counter = 0; + self.counter.set(0); // When any of the event happened, we check and update the status of all supported event // types for current file description. @@ -114,7 +116,7 @@ impl FileDescription for Event { /// supplied buffer is less than 8 bytes, or if an attempt is /// made to write the value 0xffffffffffffffff. fn write<'tcx>( - &mut self, + &self, _communicate_allowed: bool, fd_id: FdId, bytes: &[u8], @@ -135,13 +137,13 @@ impl FileDescription for Event { } // If the addition does not let the counter to exceed the maximum value, update the counter. // Else, block. - match self.counter.checked_add(num) { + match self.counter.get().checked_add(num) { Some(new_count @ 0..=MAX_COUNTER) => { // Future `read` calls will synchronize with this write, so update the FD clock. if let Some(clock) = &ecx.release_clock() { - self.clock.join(clock); + self.clock.borrow_mut().join(clock); } - self.counter = new_count; + self.counter.set(new_count); } None | Some(u64::MAX) => { if self.is_nonblock { @@ -219,8 +221,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let fds = &mut this.machine.fds; - let fd_value = - fds.insert_new(Event { counter: val.into(), is_nonblock, clock: VClock::default() }); + let fd_value = fds.insert_new(Event { + counter: Cell::new(val.into()), + is_nonblock, + clock: RefCell::new(VClock::default()), + }); Ok(Scalar::from_i32(fd_value)) } diff --git a/src/tools/miri/src/shims/unix/socket.rs b/src/tools/miri/src/shims/unix/socket.rs index 0f40d9776bb..2694391dfbb 100644 --- a/src/tools/miri/src/shims/unix/socket.rs +++ b/src/tools/miri/src/shims/unix/socket.rs @@ -1,8 +1,7 @@ -use std::cell::RefCell; +use std::cell::{OnceCell, RefCell}; use std::collections::VecDeque; use std::io; use std::io::{Error, ErrorKind, Read}; -use std::rc::{Rc, Weak}; use crate::shims::unix::fd::{FdId, WeakFileDescriptionRef}; use crate::shims::unix::linux::epoll::EpollReadyEvents; @@ -17,15 +16,12 @@ const MAX_SOCKETPAIR_BUFFER_CAPACITY: usize = 212992; /// Pair of connected sockets. #[derive(Debug)] struct SocketPair { - // By making the write link weak, a `write` can detect when all readers are - // gone, and trigger EPIPE as appropriate. - writebuf: Weak>, - readbuf: Rc>, - /// When a socketpair instance is created, two socketpair file descriptions are generated. - /// The peer_fd field holds a weak reference to the file description of peer socketpair. - // TODO: It might be possible to retrieve writebuf from peer_fd and remove the writebuf - // field above. - peer_fd: WeakFileDescriptionRef, + /// The buffer we are reading from. + readbuf: RefCell, + /// The `SocketPair` file descriptor that is our "peer", and that holds the buffer we are + /// writing to. This is a weak reference because the other side may be closed before us; all + /// future writes will then trigger EPIPE. + peer_fd: OnceCell, is_nonblock: bool, } @@ -39,6 +35,18 @@ struct Buffer { buf_has_writer: bool, } +impl Buffer { + fn new() -> Self { + Buffer { buf: VecDeque::new(), clock: VClock::default(), buf_has_writer: true } + } +} + +impl SocketPair { + fn peer_fd(&self) -> &WeakFileDescriptionRef { + self.peer_fd.get().unwrap() + } +} + impl FileDescription for SocketPair { fn name(&self) -> &'static str { "socketpair" @@ -49,29 +57,29 @@ impl FileDescription for SocketPair { // need to be supported in the future, the check should be added here. let mut epoll_ready_events = EpollReadyEvents::new(); - let readbuf = self.readbuf.borrow(); // Check if it is readable. + let readbuf = self.readbuf.borrow(); if !readbuf.buf.is_empty() { epoll_ready_events.epollin = true; } // Check if is writable. - if let Some(writebuf) = self.writebuf.upgrade() { - let writebuf = writebuf.borrow(); + if let Some(peer_fd) = self.peer_fd().upgrade() { + let writebuf = &peer_fd.downcast::().unwrap().readbuf.borrow(); let data_size = writebuf.buf.len(); let available_space = MAX_SOCKETPAIR_BUFFER_CAPACITY.strict_sub(data_size); if available_space != 0 { epoll_ready_events.epollout = true; } - } - - // Check if the peer_fd closed - if self.peer_fd.upgrade().is_none() { + } else { + // Peer FD has been closed. epoll_ready_events.epollrdhup = true; - // This is an edge case. Whenever epollrdhup is triggered, epollin will be added - // even though there is no data in the buffer. + // This is an edge case. Whenever epollrdhup is triggered, epollin and epollout will be + // added even though there is no data in the buffer. + // FIXME: Figure out why. This looks like a bug. epoll_ready_events.epollin = true; + epoll_ready_events.epollout = true; } Ok(epoll_ready_events) } @@ -81,15 +89,13 @@ impl FileDescription for SocketPair { _communicate_allowed: bool, ecx: &mut MiriInterpCx<'tcx>, ) -> InterpResult<'tcx, io::Result<()>> { - // This is used to signal socketfd of other side that there is no writer to its readbuf. - // If the upgrade fails, there is no need to update as all read ends have been dropped. - if let Some(writebuf) = self.writebuf.upgrade() { - writebuf.borrow_mut().buf_has_writer = false; - }; + if let Some(peer_fd) = self.peer_fd().upgrade() { + // This is used to signal socketfd of other side that there is no writer to its readbuf. + // If the upgrade fails, there is no need to update as all read ends have been dropped. + peer_fd.downcast::().unwrap().readbuf.borrow_mut().buf_has_writer = false; - // Notify peer fd that closed has happened. - if let Some(peer_fd) = self.peer_fd.upgrade() { - // When any of the event happened, we check and update the status of all supported events + // Notify peer fd that closed has happened. + // When any of the events happened, we check and update the status of all supported events // types of peer fd. peer_fd.check_and_update_readiness(ecx)?; } @@ -97,20 +103,20 @@ impl FileDescription for SocketPair { } fn read<'tcx>( - &mut self, + &self, _communicate_allowed: bool, _fd_id: FdId, bytes: &mut [u8], ecx: &mut MiriInterpCx<'tcx>, ) -> InterpResult<'tcx, io::Result> { let request_byte_size = bytes.len(); - let mut readbuf = self.readbuf.borrow_mut(); // Always succeed on read size 0. if request_byte_size == 0 { return Ok(Ok(0)); } + let mut readbuf = self.readbuf.borrow_mut(); if readbuf.buf.is_empty() { if !readbuf.buf_has_writer { // Socketpair with no writer and empty buffer. @@ -141,8 +147,7 @@ impl FileDescription for SocketPair { // Conveniently, `read` exists on `VecDeque` and has exactly the desired behavior. let actual_read_size = readbuf.buf.read(bytes).unwrap(); - // The readbuf needs to be explicitly dropped because it will cause panic when - // check_and_update_readiness borrows it again. + // Need to drop before others can access the readbuf again. drop(readbuf); // A notification should be provided for the peer file description even when it can @@ -152,7 +157,7 @@ impl FileDescription for SocketPair { // don't know what that *certain number* is, we will provide a notification every time // a read is successful. This might result in our epoll emulation providing more // notifications than the real system. - if let Some(peer_fd) = self.peer_fd.upgrade() { + if let Some(peer_fd) = self.peer_fd().upgrade() { peer_fd.check_and_update_readiness(ecx)?; } @@ -160,7 +165,7 @@ impl FileDescription for SocketPair { } fn write<'tcx>( - &mut self, + &self, _communicate_allowed: bool, _fd_id: FdId, bytes: &[u8], @@ -173,12 +178,13 @@ impl FileDescription for SocketPair { return Ok(Ok(0)); } - let Some(writebuf) = self.writebuf.upgrade() else { + // We are writing to our peer's readbuf. + let Some(peer_fd) = self.peer_fd().upgrade() else { // If the upgrade from Weak to Rc fails, it indicates that all read ends have been // closed. return Ok(Err(Error::from(ErrorKind::BrokenPipe))); }; - let mut writebuf = writebuf.borrow_mut(); + let mut writebuf = peer_fd.downcast::().unwrap().readbuf.borrow_mut(); let data_size = writebuf.buf.len(); let available_space = MAX_SOCKETPAIR_BUFFER_CAPACITY.strict_sub(data_size); if available_space == 0 { @@ -198,13 +204,12 @@ impl FileDescription for SocketPair { let actual_write_size = write_size.min(available_space); writebuf.buf.extend(&bytes[..actual_write_size]); - // The writebuf needs to be explicitly dropped because it will cause panic when - // check_and_update_readiness borrows it again. + // Need to stop accessing peer_fd so that it can be notified. drop(writebuf); + // Notification should be provided for peer fd as it became readable. - if let Some(peer_fd) = self.peer_fd.upgrade() { - peer_fd.check_and_update_readiness(ecx)?; - } + peer_fd.check_and_update_readiness(ecx)?; + return Ok(Ok(actual_write_size)); } } @@ -268,51 +273,30 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { ); } - let buffer1 = Rc::new(RefCell::new(Buffer { - buf: VecDeque::new(), - clock: VClock::default(), - buf_has_writer: true, - })); - - let buffer2 = Rc::new(RefCell::new(Buffer { - buf: VecDeque::new(), - clock: VClock::default(), - buf_has_writer: true, - })); - - let socketpair_0 = SocketPair { - writebuf: Rc::downgrade(&buffer1), - readbuf: Rc::clone(&buffer2), - peer_fd: WeakFileDescriptionRef::default(), - is_nonblock: is_sock_nonblock, - }; - let socketpair_1 = SocketPair { - writebuf: Rc::downgrade(&buffer2), - readbuf: Rc::clone(&buffer1), - peer_fd: WeakFileDescriptionRef::default(), - is_nonblock: is_sock_nonblock, - }; - - // Insert the file description to the fd table. + // Generate file descriptions. let fds = &mut this.machine.fds; - let sv0 = fds.insert_new(socketpair_0); - let sv1 = fds.insert_new(socketpair_1); + let fd0 = fds.new_ref(SocketPair { + readbuf: RefCell::new(Buffer::new()), + peer_fd: OnceCell::new(), + is_nonblock: is_sock_nonblock, + }); + let fd1 = fds.new_ref(SocketPair { + readbuf: RefCell::new(Buffer::new()), + peer_fd: OnceCell::new(), + is_nonblock: is_sock_nonblock, + }); - // Get weak file descriptor and file description id value. - let fd_ref0 = fds.get_ref(sv0).unwrap(); - let fd_ref1 = fds.get_ref(sv1).unwrap(); - let weak_fd_ref0 = fd_ref0.downgrade(); - let weak_fd_ref1 = fd_ref1.downgrade(); + // Make the file descriptions point to each other. + fd0.downcast::().unwrap().peer_fd.set(fd1.downgrade()).unwrap(); + fd1.downcast::().unwrap().peer_fd.set(fd0.downgrade()).unwrap(); - // Update peer_fd and id field. - fd_ref1.borrow_mut().downcast_mut::().unwrap().peer_fd = weak_fd_ref0; + // Insert the file description to the fd table, generating the file descriptors. + let sv0 = fds.insert(fd0); + let sv1 = fds.insert(fd1); - fd_ref0.borrow_mut().downcast_mut::().unwrap().peer_fd = weak_fd_ref1; - - // Return socketpair file description value to the caller. + // Return socketpair file descriptors to the caller. let sv0 = Scalar::from_int(sv0, sv.layout.size); let sv1 = Scalar::from_int(sv1, sv.layout.size); - this.write_scalar(sv0, &sv)?; this.write_scalar(sv1, &sv.offset(sv.layout.size, sv.layout, this)?)?; diff --git a/src/tools/miri/tests/pass-dep/libc/libc-epoll.rs b/src/tools/miri/tests/pass-dep/libc/libc-epoll.rs index 11a0257dc4e..95fdf2f6035 100644 --- a/src/tools/miri/tests/pass-dep/libc/libc-epoll.rs +++ b/src/tools/miri/tests/pass-dep/libc/libc-epoll.rs @@ -19,6 +19,7 @@ fn main() { test_socketpair_read(); } +#[track_caller] fn check_epoll_wait( epfd: i32, mut expected_notifications: Vec<(u32, u64)>, @@ -28,6 +29,9 @@ fn check_epoll_wait( let maxsize = N; let array_ptr = array.as_mut_ptr(); let res = unsafe { libc::epoll_wait(epfd, array_ptr, maxsize.try_into().unwrap(), 0) }; + if res < 0 { + panic!("epoll_wait failed: {}", std::io::Error::last_os_error()); + } assert_eq!(res, expected_notifications.len().try_into().unwrap()); let slice = unsafe { std::slice::from_raw_parts(array_ptr, res.try_into().unwrap()) }; let mut return_events = slice.iter();