//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat //@ no-prefer-dynamic //@ needs-enzyme // // In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many // breakages. One benefit is that we match the IR generated by Enzyme only after running it // through LLVM's O3 pipeline, which will remove most of the noise. // However, our integration test could also be affected by changes in how rustc lowers MIR into // LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should // reduce this test to only match the first lines and the ret instructions. #![feature(autodiff)] use std::autodiff::autodiff; #[autodiff(d_square3, Forward, Dual, DualOnly)] #[autodiff(d_square2, Forward, 4, Dual, DualOnly)] #[autodiff(d_square1, Forward, 4, Dual, Dual)] #[no_mangle] fn square(x: &f32) -> f32 { x * x } // d_sqaure2 // CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'") // CHECK-NEXT: start: // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 // CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4 // CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1 // CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4 // CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2 // CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 // CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 // CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 // CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 // CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 // CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 // CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 // CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer // CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 // CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 // CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 // CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 // CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 // CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 // CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 // CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 // CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 // CHECK-NEXT: ret [4 x float] %19 // CHECK-NEXT: } // d_square3, the extra float is the original return value (x * x) // CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'") // CHECK-NEXT: start: // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 // CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4 // CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1 // CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4 // CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2 // CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 // CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val // CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 // CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 // CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 // CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 // CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 // CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 // CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer // CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 // CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 // CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 // CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 // CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 // CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 // CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 // CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 // CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 // CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0 // CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1 // CHECK-NEXT: ret { float, [4 x float] } %21 // CHECK-NEXT: } fn main() { let x = std::hint::black_box(3.0); let output = square(&x); dbg!(&output); assert_eq!(9.0, output); dbg!(square(&x)); let mut df_dx1 = 1.0; let mut df_dx2 = 2.0; let mut df_dx3 = 3.0; let mut df_dx4 = 0.0; let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4); dbg!(o1, o2, o3, o4); let [output2, o1, o2, o3, o4] = d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4); dbg!(o1, o2, o3, o4); assert_eq!(output, output2); assert!((6.0 - o1).abs() < 1e-10); assert!((12.0 - o2).abs() < 1e-10); assert!((18.0 - o3).abs() < 1e-10); assert!((0.0 - o4).abs() < 1e-10); assert_eq!(1.0, df_dx1); assert_eq!(2.0, df_dx2); assert_eq!(3.0, df_dx3); assert_eq!(0.0, df_dx4); assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1); assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2); assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3); assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4); }