proc_macro: stop using a remote object handle for Literal

This builds on the symbol infrastructure built for `Ident` to replicate
the `LitKind` and `Lit` structures in rustc within the `proc_macro`
client, allowing literals to be fully created and interacted with from
the client thread. Only parsing and subspan operations still require
sync RPC.
This commit is contained in:
Nika Layzell 2022-07-03 01:04:31 -04:00
parent 491fccfbe3
commit b34c79f8f1
5 changed files with 305 additions and 260 deletions

View file

@ -14,9 +14,10 @@ use rustc_span::def_id::CrateNum;
use rustc_span::symbol::{self, sym, Symbol};
use rustc_span::{BytePos, FileName, Pos, SourceFile, Span};
use pm::bridge::{server, DelimSpan, ExpnGlobals, Group, Ident, Punct, TokenTree};
use pm::bridge::{
server, DelimSpan, ExpnGlobals, Group, Ident, LitKind, Literal, Punct, TokenTree,
};
use pm::{Delimiter, Level, LineColumn};
use std::ascii;
use std::ops::Bound;
trait FromInternal<T> {
@ -49,9 +50,40 @@ impl ToInternal<token::Delimiter> for Delimiter {
}
}
impl FromInternal<(TokenStream, &mut Rustc<'_, '_>)>
for Vec<TokenTree<TokenStream, Span, Symbol, Literal>>
{
impl FromInternal<token::LitKind> for LitKind {
fn from_internal(kind: token::LitKind) -> Self {
match kind {
token::Byte => LitKind::Byte,
token::Char => LitKind::Char,
token::Integer => LitKind::Integer,
token::Float => LitKind::Float,
token::Str => LitKind::Str,
token::StrRaw(n) => LitKind::StrRaw(n),
token::ByteStr => LitKind::ByteStr,
token::ByteStrRaw(n) => LitKind::ByteStrRaw(n),
token::Err => LitKind::Err,
token::Bool => unreachable!(),
}
}
}
impl ToInternal<token::LitKind> for LitKind {
fn to_internal(self) -> token::LitKind {
match self {
LitKind::Byte => token::Byte,
LitKind::Char => token::Char,
LitKind::Integer => token::Integer,
LitKind::Float => token::Float,
LitKind::Str => token::Str,
LitKind::StrRaw(n) => token::StrRaw(n),
LitKind::ByteStr => token::ByteStr,
LitKind::ByteStrRaw(n) => token::ByteStrRaw(n),
LitKind::Err => token::Err,
}
}
}
impl FromInternal<(TokenStream, &mut Rustc<'_, '_>)> for Vec<TokenTree<TokenStream, Span, Symbol>> {
fn from_internal((stream, rustc): (TokenStream, &mut Rustc<'_, '_>)) -> Self {
use rustc_ast::token::*;
@ -143,7 +175,14 @@ impl FromInternal<(TokenStream, &mut Rustc<'_, '_>)>
TokenTree::Ident(Ident { sym: ident.name, is_raw: false, span }),
]);
}
Literal(lit) => trees.push(TokenTree::Literal(self::Literal { lit, span })),
Literal(token::Lit { kind, symbol, suffix }) => {
trees.push(TokenTree::Literal(self::Literal {
kind: FromInternal::from_internal(kind),
symbol,
suffix,
span,
}));
}
DocComment(_, attr_style, data) => {
let mut escaped = String::new();
for ch in data.as_str().chars() {
@ -199,9 +238,7 @@ impl FromInternal<(TokenStream, &mut Rustc<'_, '_>)>
}
}
impl ToInternal<TokenStream>
for (TokenTree<TokenStream, Span, Symbol, Literal>, &mut Rustc<'_, '_>)
{
impl ToInternal<TokenStream> for (TokenTree<TokenStream, Span, Symbol>, &mut Rustc<'_, '_>) {
fn to_internal(self) -> TokenStream {
use rustc_ast::token::*;
@ -221,7 +258,9 @@ impl ToInternal<TokenStream>
return tokenstream::TokenTree::token(Ident(sym, is_raw), span).into();
}
TokenTree::Literal(self::Literal {
lit: token::Lit { kind: token::Integer, symbol, suffix },
kind: self::LitKind::Integer,
symbol,
suffix,
span,
}) if symbol.as_str().starts_with('-') => {
let minus = BinOp(BinOpToken::Minus);
@ -232,7 +271,9 @@ impl ToInternal<TokenStream>
return [a, b].into_iter().collect();
}
TokenTree::Literal(self::Literal {
lit: token::Lit { kind: token::Float, symbol, suffix },
kind: self::LitKind::Float,
symbol,
suffix,
span,
}) if symbol.as_str().starts_with('-') => {
let minus = BinOp(BinOpToken::Minus);
@ -242,8 +283,12 @@ impl ToInternal<TokenStream>
let b = tokenstream::TokenTree::token(float, span);
return [a, b].into_iter().collect();
}
TokenTree::Literal(self::Literal { lit, span }) => {
return tokenstream::TokenTree::token(Literal(lit), span).into();
TokenTree::Literal(self::Literal { kind, symbol, suffix, span }) => {
return tokenstream::TokenTree::token(
TokenKind::lit(kind.to_internal(), symbol, suffix),
span,
)
.into();
}
};
@ -292,13 +337,6 @@ impl ToInternal<rustc_errors::Level> for Level {
pub struct FreeFunctions;
// FIXME(eddyb) `Literal` should not expose internal `Debug` impls.
#[derive(Clone, Debug)]
pub struct Literal {
lit: token::Lit,
span: Span,
}
pub(crate) struct Rustc<'a, 'b> {
ecx: &'a mut ExtCtxt<'b>,
def_site: Span,
@ -324,16 +362,11 @@ impl<'a, 'b> Rustc<'a, 'b> {
fn sess(&self) -> &ParseSess {
self.ecx.parse_sess()
}
fn lit(&mut self, kind: token::LitKind, symbol: Symbol, suffix: Option<Symbol>) -> Literal {
Literal { lit: token::Lit::new(kind, symbol, suffix), span: self.call_site }
}
}
impl server::Types for Rustc<'_, '_> {
type FreeFunctions = FreeFunctions;
type TokenStream = TokenStream;
type Literal = Literal;
type SourceFile = Lrc<SourceFile>;
type MultiSpan = Vec<Span>;
type Diagnostic = Diagnostic;
@ -352,6 +385,94 @@ impl server::FreeFunctions for Rustc<'_, '_> {
fn track_path(&mut self, path: &str) {
self.sess().file_depinfo.borrow_mut().insert(Symbol::intern(path));
}
fn literal_from_str(&mut self, s: &str) -> Result<Literal<Self::Span, Self::Symbol>, ()> {
let name = FileName::proc_macro_source_code(s);
let mut parser = rustc_parse::new_parser_from_source_str(self.sess(), name, s.to_owned());
let first_span = parser.token.span.data();
let minus_present = parser.eat(&token::BinOp(token::Minus));
let lit_span = parser.token.span.data();
let token::Literal(mut lit) = parser.token.kind else {
return Err(());
};
// Check no comment or whitespace surrounding the (possibly negative)
// literal, or more tokens after it.
if (lit_span.hi.0 - first_span.lo.0) as usize != s.len() {
return Err(());
}
if minus_present {
// If minus is present, check no comment or whitespace in between it
// and the literal token.
if first_span.hi.0 != lit_span.lo.0 {
return Err(());
}
// Check literal is a kind we allow to be negated in a proc macro token.
match lit.kind {
token::LitKind::Bool
| token::LitKind::Byte
| token::LitKind::Char
| token::LitKind::Str
| token::LitKind::StrRaw(_)
| token::LitKind::ByteStr
| token::LitKind::ByteStrRaw(_)
| token::LitKind::Err => return Err(()),
token::LitKind::Integer | token::LitKind::Float => {}
}
// Synthesize a new symbol that includes the minus sign.
let symbol = Symbol::intern(&s[..1 + lit.symbol.as_str().len()]);
lit = token::Lit::new(lit.kind, symbol, lit.suffix);
}
let token::Lit { kind, symbol, suffix } = lit;
Ok(Literal {
kind: FromInternal::from_internal(kind),
symbol,
suffix,
span: self.call_site,
})
}
fn literal_subspan(
&mut self,
literal: Literal<Self::Span, Self::Symbol>,
start: Bound<usize>,
end: Bound<usize>,
) -> Option<Self::Span> {
let span = literal.span;
let length = span.hi().to_usize() - span.lo().to_usize();
let start = match start {
Bound::Included(lo) => lo,
Bound::Excluded(lo) => lo.checked_add(1)?,
Bound::Unbounded => 0,
};
let end = match end {
Bound::Included(hi) => hi.checked_add(1)?,
Bound::Excluded(hi) => hi,
Bound::Unbounded => length,
};
// Bounds check the values, preventing addition overflow and OOB spans.
if start > u32::MAX as usize
|| end > u32::MAX as usize
|| (u32::MAX - start as u32) < span.lo().to_u32()
|| (u32::MAX - end as u32) < span.lo().to_u32()
|| start >= end
|| end > length
{
return None;
}
let new_lo = span.lo() + BytePos::from_usize(start);
let new_hi = span.lo() + BytePos::from_usize(end);
Some(span.with_lo(new_lo).with_hi(new_hi))
}
}
impl server::TokenStream for Rustc<'_, '_> {
@ -429,7 +550,7 @@ impl server::TokenStream for Rustc<'_, '_> {
fn from_token_tree(
&mut self,
tree: TokenTree<Self::TokenStream, Self::Span, Self::Symbol, Self::Literal>,
tree: TokenTree<Self::TokenStream, Self::Span, Self::Symbol>,
) -> Self::TokenStream {
(tree, &mut *self).to_internal()
}
@ -437,7 +558,7 @@ impl server::TokenStream for Rustc<'_, '_> {
fn concat_trees(
&mut self,
base: Option<Self::TokenStream>,
trees: Vec<TokenTree<Self::TokenStream, Self::Span, Self::Symbol, Self::Literal>>,
trees: Vec<TokenTree<Self::TokenStream, Self::Span, Self::Symbol>>,
) -> Self::TokenStream {
let mut builder = tokenstream::TokenStreamBuilder::new();
if let Some(base) = base {
@ -467,164 +588,11 @@ impl server::TokenStream for Rustc<'_, '_> {
fn into_trees(
&mut self,
stream: Self::TokenStream,
) -> Vec<TokenTree<Self::TokenStream, Self::Span, Self::Symbol, Self::Literal>> {
) -> Vec<TokenTree<Self::TokenStream, Self::Span, Self::Symbol>> {
FromInternal::from_internal((stream, self))
}
}
impl server::Literal for Rustc<'_, '_> {
fn from_str(&mut self, s: &str) -> Result<Self::Literal, ()> {
let name = FileName::proc_macro_source_code(s);
let mut parser = rustc_parse::new_parser_from_source_str(self.sess(), name, s.to_owned());
let first_span = parser.token.span.data();
let minus_present = parser.eat(&token::BinOp(token::Minus));
let lit_span = parser.token.span.data();
let token::Literal(mut lit) = parser.token.kind else {
return Err(());
};
// Check no comment or whitespace surrounding the (possibly negative)
// literal, or more tokens after it.
if (lit_span.hi.0 - first_span.lo.0) as usize != s.len() {
return Err(());
}
if minus_present {
// If minus is present, check no comment or whitespace in between it
// and the literal token.
if first_span.hi.0 != lit_span.lo.0 {
return Err(());
}
// Check literal is a kind we allow to be negated in a proc macro token.
match lit.kind {
token::LitKind::Bool
| token::LitKind::Byte
| token::LitKind::Char
| token::LitKind::Str
| token::LitKind::StrRaw(_)
| token::LitKind::ByteStr
| token::LitKind::ByteStrRaw(_)
| token::LitKind::Err => return Err(()),
token::LitKind::Integer | token::LitKind::Float => {}
}
// Synthesize a new symbol that includes the minus sign.
let symbol = Symbol::intern(&s[..1 + lit.symbol.as_str().len()]);
lit = token::Lit::new(lit.kind, symbol, lit.suffix);
}
Ok(Literal { lit, span: self.call_site })
}
fn to_string(&mut self, literal: &Self::Literal) -> String {
literal.lit.to_string()
}
fn debug_kind(&mut self, literal: &Self::Literal) -> String {
format!("{:?}", literal.lit.kind)
}
fn symbol(&mut self, literal: &Self::Literal) -> String {
literal.lit.symbol.to_string()
}
fn suffix(&mut self, literal: &Self::Literal) -> Option<String> {
literal.lit.suffix.as_ref().map(Symbol::to_string)
}
fn integer(&mut self, n: &str) -> Self::Literal {
self.lit(token::Integer, Symbol::intern(n), None)
}
fn typed_integer(&mut self, n: &str, kind: &str) -> Self::Literal {
self.lit(token::Integer, Symbol::intern(n), Some(Symbol::intern(kind)))
}
fn float(&mut self, n: &str) -> Self::Literal {
self.lit(token::Float, Symbol::intern(n), None)
}
fn f32(&mut self, n: &str) -> Self::Literal {
self.lit(token::Float, Symbol::intern(n), Some(sym::f32))
}
fn f64(&mut self, n: &str) -> Self::Literal {
self.lit(token::Float, Symbol::intern(n), Some(sym::f64))
}
fn string(&mut self, string: &str) -> Self::Literal {
let quoted = format!("{:?}", string);
assert!(quoted.starts_with('"') && quoted.ends_with('"'));
let symbol = &quoted[1..quoted.len() - 1];
self.lit(token::Str, Symbol::intern(symbol), None)
}
fn character(&mut self, ch: char) -> Self::Literal {
let quoted = format!("{:?}", ch);
assert!(quoted.starts_with('\'') && quoted.ends_with('\''));
let symbol = &quoted[1..quoted.len() - 1];
self.lit(token::Char, Symbol::intern(symbol), None)
}
fn byte_string(&mut self, bytes: &[u8]) -> Self::Literal {
let string = bytes
.iter()
.cloned()
.flat_map(ascii::escape_default)
.map(Into::<char>::into)
.collect::<String>();
self.lit(token::ByteStr, Symbol::intern(&string), None)
}
fn span(&mut self, literal: &Self::Literal) -> Self::Span {
literal.span
}
fn set_span(&mut self, literal: &mut Self::Literal, span: Self::Span) {
literal.span = span;
}
fn subspan(
&mut self,
literal: &Self::Literal,
start: Bound<usize>,
end: Bound<usize>,
) -> Option<Self::Span> {
let span = literal.span;
let length = span.hi().to_usize() - span.lo().to_usize();
let start = match start {
Bound::Included(lo) => lo,
Bound::Excluded(lo) => lo.checked_add(1)?,
Bound::Unbounded => 0,
};
let end = match end {
Bound::Included(hi) => hi.checked_add(1)?,
Bound::Excluded(hi) => hi,
Bound::Unbounded => length,
};
// Bounds check the values, preventing addition overflow and OOB spans.
if start > u32::MAX as usize
|| end > u32::MAX as usize
|| (u32::MAX - start as u32) < span.lo().to_u32()
|| (u32::MAX - end as u32) < span.lo().to_u32()
|| start >= end
|| end > length
{
return None;
}
let new_lo = span.lo() + BytePos::from_usize(start);
let new_hi = span.lo() + BytePos::from_usize(end);
Some(span.with_lo(new_lo).with_hi(new_hi))
}
}
impl server::SourceFile for Rustc<'_, '_> {
fn eq(&mut self, file1: &Self::SourceFile, file2: &Self::SourceFile) -> bool {
Lrc::ptr_eq(file1, file2)