From f4c297802ff34cd6a36d077c1271041fc0501cb7 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 17 Mar 2025 18:58:51 -0400 Subject: [PATCH] [NFC] extract autodiff call lowering in cg_llvm into own function --- .../src/builder/autodiff.rs | 201 ++++++++++-------- 1 file changed, 108 insertions(+), 93 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 71705ecb4d0..482af98aa1a 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -28,6 +28,113 @@ fn get_params(fnc: &Value) -> Vec<&Value> { } } +fn match_args_from_caller_to_enzyme<'ll>( + cx: &SimpleCx<'ll>, + args: &mut Vec<&'ll llvm::Value>, + inputs: &[DiffActivity], + outer_args: &[&'ll llvm::Value], +) { + debug!("matching autodiff arguments"); + // We now handle the issue that Rust level arguments not always match the llvm-ir level + // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on + // llvm-ir level. The number of activities matches the number of Rust level arguments, so we + // need to match those. + // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it + // using iterators and peek()? + let mut outer_pos: usize = 0; + let mut activity_pos = 0; + + let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); + let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap(); + let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); + let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap(); + + while activity_pos < inputs.len() { + let diff_activity = inputs[activity_pos as usize]; + // Duplicated arguments received a shadow argument, into which enzyme will write the + // gradient. + let (activity, duplicated): (&Metadata, bool) = match diff_activity { + DiffActivity::None => panic!("not a valid input activity"), + DiffActivity::Const => (enzyme_const, false), + DiffActivity::Active => (enzyme_out, false), + DiffActivity::ActiveOnly => (enzyme_out, false), + DiffActivity::Dual => (enzyme_dup, true), + DiffActivity::DualOnly => (enzyme_dupnoneed, true), + DiffActivity::Duplicated => (enzyme_dup, true), + DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true), + DiffActivity::FakeActivitySize => (enzyme_const, false), + }; + let outer_arg = outer_args[outer_pos]; + args.push(cx.get_metadata_value(activity)); + args.push(outer_arg); + if duplicated { + // We know that duplicated args by construction have a following argument, + // so this can not be out of bounds. + let next_outer_arg = outer_args[outer_pos + 1]; + let next_outer_ty = cx.val_ty(next_outer_arg); + // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since + // vectors behind references (&Vec) are already supported. Users can not pass a + // Vec by value for reverse mode, so this would only help forward mode autodiff. + let slice = { + if activity_pos + 1 >= inputs.len() { + // If there is no arg following our ptr, it also can't be a slice, + // since that would lead to a ptr, int pair. + false + } else { + let next_activity = inputs[activity_pos + 1]; + // We analyze the MIR types and add this dummy activity if we visit a slice. + next_activity == DiffActivity::FakeActivitySize + } + }; + if slice { + // A duplicated slice will have the following two outer_fn arguments: + // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call: + // (..., metadata! enzyme_dup, ptr, ptr, int1, ...). + // FIXME(ZuseZ4): We will upstream a safety check later which asserts that + // int2 >= int1, which means the shadow vector is large enough to store the gradient. + assert!(unsafe { + llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer + }); + let next_outer_arg2 = outer_args[outer_pos + 2]; + let next_outer_ty2 = cx.val_ty(next_outer_arg2); + assert!(unsafe { + llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer + }); + let next_outer_arg3 = outer_args[outer_pos + 3]; + let next_outer_ty3 = cx.val_ty(next_outer_arg3); + assert!(unsafe { + llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer + }); + args.push(next_outer_arg2); + args.push(cx.get_metadata_value(enzyme_const)); + args.push(next_outer_arg); + outer_pos += 4; + activity_pos += 2; + } else { + // A duplicated pointer will have the following two outer_fn arguments: + // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call: + // (..., metadata! enzyme_dup, ptr, ptr, ...). + if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) + { + assert!( + unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty) } + == llvm::TypeKind::Pointer + ); + } + // In the case of Dual we don't have assumptions, e.g. f32 would be valid. + args.push(next_outer_arg); + outer_pos += 2; + activity_pos += 1; + } + } else { + // We do not differentiate with resprect to this argument. + // We already added the metadata and argument above, so just increase the counters. + outer_pos += 1; + activity_pos += 1; + } + } +} + /// When differentiating `fn_to_diff`, take a `outer_fn` and generate another /// function with expected naming and calling conventions[^1] which will be /// discovered by the enzyme LLVM pass and its body populated with the differentiated @@ -132,12 +239,7 @@ fn generate_enzyme_call<'ll>( let mut args = Vec::with_capacity(num_args as usize + 1); args.push(fn_to_diff); - let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); - let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap(); - let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); - let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap(); let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap(); - match output { DiffActivity::Dual => { args.push(cx.get_metadata_value(enzyme_primal_ret)); @@ -148,95 +250,8 @@ fn generate_enzyme_call<'ll>( _ => {} } - debug!("matching autodiff arguments"); - // We now handle the issue that Rust level arguments not always match the llvm-ir level - // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on - // llvm-ir level. The number of activities matches the number of Rust level arguments, so we - // need to match those. - // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it - // using iterators and peek()? - let mut outer_pos: usize = 0; - let mut activity_pos = 0; let outer_args: Vec<&llvm::Value> = get_params(outer_fn); - while activity_pos < inputs.len() { - let diff_activity = inputs[activity_pos as usize]; - // Duplicated arguments received a shadow argument, into which enzyme will write the - // gradient. - let (activity, duplicated): (&Metadata, bool) = match diff_activity { - DiffActivity::None => panic!("not a valid input activity"), - DiffActivity::Const => (enzyme_const, false), - DiffActivity::Active => (enzyme_out, false), - DiffActivity::ActiveOnly => (enzyme_out, false), - DiffActivity::Dual => (enzyme_dup, true), - DiffActivity::DualOnly => (enzyme_dupnoneed, true), - DiffActivity::Duplicated => (enzyme_dup, true), - DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true), - DiffActivity::FakeActivitySize => (enzyme_const, false), - }; - let outer_arg = outer_args[outer_pos]; - args.push(cx.get_metadata_value(activity)); - args.push(outer_arg); - if duplicated { - // We know that duplicated args by construction have a following argument, - // so this can not be out of bounds. - let next_outer_arg = outer_args[outer_pos + 1]; - let next_outer_ty = cx.val_ty(next_outer_arg); - // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since - // vectors behind references (&Vec) are already supported. Users can not pass a - // Vec by value for reverse mode, so this would only help forward mode autodiff. - let slice = { - if activity_pos + 1 >= inputs.len() { - // If there is no arg following our ptr, it also can't be a slice, - // since that would lead to a ptr, int pair. - false - } else { - let next_activity = inputs[activity_pos + 1]; - // We analyze the MIR types and add this dummy activity if we visit a slice. - next_activity == DiffActivity::FakeActivitySize - } - }; - if slice { - // A duplicated slice will have the following two outer_fn arguments: - // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call: - // (..., metadata! enzyme_dup, ptr, ptr, int1, ...). - // FIXME(ZuseZ4): We will upstream a safety check later which asserts that - // int2 >= int1, which means the shadow vector is large enough to store the gradient. - assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer); - let next_outer_arg2 = outer_args[outer_pos + 2]; - let next_outer_ty2 = cx.val_ty(next_outer_arg2); - assert!(llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer); - let next_outer_arg3 = outer_args[outer_pos + 3]; - let next_outer_ty3 = cx.val_ty(next_outer_arg3); - assert!(llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer); - args.push(next_outer_arg2); - args.push(cx.get_metadata_value(enzyme_const)); - args.push(next_outer_arg); - outer_pos += 4; - activity_pos += 2; - } else { - // A duplicated pointer will have the following two outer_fn arguments: - // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call: - // (..., metadata! enzyme_dup, ptr, ptr, ...). - if matches!( - diff_activity, - DiffActivity::Duplicated | DiffActivity::DuplicatedOnly - ) { - assert!( - llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer - ); - } - // In the case of Dual we don't have assumptions, e.g. f32 would be valid. - args.push(next_outer_arg); - outer_pos += 2; - activity_pos += 1; - } - } else { - // We do not differentiate with resprect to this argument. - // We already added the metadata and argument above, so just increase the counters. - outer_pos += 1; - activity_pos += 1; - } - } + match_args_from_caller_to_enzyme(&cx, &mut args, &inputs, &outer_args); let call = builder.call(enzyme_ty, ad_fn, &args, None);