1
Fork 0

Use a parallel_guard function to handle the parallel guard

This commit is contained in:
John Kåre Alsaker 2023-08-25 22:16:21 +02:00
parent d36393b839
commit c303c8abdd

View file

@ -115,11 +115,6 @@ pub struct ParallelGuard {
} }
impl ParallelGuard { impl ParallelGuard {
#[inline]
pub fn new() -> Self {
ParallelGuard { panic: Mutex::new(None) }
}
pub fn run<R>(&self, f: impl FnOnce() -> R) -> Option<R> { pub fn run<R>(&self, f: impl FnOnce() -> R) -> Option<R> {
catch_unwind(AssertUnwindSafe(f)) catch_unwind(AssertUnwindSafe(f))
.map_err(|err| { .map_err(|err| {
@ -127,13 +122,18 @@ impl ParallelGuard {
}) })
.ok() .ok()
} }
}
#[inline] /// This gives access to a fresh parallel guard in the closure and will unwind any panics
pub fn unwind(self) { /// caught in it after the closure returns.
if let Some(panic) = self.panic.into_inner() { #[inline]
resume_unwind(panic); pub fn parallel_guard<R>(f: impl FnOnce(&ParallelGuard) -> R) -> R {
} let guard = ParallelGuard { panic: Mutex::new(None) };
let ret = f(&guard);
if let Some(panic) = guard.panic.into_inner() {
resume_unwind(panic);
} }
ret
} }
cfg_if! { cfg_if! {
@ -231,38 +231,38 @@ cfg_if! {
where A: FnOnce() -> RA, where A: FnOnce() -> RA,
B: FnOnce() -> RB B: FnOnce() -> RB
{ {
let guard = ParallelGuard::new(); let (a, b) = parallel_guard(|guard| {
let a = guard.run(oper_a); let a = guard.run(oper_a);
let b = guard.run(oper_b); let b = guard.run(oper_b);
guard.unwind(); (a, b)
});
(a.unwrap(), b.unwrap()) (a.unwrap(), b.unwrap())
} }
#[macro_export] #[macro_export]
macro_rules! parallel { macro_rules! parallel {
($($blocks:block),*) => {{ ($($blocks:block),*) => {{
let mut guard = $crate::sync::ParallelGuard::new(); $crate::sync::parallel_guard(|guard| {
$(guard.run(|| $blocks);)* $(guard.run(|| $blocks);)*
guard.unwind(); });
}} }}
} }
pub fn par_for_each_in<T: IntoIterator>(t: T, mut for_each: impl FnMut(T::Item) + Sync + Send) { pub fn par_for_each_in<T: IntoIterator>(t: T, mut for_each: impl FnMut(T::Item) + Sync + Send) {
let guard = ParallelGuard::new(); parallel_guard(|guard| {
t.into_iter().for_each(|i| { t.into_iter().for_each(|i| {
guard.run(|| for_each(i)); guard.run(|| for_each(i));
}); });
guard.unwind(); })
} }
pub fn par_map<T: IntoIterator, R, C: FromIterator<R>>( pub fn par_map<T: IntoIterator, R, C: FromIterator<R>>(
t: T, t: T,
mut map: impl FnMut(<<T as IntoIterator>::IntoIter as Iterator>::Item) -> R, mut map: impl FnMut(<<T as IntoIterator>::IntoIter as Iterator>::Item) -> R,
) -> C { ) -> C {
let guard = ParallelGuard::new(); parallel_guard(|guard| {
let r = t.into_iter().filter_map(|i| guard.run(|| map(i))).collect(); t.into_iter().filter_map(|i| guard.run(|| map(i))).collect()
guard.unwind(); })
r
} }
pub use std::rc::Rc as Lrc; pub use std::rc::Rc as Lrc;
@ -382,10 +382,11 @@ cfg_if! {
let (a, b) = rayon::join(move || FromDyn::from(oper_a.into_inner()()), move || FromDyn::from(oper_b.into_inner()())); let (a, b) = rayon::join(move || FromDyn::from(oper_a.into_inner()()), move || FromDyn::from(oper_b.into_inner()()));
(a.into_inner(), b.into_inner()) (a.into_inner(), b.into_inner())
} else { } else {
let guard = ParallelGuard::new(); let (a, b) = parallel_guard(|guard| {
let a = guard.run(oper_a); let a = guard.run(oper_a);
let b = guard.run(oper_b); let b = guard.run(oper_b);
guard.unwind(); (a, b)
});
(a.unwrap(), b.unwrap()) (a.unwrap(), b.unwrap())
} }
} }
@ -421,10 +422,10 @@ cfg_if! {
// of a single threaded rustc. // of a single threaded rustc.
parallel!(impl $fblock [] [$($blocks),*]); parallel!(impl $fblock [] [$($blocks),*]);
} else { } else {
let guard = $crate::sync::ParallelGuard::new(); $crate::sync::parallel_guard(|guard| {
guard.run(|| $fblock); guard.run(|| $fblock);
$(guard.run(|| $blocks);)* $(guard.run(|| $blocks);)*
guard.unwind(); });
} }
}; };
} }
@ -435,20 +436,18 @@ cfg_if! {
t: T, t: T,
for_each: impl Fn(I) + DynSync + DynSend for_each: impl Fn(I) + DynSync + DynSend
) { ) {
if mode::is_dyn_thread_safe() { parallel_guard(|guard| {
let for_each = FromDyn::from(for_each); if mode::is_dyn_thread_safe() {
let guard = ParallelGuard::new(); let for_each = FromDyn::from(for_each);
t.into_par_iter().for_each(|i| { t.into_par_iter().for_each(|i| {
guard.run(|| for_each(i)); guard.run(|| for_each(i));
}); });
guard.unwind(); } else {
} else { t.into_iter().for_each(|i| {
let guard = ParallelGuard::new(); guard.run(|| for_each(i));
t.into_iter().for_each(|i| { });
guard.run(|| for_each(i)); }
}); });
guard.unwind();
}
} }
pub fn par_map< pub fn par_map<
@ -460,18 +459,14 @@ cfg_if! {
t: T, t: T,
map: impl Fn(I) -> R + DynSync + DynSend map: impl Fn(I) -> R + DynSync + DynSend
) -> C { ) -> C {
if mode::is_dyn_thread_safe() { parallel_guard(|guard| {
let map = FromDyn::from(map); if mode::is_dyn_thread_safe() {
let guard = ParallelGuard::new(); let map = FromDyn::from(map);
let r = t.into_par_iter().filter_map(|i| guard.run(|| map(i))).collect(); t.into_par_iter().filter_map(|i| guard.run(|| map(i))).collect()
guard.unwind(); } else {
r t.into_iter().filter_map(|i| guard.run(|| map(i))).collect()
} else { }
let guard = ParallelGuard::new(); })
let r = t.into_iter().filter_map(|i| guard.run(|| map(i))).collect();
guard.unwind();
r
}
} }
/// This makes locks panic if they are already held. /// This makes locks panic if they are already held.