Rollup merge of #117953 - farnoy:masked-load-store, r=workingjubilee
Add more SIMD platform-intrinsics - [x] simd_masked_load - [x] LLVM codegen - llvm.masked.load - [x] cranelift codegen - implemented but untested - [ ] simd_masked_store - [x] LLVM codegen - llvm.masked.store - [ ] cranelift codegen Also added a run-pass test to test both intrinsics, and additional build-fail & check-fail to cover validation for both intrinsics
This commit is contained in:
commit
c57b0549af
11 changed files with 594 additions and 1 deletions
|
@ -1492,6 +1492,198 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
|
|||
return Ok(v);
|
||||
}
|
||||
|
||||
if name == sym::simd_masked_load {
|
||||
// simd_masked_load(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
|
||||
// * N: number of elements in the input vectors
|
||||
// * T: type of the element to load
|
||||
// * M: any integer width is supported, will be truncated to i1
|
||||
// Loads contiguous elements from memory behind `pointer`, but only for
|
||||
// those lanes whose `mask` bit is enabled.
|
||||
// The memory addresses corresponding to the “off” lanes are not accessed.
|
||||
|
||||
// The element type of the "mask" argument must be a signed integer type of any width
|
||||
let mask_ty = in_ty;
|
||||
let (mask_len, mask_elem) = (in_len, in_elem);
|
||||
|
||||
// The second argument must be a pointer matching the element type
|
||||
let pointer_ty = arg_tys[1];
|
||||
|
||||
// The last argument is a passthrough vector providing values for disabled lanes
|
||||
let values_ty = arg_tys[2];
|
||||
let (values_len, values_elem) = require_simd!(values_ty, SimdThird);
|
||||
|
||||
require_simd!(ret_ty, SimdReturn);
|
||||
|
||||
// Of the same length:
|
||||
require!(
|
||||
values_len == mask_len,
|
||||
InvalidMonomorphization::ThirdArgumentLength {
|
||||
span,
|
||||
name,
|
||||
in_len: mask_len,
|
||||
in_ty: mask_ty,
|
||||
arg_ty: values_ty,
|
||||
out_len: values_len
|
||||
}
|
||||
);
|
||||
|
||||
// The return type must match the last argument type
|
||||
require!(
|
||||
ret_ty == values_ty,
|
||||
InvalidMonomorphization::ExpectedReturnType { span, name, in_ty: values_ty, ret_ty }
|
||||
);
|
||||
|
||||
require!(
|
||||
matches!(
|
||||
pointer_ty.kind(),
|
||||
ty::RawPtr(p) if p.ty == values_elem && p.ty.kind() == values_elem.kind()
|
||||
),
|
||||
InvalidMonomorphization::ExpectedElementType {
|
||||
span,
|
||||
name,
|
||||
expected_element: values_elem,
|
||||
second_arg: pointer_ty,
|
||||
in_elem: values_elem,
|
||||
in_ty: values_ty,
|
||||
mutability: ExpectedPointerMutability::Not,
|
||||
}
|
||||
);
|
||||
|
||||
require!(
|
||||
matches!(mask_elem.kind(), ty::Int(_)),
|
||||
InvalidMonomorphization::ThirdArgElementType {
|
||||
span,
|
||||
name,
|
||||
expected_element: values_elem,
|
||||
third_arg: mask_ty,
|
||||
}
|
||||
);
|
||||
|
||||
// Alignment of T, must be a constant integer value:
|
||||
let alignment_ty = bx.type_i32();
|
||||
let alignment = bx.const_i32(bx.align_of(values_ty).bytes() as i32);
|
||||
|
||||
// Truncate the mask vector to a vector of i1s:
|
||||
let (mask, mask_ty) = {
|
||||
let i1 = bx.type_i1();
|
||||
let i1xn = bx.type_vector(i1, mask_len);
|
||||
(bx.trunc(args[0].immediate(), i1xn), i1xn)
|
||||
};
|
||||
|
||||
let llvm_pointer = bx.type_ptr();
|
||||
|
||||
// Type of the vector of elements:
|
||||
let llvm_elem_vec_ty = llvm_vector_ty(bx, values_elem, values_len);
|
||||
let llvm_elem_vec_str = llvm_vector_str(bx, values_elem, values_len);
|
||||
|
||||
let llvm_intrinsic = format!("llvm.masked.load.{llvm_elem_vec_str}.p0");
|
||||
let fn_ty = bx
|
||||
.type_func(&[llvm_pointer, alignment_ty, mask_ty, llvm_elem_vec_ty], llvm_elem_vec_ty);
|
||||
let f = bx.declare_cfn(&llvm_intrinsic, llvm::UnnamedAddr::No, fn_ty);
|
||||
let v = bx.call(
|
||||
fn_ty,
|
||||
None,
|
||||
None,
|
||||
f,
|
||||
&[args[1].immediate(), alignment, mask, args[2].immediate()],
|
||||
None,
|
||||
);
|
||||
return Ok(v);
|
||||
}
|
||||
|
||||
if name == sym::simd_masked_store {
|
||||
// simd_masked_store(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
|
||||
// * N: number of elements in the input vectors
|
||||
// * T: type of the element to load
|
||||
// * M: any integer width is supported, will be truncated to i1
|
||||
// Stores contiguous elements to memory behind `pointer`, but only for
|
||||
// those lanes whose `mask` bit is enabled.
|
||||
// The memory addresses corresponding to the “off” lanes are not accessed.
|
||||
|
||||
// The element type of the "mask" argument must be a signed integer type of any width
|
||||
let mask_ty = in_ty;
|
||||
let (mask_len, mask_elem) = (in_len, in_elem);
|
||||
|
||||
// The second argument must be a pointer matching the element type
|
||||
let pointer_ty = arg_tys[1];
|
||||
|
||||
// The last argument specifies the values to store to memory
|
||||
let values_ty = arg_tys[2];
|
||||
let (values_len, values_elem) = require_simd!(values_ty, SimdThird);
|
||||
|
||||
// Of the same length:
|
||||
require!(
|
||||
values_len == mask_len,
|
||||
InvalidMonomorphization::ThirdArgumentLength {
|
||||
span,
|
||||
name,
|
||||
in_len: mask_len,
|
||||
in_ty: mask_ty,
|
||||
arg_ty: values_ty,
|
||||
out_len: values_len
|
||||
}
|
||||
);
|
||||
|
||||
// The second argument must be a mutable pointer type matching the element type
|
||||
require!(
|
||||
matches!(
|
||||
pointer_ty.kind(),
|
||||
ty::RawPtr(p) if p.ty == values_elem && p.ty.kind() == values_elem.kind() && p.mutbl.is_mut()
|
||||
),
|
||||
InvalidMonomorphization::ExpectedElementType {
|
||||
span,
|
||||
name,
|
||||
expected_element: values_elem,
|
||||
second_arg: pointer_ty,
|
||||
in_elem: values_elem,
|
||||
in_ty: values_ty,
|
||||
mutability: ExpectedPointerMutability::Mut,
|
||||
}
|
||||
);
|
||||
|
||||
require!(
|
||||
matches!(mask_elem.kind(), ty::Int(_)),
|
||||
InvalidMonomorphization::ThirdArgElementType {
|
||||
span,
|
||||
name,
|
||||
expected_element: values_elem,
|
||||
third_arg: mask_ty,
|
||||
}
|
||||
);
|
||||
|
||||
// Alignment of T, must be a constant integer value:
|
||||
let alignment_ty = bx.type_i32();
|
||||
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
|
||||
|
||||
// Truncate the mask vector to a vector of i1s:
|
||||
let (mask, mask_ty) = {
|
||||
let i1 = bx.type_i1();
|
||||
let i1xn = bx.type_vector(i1, in_len);
|
||||
(bx.trunc(args[0].immediate(), i1xn), i1xn)
|
||||
};
|
||||
|
||||
let ret_t = bx.type_void();
|
||||
|
||||
let llvm_pointer = bx.type_ptr();
|
||||
|
||||
// Type of the vector of elements:
|
||||
let llvm_elem_vec_ty = llvm_vector_ty(bx, values_elem, values_len);
|
||||
let llvm_elem_vec_str = llvm_vector_str(bx, values_elem, values_len);
|
||||
|
||||
let llvm_intrinsic = format!("llvm.masked.store.{llvm_elem_vec_str}.p0");
|
||||
let fn_ty = bx.type_func(&[llvm_elem_vec_ty, llvm_pointer, alignment_ty, mask_ty], ret_t);
|
||||
let f = bx.declare_cfn(&llvm_intrinsic, llvm::UnnamedAddr::No, fn_ty);
|
||||
let v = bx.call(
|
||||
fn_ty,
|
||||
None,
|
||||
None,
|
||||
f,
|
||||
&[args[2].immediate(), args[1].immediate(), alignment, mask],
|
||||
None,
|
||||
);
|
||||
return Ok(v);
|
||||
}
|
||||
|
||||
if name == sym::simd_scatter {
|
||||
// simd_scatter(values: <N x T>, pointers: <N x *mut T>,
|
||||
// mask: <N x i{M}>) -> ()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue