diff --git a/src/librustc/middle/expr_use_visitor.rs b/src/librustc/middle/expr_use_visitor.rs index 2ca0069560c..a44679b0b3e 100644 --- a/src/librustc/middle/expr_use_visitor.rs +++ b/src/librustc/middle/expr_use_visitor.rs @@ -715,6 +715,7 @@ impl<'a, 'gcx, 'tcx> ExprUseVisitor<'a, 'gcx, 'tcx> { adjustment::Adjust::NeverToAny | adjustment::Adjust::ReifyFnPointer | adjustment::Adjust::UnsafeFnPointer | + adjustment::Adjust::ClosureFnPointer | adjustment::Adjust::MutToConstPointer => { // Creating a closure/fn-pointer or unsizing consumes // the input and stores it into the resulting rvalue. diff --git a/src/librustc/middle/mem_categorization.rs b/src/librustc/middle/mem_categorization.rs index 627753039ba..b0c85e2ef4c 100644 --- a/src/librustc/middle/mem_categorization.rs +++ b/src/librustc/middle/mem_categorization.rs @@ -464,6 +464,7 @@ impl<'a, 'gcx, 'tcx> MemCategorizationContext<'a, 'gcx, 'tcx> { adjustment::Adjust::NeverToAny | adjustment::Adjust::ReifyFnPointer | adjustment::Adjust::UnsafeFnPointer | + adjustment::Adjust::ClosureFnPointer | adjustment::Adjust::MutToConstPointer | adjustment::Adjust::DerefRef {..} => { debug!("cat_expr({:?}): {:?}", diff --git a/src/librustc/mir/mod.rs b/src/librustc/mir/mod.rs index 3403cf04774..d2a657e35b5 100644 --- a/src/librustc/mir/mod.rs +++ b/src/librustc/mir/mod.rs @@ -1022,6 +1022,9 @@ pub enum CastKind { /// Convert unique, zero-sized type for a fn to fn() ReifyFnPointer, + /// Convert non capturing closure to fn() + ClosureFnPointer, + /// Convert safe fn() to unsafe fn() UnsafeFnPointer, diff --git a/src/librustc/ty/adjustment.rs b/src/librustc/ty/adjustment.rs index 333a5c74cb4..34977822bc6 100644 --- a/src/librustc/ty/adjustment.rs +++ b/src/librustc/ty/adjustment.rs @@ -33,6 +33,9 @@ pub enum Adjust<'tcx> { /// Go from a safe fn pointer to an unsafe fn pointer. UnsafeFnPointer, + // Go from a non-capturing closure to an fn pointer. + ClosureFnPointer, + /// Go from a mut raw pointer to a const raw pointer. MutToConstPointer, @@ -120,6 +123,7 @@ impl<'tcx> Adjustment<'tcx> { Adjust::ReifyFnPointer | Adjust::UnsafeFnPointer | + Adjust::ClosureFnPointer | Adjust::MutToConstPointer | Adjust::DerefRef {..} => false, } diff --git a/src/librustc_mir/build/expr/as_lvalue.rs b/src/librustc_mir/build/expr/as_lvalue.rs index 0487e277a33..5abfe084f22 100644 --- a/src/librustc_mir/build/expr/as_lvalue.rs +++ b/src/librustc_mir/build/expr/as_lvalue.rs @@ -99,6 +99,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> { ExprKind::Use { .. } | ExprKind::NeverToAny { .. } | ExprKind::ReifyFnPointer { .. } | + ExprKind::ClosureFnPointer { .. } | ExprKind::UnsafeFnPointer { .. } | ExprKind::Unsize { .. } | ExprKind::Repeat { .. } | diff --git a/src/librustc_mir/build/expr/as_rvalue.rs b/src/librustc_mir/build/expr/as_rvalue.rs index 7adcc0e730b..7f5d9c36ece 100644 --- a/src/librustc_mir/build/expr/as_rvalue.rs +++ b/src/librustc_mir/build/expr/as_rvalue.rs @@ -112,6 +112,10 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> { let source = unpack!(block = this.as_operand(block, source)); block.and(Rvalue::Cast(CastKind::UnsafeFnPointer, source, expr.ty)) } + ExprKind::ClosureFnPointer { source } => { + let source = unpack!(block = this.as_operand(block, source)); + block.and(Rvalue::Cast(CastKind::ClosureFnPointer, source, expr.ty)) + } ExprKind::Unsize { source } => { let source = unpack!(block = this.as_operand(block, source)); block.and(Rvalue::Cast(CastKind::Unsize, source, expr.ty)) diff --git a/src/librustc_mir/build/expr/category.rs b/src/librustc_mir/build/expr/category.rs index 6e57c10964c..35173bb598c 100644 --- a/src/librustc_mir/build/expr/category.rs +++ b/src/librustc_mir/build/expr/category.rs @@ -70,6 +70,7 @@ impl Category { ExprKind::Cast { .. } | ExprKind::Use { .. } | ExprKind::ReifyFnPointer { .. } | + ExprKind::ClosureFnPointer { .. } | ExprKind::UnsafeFnPointer { .. } | ExprKind::Unsize { .. } | ExprKind::Repeat { .. } | diff --git a/src/librustc_mir/build/expr/into.rs b/src/librustc_mir/build/expr/into.rs index e66f2b4e2bf..d9f71e36e21 100644 --- a/src/librustc_mir/build/expr/into.rs +++ b/src/librustc_mir/build/expr/into.rs @@ -244,6 +244,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> { ExprKind::Cast { .. } | ExprKind::Use { .. } | ExprKind::ReifyFnPointer { .. } | + ExprKind::ClosureFnPointer { .. } | ExprKind::UnsafeFnPointer { .. } | ExprKind::Unsize { .. } | ExprKind::Repeat { .. } | diff --git a/src/librustc_mir/hair/cx/expr.rs b/src/librustc_mir/hair/cx/expr.rs index 7eaf1fe1398..52e64463768 100644 --- a/src/librustc_mir/hair/cx/expr.rs +++ b/src/librustc_mir/hair/cx/expr.rs @@ -60,6 +60,15 @@ impl<'tcx> Mirror<'tcx> for &'tcx hir::Expr { kind: ExprKind::UnsafeFnPointer { source: expr.to_ref() }, }; } + Some((ty::adjustment::Adjust::ClosureFnPointer, adjusted_ty)) => { + expr = Expr { + temp_lifetime: temp_lifetime, + temp_lifetime_was_shrunk: was_shrunk, + ty: adjusted_ty, + span: self.span, + kind: ExprKind::ClosureFnPointer { source: expr.to_ref() }, + }; + } Some((ty::adjustment::Adjust::NeverToAny, adjusted_ty)) => { expr = Expr { temp_lifetime: temp_lifetime, diff --git a/src/librustc_mir/hair/mod.rs b/src/librustc_mir/hair/mod.rs index 4ac67cfb2fc..4ab45e14c99 100644 --- a/src/librustc_mir/hair/mod.rs +++ b/src/librustc_mir/hair/mod.rs @@ -152,6 +152,9 @@ pub enum ExprKind<'tcx> { ReifyFnPointer { source: ExprRef<'tcx>, }, + ClosureFnPointer { + source: ExprRef<'tcx>, + }, UnsafeFnPointer { source: ExprRef<'tcx>, }, diff --git a/src/librustc_mir/transform/qualify_consts.rs b/src/librustc_mir/transform/qualify_consts.rs index 4459142cfb2..04e809ef9d8 100644 --- a/src/librustc_mir/transform/qualify_consts.rs +++ b/src/librustc_mir/transform/qualify_consts.rs @@ -619,6 +619,7 @@ impl<'a, 'tcx> Visitor<'tcx> for Qualifier<'a, 'tcx, 'tcx> { Rvalue::CheckedBinaryOp(..) | Rvalue::Cast(CastKind::ReifyFnPointer, ..) | Rvalue::Cast(CastKind::UnsafeFnPointer, ..) | + Rvalue::Cast(CastKind::ClosureFnPointer, ..) | Rvalue::Cast(CastKind::Unsize, ..) => {} Rvalue::Len(_) => { diff --git a/src/librustc_passes/consts.rs b/src/librustc_passes/consts.rs index 0b55513f831..e3772a09968 100644 --- a/src/librustc_passes/consts.rs +++ b/src/librustc_passes/consts.rs @@ -447,6 +447,7 @@ fn check_adjustments<'a, 'tcx>(v: &mut CheckCrateVisitor<'a, 'tcx>, e: &hir::Exp Some(Adjust::NeverToAny) | Some(Adjust::ReifyFnPointer) | Some(Adjust::UnsafeFnPointer) | + Some(Adjust::ClosureFnPointer) | Some(Adjust::MutToConstPointer) => {} Some(Adjust::DerefRef { autoderefs, .. }) => { diff --git a/src/librustc_trans/builder.rs b/src/librustc_trans/builder.rs index f64e581c177..66722f883d7 100644 --- a/src/librustc_trans/builder.rs +++ b/src/librustc_trans/builder.rs @@ -1181,7 +1181,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } assert!(fn_ty.kind() == llvm::TypeKind::Function, - "builder::{} not passed a function", typ); + "builder::{} not passed a function, but {:?}", typ, fn_ty); let param_tys = fn_ty.func_params(); diff --git a/src/librustc_trans/collector.rs b/src/librustc_trans/collector.rs index b5f948442b7..61b2c108fed 100644 --- a/src/librustc_trans/collector.rs +++ b/src/librustc_trans/collector.rs @@ -489,6 +489,20 @@ impl<'a, 'tcx> MirVisitor<'tcx> for MirNeighborCollector<'a, 'tcx> { self.output); } } + mir::Rvalue::Cast(mir::CastKind::ClosureFnPointer, ref operand, _) => { + let source_ty = operand.ty(self.mir, self.scx.tcx()); + match source_ty.sty { + ty::TyClosure(def_id, substs) => { + let closure_trans_item = + create_fn_trans_item(self.scx, + def_id, + substs.substs, + self.param_substs); + self.output.push(closure_trans_item); + } + _ => bug!(), + } + } mir::Rvalue::Box(..) => { let exchange_malloc_fn_def_id = self.scx diff --git a/src/librustc_trans/mir/constant.rs b/src/librustc_trans/mir/constant.rs index 7e17ae5f1d3..e7582061b54 100644 --- a/src/librustc_trans/mir/constant.rs +++ b/src/librustc_trans/mir/constant.rs @@ -578,6 +578,35 @@ impl<'a, 'tcx> MirConstContext<'a, 'tcx> { } } } + mir::CastKind::ClosureFnPointer => { + match operand.ty.sty { + ty::TyClosure(def_id, substs) => { + // Get the def_id for FnOnce::call_once + let fn_once = tcx.lang_items.fn_once_trait().unwrap(); + let call_once = tcx + .global_tcx().associated_items(fn_once) + .find(|it| it.kind == ty::AssociatedKind::Method) + .unwrap().def_id; + // Now create its substs [Closure, Tuple] + let input = tcx.closure_type(def_id, substs).sig.input(0); + let substs = Substs::for_item(tcx, + call_once, + |_, _| {bug!()}, + |def, _| { match def.index { + 0 => operand.ty.clone(), + 1 => input.skip_binder(), + _ => bug!(), + } } + ); + + Callee::def(self.ccx, call_once, substs) + .reify(self.ccx) + } + _ => { + bug!("{} cannot be cast to a fn ptr", operand.ty) + } + } + } mir::CastKind::UnsafeFnPointer => { // this is a no-op at the LLVM level operand.llval diff --git a/src/librustc_trans/mir/rvalue.rs b/src/librustc_trans/mir/rvalue.rs index 7d4f542addb..95bb1dd7e0f 100644 --- a/src/librustc_trans/mir/rvalue.rs +++ b/src/librustc_trans/mir/rvalue.rs @@ -12,6 +12,7 @@ use llvm::{self, ValueRef}; use rustc::ty::{self, Ty}; use rustc::ty::cast::{CastTy, IntTy}; use rustc::ty::layout::Layout; +use rustc::ty::subst::Substs; use rustc::mir::tcx::LvalueTy; use rustc::mir; use middle::lang_items::ExchangeMallocFnLangItem; @@ -190,6 +191,36 @@ impl<'a, 'tcx> MirContext<'a, 'tcx> { } } } + mir::CastKind::ClosureFnPointer => { + match operand.ty.sty { + ty::TyClosure(def_id, substs) => { + // Get the def_id for FnOnce::call_once + let fn_once = bcx.tcx().lang_items.fn_once_trait().unwrap(); + let call_once = bcx.tcx() + .global_tcx().associated_items(fn_once) + .find(|it| it.kind == ty::AssociatedKind::Method) + .unwrap().def_id; + // Now create its substs [Closure, Tuple] + let input = bcx.tcx().closure_type(def_id, substs).sig.input(0); + let substs = Substs::for_item(bcx.tcx(), + call_once, + |_, _| {bug!()}, + |def, _| { match def.index { + 0 => operand.ty.clone(), + 1 => input.skip_binder(), + _ => bug!(), + } } + ); + + OperandValue::Immediate( + Callee::def(bcx.ccx, call_once, substs) + .reify(bcx.ccx)) + } + _ => { + bug!("{} cannot be cast to a fn ptr", operand.ty) + } + } + } mir::CastKind::UnsafeFnPointer => { // this is a no-op at the LLVM level operand.val diff --git a/src/librustc_typeck/check/coercion.rs b/src/librustc_typeck/check/coercion.rs index 718c273785a..e9ac0a58d36 100644 --- a/src/librustc_typeck/check/coercion.rs +++ b/src/librustc_typeck/check/coercion.rs @@ -63,13 +63,17 @@ use check::FnCtxt; use rustc::hir; +use rustc::hir::def_id::DefId; use rustc::infer::{Coercion, InferOk, TypeTrace}; use rustc::traits::{self, ObligationCause, ObligationCauseCode}; use rustc::ty::adjustment::{Adjustment, Adjust, AutoBorrow}; -use rustc::ty::{self, LvaluePreference, TypeAndMut, Ty}; +use rustc::ty::{self, LvaluePreference, TypeVariants, TypeAndMut, + Ty, ClosureSubsts}; use rustc::ty::fold::TypeFoldable; use rustc::ty::error::TypeError; use rustc::ty::relate::RelateResult; +use syntax::ast::NodeId; +use syntax::abi; use util::common::indent; use std::cell::RefCell; @@ -196,6 +200,11 @@ impl<'f, 'gcx, 'tcx> Coerce<'f, 'gcx, 'tcx> { // unsafe qualifier. self.coerce_from_fn_pointer(a, a_f, b) } + ty::TyClosure(def_id_a, substs_a) => { + // Non-capturing closures are coercible to + // function pointers + self.coerce_closure_to_fn(a, def_id_a, substs_a, b) + } _ => { // Otherwise, just use unification rules. self.unify_and_identity(a, b) @@ -551,6 +560,52 @@ impl<'f, 'gcx, 'tcx> Coerce<'f, 'gcx, 'tcx> { } } + fn coerce_closure_to_fn(&self, + a: Ty<'tcx>, + def_id_a: DefId, + substs_a: ClosureSubsts<'tcx>, + b: Ty<'tcx>) + -> CoerceResult<'tcx> { + //! Attempts to coerce from the type of a non-capturing closure + //! into a function pointer. + //! + + let b = self.shallow_resolve(b); + + let node_id_a :NodeId = self.tcx.hir.as_local_node_id(def_id_a).unwrap(); + match b.sty { + ty::TyFnPtr(_) if self.tcx.with_freevars(node_id_a, |v| v.is_empty()) => { + // We coerce the closure, which has fn type + // `extern "rust-call" fn((arg0,arg1,...)) -> _` + // to + // `fn(arg0,arg1,...) -> _` + let sig = self.closure_type(def_id_a, substs_a).sig; + let converted_sig = sig.input(0).map_bound(|v| { + let params_iter = match v.sty { + TypeVariants::TyTuple(params, _) => { + params.into_iter().cloned() + } + _ => bug!(), + }; + self.tcx.mk_fn_sig(params_iter, + sig.output().skip_binder(), + sig.variadic()) + }); + let fn_ty = self.tcx.mk_bare_fn(ty::BareFnTy { + unsafety: hir::Unsafety::Normal, + abi: abi::Abi::Rust, + sig: converted_sig, + }); + let pointer_ty = self.tcx.mk_fn_ptr(&fn_ty); + debug!("coerce_closure_to_fn(a={:?}, b={:?}, pty={:?})", + a, b, pointer_ty); + self.unify_and_identity(pointer_ty, b) + .map(|(ty, _)| (ty, Adjust::ClosureFnPointer)) + } + _ => self.unify_and_identity(a, b), + } + } + fn coerce_unsafe_ptr(&self, a: Ty<'tcx>, b: Ty<'tcx>, diff --git a/src/librustc_typeck/check/writeback.rs b/src/librustc_typeck/check/writeback.rs index a2922270583..a25e5f3f283 100644 --- a/src/librustc_typeck/check/writeback.rs +++ b/src/librustc_typeck/check/writeback.rs @@ -412,6 +412,10 @@ impl<'cx, 'gcx, 'tcx> WritebackCx<'cx, 'gcx, 'tcx> { adjustment::Adjust::MutToConstPointer } + adjustment::Adjust::ClosureFnPointer => { + adjustment::Adjust::ClosureFnPointer + } + adjustment::Adjust::UnsafeFnPointer => { adjustment::Adjust::UnsafeFnPointer } diff --git a/src/test/run-pass/closure-to-fn-coercion.rs b/src/test/run-pass/closure-to-fn-coercion.rs new file mode 100644 index 00000000000..c4d0bbdd070 --- /dev/null +++ b/src/test/run-pass/closure-to-fn-coercion.rs @@ -0,0 +1,41 @@ +// Copyright 2017 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// ignore-stage0: new feature, remove this when SNAP + +// #![feature(closure_to_fn_coercion)] + +const FOO :fn(u8) -> u8 = |v: u8| { v }; + +const BAR: [fn(&mut u32); 5] = [ + |_: &mut u32| {}, + |v: &mut u32| *v += 1, + |v: &mut u32| *v += 2, + |v: &mut u32| *v += 3, + |v: &mut u32| *v += 4, +]; +fn func_specific() -> (fn() -> u32) { + || return 42 +} + +fn main() { + // Items + assert_eq!(func_specific()(), 42); + let foo :fn(u8) -> u8 = |v: u8| { v }; + assert_eq!(foo(31), 31); + // Constants + assert_eq!(FOO(31), 31); + let mut a :u32 = 0; + assert_eq!({BAR[0](&mut a); a }, 0); + assert_eq!({BAR[1](&mut a); a }, 1); + assert_eq!({BAR[2](&mut a); a }, 3); + assert_eq!({BAR[3](&mut a); a }, 6); + assert_eq!({BAR[4](&mut a); a }, 10); +}