1
Fork 0

Rewrite MemDecoder around pointers not a slice

This commit is contained in:
Ben Kimock 2023-04-20 23:11:47 -04:00
parent 39c6804b92
commit 1f67ba61a9
11 changed files with 174 additions and 130 deletions

View file

@ -1,3 +1,6 @@
use crate::opaque::MemDecoder;
use crate::serialize::Decoder;
/// Returns the length of the longest LEB128 encoding for `T`, assuming `T` is an integer type
pub const fn max_leb128_len<T>() -> usize {
// The longest LEB128 encoding for an integer uses 7 bits per byte.
@ -50,21 +53,19 @@ impl_write_unsigned_leb128!(write_usize_leb128, usize);
macro_rules! impl_read_unsigned_leb128 {
($fn_name:ident, $int_ty:ty) => {
#[inline]
pub fn $fn_name(slice: &[u8], position: &mut usize) -> $int_ty {
pub fn $fn_name(decoder: &mut MemDecoder<'_>) -> $int_ty {
// The first iteration of this loop is unpeeled. This is a
// performance win because this code is hot and integer values less
// than 128 are very common, typically occurring 50-80% or more of
// the time, even for u64 and u128.
let byte = slice[*position];
*position += 1;
let byte = decoder.read_u8();
if (byte & 0x80) == 0 {
return byte as $int_ty;
}
let mut result = (byte & 0x7F) as $int_ty;
let mut shift = 7;
loop {
let byte = slice[*position];
*position += 1;
let byte = decoder.read_u8();
if (byte & 0x80) == 0 {
result |= (byte as $int_ty) << shift;
return result;
@ -127,14 +128,13 @@ impl_write_signed_leb128!(write_isize_leb128, isize);
macro_rules! impl_read_signed_leb128 {
($fn_name:ident, $int_ty:ty) => {
#[inline]
pub fn $fn_name(slice: &[u8], position: &mut usize) -> $int_ty {
pub fn $fn_name(decoder: &mut MemDecoder<'_>) -> $int_ty {
let mut result = 0;
let mut shift = 0;
let mut byte;
loop {
byte = slice[*position];
*position += 1;
byte = decoder.read_u8();
result |= <$int_ty>::from(byte & 0x7F) << shift;
shift += 7;

View file

@ -16,6 +16,7 @@ Core encoding and decoding interfaces.
#![feature(maybe_uninit_slice)]
#![feature(new_uninit)]
#![feature(allocator_api)]
#![feature(ptr_sub_ptr)]
#![cfg_attr(test, feature(test))]
#![allow(rustc::internal)]
#![deny(rustc::untranslatable_diagnostic)]

View file

@ -2,7 +2,9 @@ use crate::leb128::{self, largest_max_leb128_len};
use crate::serialize::{Decodable, Decoder, Encodable, Encoder};
use std::fs::File;
use std::io::{self, Write};
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::ops::Range;
use std::path::Path;
use std::ptr;
@ -510,38 +512,125 @@ impl Encoder for FileEncoder {
// Decoder
// -----------------------------------------------------------------------------
// Conceptually, `MemDecoder` wraps a `&[u8]` with a cursor into it that is always valid.
// This is implemented with three pointers, two which represent the original slice and a
// third that is our cursor.
// It is an invariant of this type that start <= current <= end.
// Additionally, the implementation of this type never modifies start and end.
pub struct MemDecoder<'a> {
pub data: &'a [u8],
position: usize,
start: *const u8,
current: *const u8,
end: *const u8,
_marker: PhantomData<&'a u8>,
}
impl<'a> MemDecoder<'a> {
#[inline]
pub fn new(data: &'a [u8], position: usize) -> MemDecoder<'a> {
MemDecoder { data, position }
let Range { start, end } = data.as_ptr_range();
MemDecoder { start, current: data[position..].as_ptr(), end, _marker: PhantomData }
}
#[inline]
pub fn position(&self) -> usize {
self.position
pub fn data(&self) -> &'a [u8] {
// SAFETY: This recovers the original slice, only using members we never modify.
unsafe { std::slice::from_raw_parts(self.start, self.len()) }
}
#[inline]
pub fn set_position(&mut self, pos: usize) {
self.position = pos
pub fn len(&self) -> usize {
// SAFETY: This recovers the length of the original slice, only using members we never modify.
unsafe { self.end.sub_ptr(self.start) }
}
#[inline]
pub fn advance(&mut self, bytes: usize) {
self.position += bytes;
pub fn remaining(&self) -> usize {
// SAFETY: This type guarantees current <= end.
unsafe { self.end.sub_ptr(self.current) }
}
#[cold]
#[inline(never)]
fn decoder_exhausted() -> ! {
panic!("MemDecoder exhausted")
}
#[inline]
fn read_byte(&mut self) -> u8 {
if self.current == self.end {
Self::decoder_exhausted();
}
// SAFETY: This type guarantees current <= end, and we just checked current == end.
unsafe {
let byte = *self.current;
self.current = self.current.add(1);
byte
}
}
#[inline]
fn read_array<const N: usize>(&mut self) -> [u8; N] {
self.read_raw_bytes(N).try_into().unwrap()
}
// The trait method doesn't have a lifetime parameter, and we need a version of this
// that definitely returns a slice based on the underlying storage as opposed to
// the Decoder itself in order to implement read_str efficiently.
#[inline]
fn read_raw_bytes_inherent(&mut self, bytes: usize) -> &'a [u8] {
if bytes > self.remaining() {
Self::decoder_exhausted();
}
// SAFETY: We just checked if this range is in-bounds above.
unsafe {
let slice = std::slice::from_raw_parts(self.current, bytes);
self.current = self.current.add(bytes);
slice
}
}
/// While we could manually expose manipulation of the decoder position,
/// all current users of that method would need to reset the position later,
/// incurring the bounds check of set_position twice.
#[inline]
pub fn with_position<F, T>(&mut self, pos: usize, func: F) -> T
where
F: Fn(&mut MemDecoder<'a>) -> T,
{
struct SetOnDrop<'a, 'guarded> {
decoder: &'guarded mut MemDecoder<'a>,
current: *const u8,
}
impl Drop for SetOnDrop<'_, '_> {
fn drop(&mut self) {
self.decoder.current = self.current;
}
}
if pos >= self.len() {
Self::decoder_exhausted();
}
let previous = self.current;
// SAFETY: We just checked if this add is in-bounds above.
unsafe {
self.current = self.start.add(pos);
}
let guard = SetOnDrop { current: previous, decoder: self };
func(guard.decoder)
}
}
macro_rules! read_leb128 {
($dec:expr, $fun:ident) => {{ leb128::$fun($dec.data, &mut $dec.position) }};
($dec:expr, $fun:ident) => {{ leb128::$fun($dec) }};
}
impl<'a> Decoder for MemDecoder<'a> {
#[inline]
fn position(&self) -> usize {
// SAFETY: This type guarantees start <= current
unsafe { self.current.sub_ptr(self.start) }
}
#[inline]
fn read_u128(&mut self) -> u128 {
read_leb128!(self, read_u128_leb128)
@ -559,17 +648,12 @@ impl<'a> Decoder for MemDecoder<'a> {
#[inline]
fn read_u16(&mut self) -> u16 {
let bytes = [self.data[self.position], self.data[self.position + 1]];
let value = u16::from_le_bytes(bytes);
self.position += 2;
value
u16::from_le_bytes(self.read_array())
}
#[inline]
fn read_u8(&mut self) -> u8 {
let value = self.data[self.position];
self.position += 1;
value
self.read_byte()
}
#[inline]
@ -594,17 +678,12 @@ impl<'a> Decoder for MemDecoder<'a> {
#[inline]
fn read_i16(&mut self) -> i16 {
let bytes = [self.data[self.position], self.data[self.position + 1]];
let value = i16::from_le_bytes(bytes);
self.position += 2;
value
i16::from_le_bytes(self.read_array())
}
#[inline]
fn read_i8(&mut self) -> i8 {
let value = self.data[self.position];
self.position += 1;
value as i8
self.read_byte() as i8
}
#[inline]
@ -625,22 +704,26 @@ impl<'a> Decoder for MemDecoder<'a> {
}
#[inline]
fn read_str(&mut self) -> &'a str {
fn read_str(&mut self) -> &str {
let len = self.read_usize();
let sentinel = self.data[self.position + len];
assert!(sentinel == STR_SENTINEL);
let s = unsafe {
std::str::from_utf8_unchecked(&self.data[self.position..self.position + len])
};
self.position += len + 1;
s
let bytes = self.read_raw_bytes_inherent(len + 1);
assert!(bytes[len] == STR_SENTINEL);
unsafe { std::str::from_utf8_unchecked(&bytes[..len]) }
}
#[inline]
fn read_raw_bytes(&mut self, bytes: usize) -> &'a [u8] {
let start = self.position;
self.position += bytes;
&self.data[start..self.position]
fn read_raw_bytes(&mut self, bytes: usize) -> &[u8] {
self.read_raw_bytes_inherent(bytes)
}
#[inline]
fn peek_byte(&self) -> u8 {
if self.current == self.end {
Self::decoder_exhausted();
}
// SAFETY: This type guarantees current is inbounds or one-past-the-end, which is end.
// Since we just checked current == end, the current pointer must be inbounds.
unsafe { *self.current }
}
}

View file

@ -84,6 +84,8 @@ pub trait Decoder {
fn read_char(&mut self) -> char;
fn read_str(&mut self) -> &str;
fn read_raw_bytes(&mut self, len: usize) -> &[u8];
fn peek_byte(&self) -> u8;
fn position(&self) -> usize;
}
/// Trait for types that can be serialized