diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml new file mode 100644 index 0000000000000..711f4de7629b4 --- /dev/null +++ b/.github/workflows/enzyme-ci.yml @@ -0,0 +1,90 @@ +name: Rust CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + merge_group: + +jobs: + build: + name: Rust Integration CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ${{ matrix.os }} + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [openstack22] + + timeout-minutes: 600 + steps: + - name: Checkout Rust source + uses: actions/checkout@v4 + with: + submodules: true # check out all submodules so the cache can work correctly + fetch-depth: 2 + - uses: dtolnay/rust-toolchain@nightly + - name: Get LLVM commit hash + id: llvm-commit + run: echo "HEAD=$(git rev-parse HEAD:src/llvm-project)" >> $GITHUB_OUTPUT + - name: Cache LLVM + id: cache-llvm + uses: actions/cache@v4 + with: + path: build/build/x86_64-unknown-linux-gnu/llvm + key: ${{ matrix.os }}-llvm-${{ steps.llvm-commit.outputs.HEAD }} + - name: Get Enzyme commit hash + id: enzyme-commit + run: echo "HEAD=$(git rev-parse HEAD:src/tools/enzyme)" >> $GITHUB_OUTPUT + - name: Cache Enzyme + id: cache-enzyme + uses: actions/cache@v4 + with: + path: build/build/x86_64-unknown-linux-gnu/enzyme + key: ${{ matrix.os }}-enzyme-${{ steps.enzyme-commit.outputs.HEAD }} + - name: Cache bootstrap/stage0 artifacts for incremental builds + uses: actions/cache@v4 + with: + path: | + build/build/bootstrap/ + build/build/x86_64-unknown-linux-gnu/stage0-rustc/ + build/build/x86_64-unknown-linux-gnu/stage0-std/ + build/build/x86_64-unknown-linux-gnu/stage0-tools/ + build/build/x86_64-unknown-linux-gnu/stage1-std/ + # Approximate stable hash. It doesn't matter too much when this goes out of sync as it just caches + # some stage0/stage1 dependencies and stdlibs which *hopefully* are hash-keyed. + key: enzyme-rust-incremental-${{ runner.os }}-${{ hashFiles('src/**/Cargo.lock', 'Cargo.lock') }} + restore-keys: | + enzyme-rust-incremental-${{ runner.os }} + - name: Build + run: | + wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - + sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-17 main" || true + sudo apt-get -y update + sudo apt-get install -y lld-17 + mkdir lld-path-manipulation + ln -s "$(which lld-17)" lld-path-manipulation/lld + ln -s "$(which lld-17)" lld-path-manipulation/ld + ln -s "$(which lld-17)" lld-path-manipulation/ld.lld + ln -s "$(which lld-17)" lld-path-manipulation/lld-17 + export PATH="$PWD/lld-path-manipulation:$PATH" + mkdir -p build + cd build + rm -f config.toml + ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --set=rust.use-lld=true --release-channel=nightly --enable-llvm-assertions --enable-option-checking --enable-ninja --disable-docs + ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc + rustup toolchain link enzyme build/host/stage1 + - name: checkout Enzyme/rustbook + uses: actions/checkout@v4 + with: + repository: EnzymeAD/rustbook + ref: main + path: rustbook + - name: test Enzyme/rustbook + working-directory: rustbook + run: | + RUSTFLAGS='-Z unstable-options' cargo +enzyme test + ENZYME_LOOSE_TYPES=1 RUSTFLAGS='-Z unstable-options' cargo +enzyme test -p samples-loose-types diff --git a/.gitmodules b/.gitmodules index 9bb68b37081f5..52517890060b8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -43,3 +43,7 @@ path = library/backtrace url = https://github.com/rust-lang/backtrace-rs.git shallow = true +[submodule "src/tools/enzyme"] + path = src/tools/enzyme + url = https://github.com/enzymeAD/enzyme + shallow = true diff --git a/Cargo.lock b/Cargo.lock index 5d78e29de0e08..847f11573b811 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4261,6 +4261,7 @@ dependencies = [ name = "rustc_monomorphize" version = "0.0.0" dependencies = [ + "rustc_ast", "rustc_data_structures", "rustc_errors", "rustc_fluent_macro", @@ -4269,6 +4270,7 @@ dependencies = [ "rustc_middle", "rustc_session", "rustc_span", + "rustc_symbol_mangling", "rustc_target", "serde", "serde_json", diff --git a/README.md b/README.md index 5d5beaf1b7a20..062126c109d73 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,41 @@ -# The Rust Programming Language +# The Rust Programming Language + AutoDiff [![Rust Community](https://img.shields.io/badge/Rust_Community%20-Join_us-brightgreen?style=plastic&logo=rust)](https://www.rust-lang.org/community) This is the main source code repository for [Rust]. It contains the compiler, -standard library, and documentation. +standard library, and documentation. It is modified to use Enzyme for AutoDiff. + + +Please configure this fork using the following command: + +``` +./configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs +``` + +Afterwards you can build rustc using: +``` +./x.py build --stage 1 library +``` + +Afterwards rustc toolchain link will allow you to use it through cargo: +``` +rustup toolchain link enzyme build/host/stage1 +rustup toolchain install nightly # enables -Z unstable-options +``` + +You can then run an examples from our [docs](https://enzyme.mit.edu/index.fcgi/rust/usage/usage.html) using + +```bash +cargo +enzyme run --release +``` + +## Bug reporting +Bugs are pretty much expected at this point of the development process. +Any type of bug report is therefore highly appreciated. Documentation +on how to best debug this fork and how to report errors can be found +on our [website]: https://enzyme.mit.edu/index.fcgi/rust/Debugging.html + + [Rust]: https://www.rust-lang.org/ diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs index 3496cfc38c84e..259f76abd0816 100644 --- a/compiler/rustc_ast/src/ast.rs +++ b/compiler/rustc_ast/src/ast.rs @@ -2574,6 +2574,12 @@ impl FnRetTy { FnRetTy::Ty(ty) => ty.span, } } + pub fn has_ret(&self) -> bool { + match self { + FnRetTy::Default(_) => false, + FnRetTy::Ty(_) => true, + } + } } #[derive(Clone, Copy, PartialEq, Encodable, Decodable, Debug)] diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs new file mode 100644 index 0000000000000..8fb5e41577330 --- /dev/null +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -0,0 +1,274 @@ +use crate::expand::typetree::TypeTree; +use std::str::FromStr; +use std::fmt::{self, Display, Formatter}; +use crate::ptr::P; +use crate::{Ty, TyKind}; + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum DiffMode { + Inactive, + Source, + Forward, + Reverse, + ForwardFirst, + ReverseFirst, +} + +pub fn is_rev(mode: DiffMode) -> bool { + match mode { + DiffMode::Reverse | DiffMode::ReverseFirst => true, + _ => false, + } +} +pub fn is_fwd(mode: DiffMode) -> bool { + match mode { + DiffMode::Forward | DiffMode::ForwardFirst => true, + _ => false, + } +} + +impl Display for DiffMode { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + DiffMode::Inactive => write!(f, "Inactive"), + DiffMode::Source => write!(f, "Source"), + DiffMode::Forward => write!(f, "Forward"), + DiffMode::Reverse => write!(f, "Reverse"), + DiffMode::ForwardFirst => write!(f, "ForwardFirst"), + DiffMode::ReverseFirst => write!(f, "ReverseFirst"), + } + } +} + +pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { + if activity == DiffActivity::None { + // Only valid if primal returns (), but we can't check that here. + return true; + } + match mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + DiffMode::Forward | DiffMode::ForwardFirst => { + activity == DiffActivity::Dual || + activity == DiffActivity::DualOnly || + activity == DiffActivity::Const + } + DiffMode::Reverse | DiffMode::ReverseFirst => { + activity == DiffActivity::Const || + activity == DiffActivity::Active || + activity == DiffActivity::ActiveOnly + } + } +} +fn is_ptr_or_ref(ty: &Ty) -> bool { + match ty.kind { + TyKind::Ptr(_) | TyKind::Ref(_, _) => true, + _ => false, + } +} +// TODO We should make this more robust to also +// accept aliases of f32 and f64 +//fn is_float(ty: &Ty) -> bool { +// false +//} +pub fn valid_ty_for_activity(ty: &P, activity: DiffActivity) -> bool { + if is_ptr_or_ref(ty) { + return activity == DiffActivity::Dual || + activity == DiffActivity::DualOnly || + activity == DiffActivity::Duplicated || + activity == DiffActivity::DuplicatedOnly || + activity == DiffActivity::Const; + } + true + //if is_scalar_ty(&ty) { + // return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly || + // activity == DiffActivity::Const; + //} +} +pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { + return match mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + DiffMode::Forward | DiffMode::ForwardFirst => { + // These are the only valid cases + activity == DiffActivity::Dual || + activity == DiffActivity::DualOnly || + activity == DiffActivity::Const + } + DiffMode::Reverse | DiffMode::ReverseFirst => { + // These are the only valid cases + activity == DiffActivity::Active || + activity == DiffActivity::ActiveOnly || + activity == DiffActivity::Const || + activity == DiffActivity::Duplicated || + activity == DiffActivity::DuplicatedOnly + } + }; +} +pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> Option { + for i in 0..activity_vec.len() { + if !valid_input_activity(mode, activity_vec[i]) { + return Some(i); + } + } + None +} + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum DiffActivity { + None, + Const, + Active, + ActiveOnly, + Dual, + DualOnly, + Duplicated, + DuplicatedOnly, + FakeActivitySize +} + +impl Display for DiffActivity { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DiffActivity::None => write!(f, "None"), + DiffActivity::Const => write!(f, "Const"), + DiffActivity::Active => write!(f, "Active"), + DiffActivity::ActiveOnly => write!(f, "ActiveOnly"), + DiffActivity::Dual => write!(f, "Dual"), + DiffActivity::DualOnly => write!(f, "DualOnly"), + DiffActivity::Duplicated => write!(f, "Duplicated"), + DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"), + DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"), + } + } +} + +impl FromStr for DiffMode { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "Inactive" => Ok(DiffMode::Inactive), + "Source" => Ok(DiffMode::Source), + "Forward" => Ok(DiffMode::Forward), + "Reverse" => Ok(DiffMode::Reverse), + "ForwardFirst" => Ok(DiffMode::ForwardFirst), + "ReverseFirst" => Ok(DiffMode::ReverseFirst), + _ => Err(()), + } + } +} +impl FromStr for DiffActivity { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "None" => Ok(DiffActivity::None), + "Active" => Ok(DiffActivity::Active), + "ActiveOnly" => Ok(DiffActivity::ActiveOnly), + "Const" => Ok(DiffActivity::Const), + "Dual" => Ok(DiffActivity::Dual), + "DualOnly" => Ok(DiffActivity::DualOnly), + "Duplicated" => Ok(DiffActivity::Duplicated), + "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly), + _ => Err(()), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct AutoDiffAttrs { + pub mode: DiffMode, + pub ret_activity: DiffActivity, + pub input_activity: Vec, +} + +impl AutoDiffAttrs { + pub fn has_ret_activity(&self) -> bool { + match self.ret_activity { + DiffActivity::None => false, + _ => true, + } + } + pub fn has_active_only_ret(&self) -> bool { + match self.ret_activity { + DiffActivity::ActiveOnly => true, + _ => false, + } + } +} + +impl AutoDiffAttrs { + pub fn inactive() -> Self { + AutoDiffAttrs { + mode: DiffMode::Inactive, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + pub fn source() -> Self { + AutoDiffAttrs { + mode: DiffMode::Source, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + + pub fn is_active(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + _ => { + true + } + } + } + + pub fn is_source(&self) -> bool { + match self.mode { + DiffMode::Source => true, + _ => false, + } + } + pub fn apply_autodiff(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + _ => { + true + } + } + } + + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } + } +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct AutoDiffItem { + pub source: String, + pub target: String, + pub attrs: AutoDiffAttrs, + pub inputs: Vec, + pub output: TypeTree, +} + +impl fmt::Display for AutoDiffItem { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Differentiating {} -> {}", self.source, self.target)?; + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with inputs: {:?}", self.inputs)?; + write!(f, " with output: {:?}", self.output) + } +} + + diff --git a/compiler/rustc_ast/src/expand/mod.rs b/compiler/rustc_ast/src/expand/mod.rs index 942347383ce31..29cc9121aebbd 100644 --- a/compiler/rustc_ast/src/expand/mod.rs +++ b/compiler/rustc_ast/src/expand/mod.rs @@ -5,6 +5,8 @@ use rustc_span::{def_id::DefId, symbol::Ident}; use crate::MetaItem; pub mod allocator; +pub mod autodiff_attrs; +pub mod typetree; #[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)] pub struct StrippedCfgItem { diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs new file mode 100644 index 0000000000000..fa04ed253ac21 --- /dev/null +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -0,0 +1,62 @@ +use std::fmt; + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum Kind { + Anything, + Integer, + Pointer, + Half, + Float, + Double, + Unknown, +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct TypeTree(pub Vec); + +impl TypeTree { + pub fn new() -> Self { + Self(Vec::new()) + } + pub fn all_ints() -> Self { + Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }]) + } + pub fn int(size: usize) -> Self { + let mut ints = Vec::with_capacity(size); + for i in 0..size { + ints.push(Type { offset: i as isize, size: 1, kind: Kind::Integer, child: TypeTree::new() }); + } + Self(ints) + } +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct FncTree { + pub args: Vec, + pub ret: TypeTree, +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct Type { + pub offset: isize, + pub size: usize, + pub kind: Kind, + pub child: TypeTree, +} + +impl Type { + pub fn add_offset(self, add: isize) -> Self { + let offset = match self.offset { + -1 => add, + x => add + x, + }; + + Self { size: self.size, kind: self.kind, child: self.child, offset } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} diff --git a/compiler/rustc_ast/src/lib.rs b/compiler/rustc_ast/src/lib.rs index 7e713a49a8cfa..e69d222c418a0 100644 --- a/compiler/rustc_ast/src/lib.rs +++ b/compiler/rustc_ast/src/lib.rs @@ -50,6 +50,7 @@ pub mod ptr; pub mod token; pub mod tokenstream; pub mod visit; +//pub mod autodiff_attrs; pub use self::ast::*; pub use self::ast_traits::{AstDeref, AstNodeWrapper, HasAttrs, HasNodeId, HasSpan, HasTokens}; diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index dda466b026d91..d5d7562f36a11 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -1,6 +1,14 @@ builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function builtin_macros_alloc_must_statics = allocators must be statics +builtin_macros_autodiff_unknown_activity = did not recognize activity {$act} +builtin_macros_autodiff = autodiff must be applied to function +builtin_macros_autodiff_not_build = this rustc version does not support autodiff +builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode +builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found} +builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse` +builtin_macros_autodiff_ty_activity = {$act} can not be used for this type + builtin_macros_asm_clobber_abi = clobber_abi builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs builtin_macros_asm_clobber_outputs = generic outputs diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs new file mode 100644 index 0000000000000..7f47b62776e93 --- /dev/null +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -0,0 +1,702 @@ +#![allow(unused_imports)] +//use crate::util::check_builtin_macro_attribute; +//use crate::util::check_autodiff; + +use crate::errors; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, is_fwd, is_rev, valid_input_activity, valid_ty_for_activity}; +use rustc_ast::ptr::P; +use rustc_ast::token::{Token, TokenKind}; +use rustc_ast::tokenstream::*; +use rustc_ast::FnRetTy; +use rustc_ast::{self as ast, FnHeader, FnSig, Generics, MetaItemKind, NestedMetaItem, StmtKind}; +use rustc_ast::{BindingAnnotation, ByRef}; +use rustc_ast::{Fn, ItemKind, PatKind, Stmt, TyKind, Unsafe}; +use rustc_expand::base::{Annotatable, ExtCtxt}; +use rustc_span::symbol::{kw, sym, Ident}; +use rustc_span::Span; +use rustc_span::Symbol; +use std::string::String; +use thin_vec::{thin_vec, ThinVec}; +use std::str::FromStr; + +use rustc_ast::AssocItemKind; + +#[cfg(not(llvm_enzyme))] +pub fn expand( + ecx: &mut ExtCtxt<'_>, + _expand_span: Span, + meta_item: &ast::MetaItem, + item: Annotatable, +) -> Vec { + ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span }); + return vec![item]; +} + +#[cfg(llvm_enzyme)] +fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident { + let segments = &x.meta_item().unwrap().path.segments; + assert!(segments.len() == 1); + segments[0].ident +} + +#[cfg(llvm_enzyme)] +fn name(x: &NestedMetaItem) -> String { + first_ident(x).name.to_string() +} + +#[cfg(llvm_enzyme)] +pub fn from_ast(ecx: &mut ExtCtxt<'_>, meta_item: &ThinVec, has_ret: bool) -> AutoDiffAttrs { + + let mode = name(&meta_item[1]); + let mode = match DiffMode::from_str(&mode) { + Ok(x) => x, + Err(_) => { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode}); + return AutoDiffAttrs::inactive(); + }, + }; + let mut activities: Vec = vec![]; + for x in &meta_item[2..] { + let activity_str = name(&x); + let res = DiffActivity::from_str(&activity_str); + match res { + Ok(x) => activities.push(x), + Err(_) => { + ecx.sess.dcx().emit_err(errors::AutoDiffUnknownActivity { span: x.span(), act: activity_str }); + return AutoDiffAttrs::inactive(); + } + }; + } + + // If a return type exist, we need to split the last activity, + // otherwise we return None as placeholder. + let (ret_activity, input_activity) = if has_ret { + activities.split_last().unwrap() + } else { + (&DiffActivity::None, activities.as_slice()) + }; + + AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() } +} + +#[cfg(llvm_enzyme)] +pub fn expand( + ecx: &mut ExtCtxt<'_>, + expand_span: Span, + meta_item: &ast::MetaItem, + mut item: Annotatable, +) -> Vec { + //check_builtin_macro_attribute(ecx, meta_item, sym::alloc_error_handler); + + // first get the annotable item: + let (sig, is_impl): (FnSig, bool) = match &item { + Annotatable::Item(ref iitem) => { + let sig = match &iitem.kind { + ItemKind::Fn(box ast::Fn { sig, .. }) => sig, + _ => { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + (sig.clone(), false) + }, + Annotatable::ImplItem(ref assoc_item) => { + let sig = match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig, + _ => { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + (sig.clone(), true) + }, + _ => { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + + let meta_item_vec: ThinVec = match meta_item.kind { + ast::MetaItemKind::List(ref vec) => vec.clone(), + _ => { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + + let has_ret = sig.decl.output.has_ret(); + let sig_span = ecx.with_call_site_ctxt(sig.span); + + let (vis, primal) = match &item { + Annotatable::Item(ref iitem) => { + (iitem.vis.clone(), iitem.ident.clone()) + }, + Annotatable::ImplItem(ref assoc_item) => { + (assoc_item.vis.clone(), assoc_item.ident.clone()) + }, + _ => { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + + // create TokenStream from vec elemtents: + // meta_item doesn't have a .tokens field + let comma: Token = Token::new(TokenKind::Comma, Span::default()); + let mut ts: Vec = vec![]; + for t in meta_item_vec.clone()[1..].iter() { + let val = first_ident(t); + let t = Token::from_ast_ident(val); + ts.push(TokenTree::Token(t, Spacing::Joint)); + ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); + } + if !sig.decl.output.has_ret() { + // We don't want users to provide a return activity if the function doesn't return anything. + // For simplicity, we just add a dummy token to the end of the list. + let t = Token::new(TokenKind::Ident(sym::None, false), Span::default()); + ts.push(TokenTree::Token(t, Spacing::Joint)); + } + let ts: TokenStream = TokenStream::from_iter(ts); + + let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret); + if !x.is_active() { + // We encountered an error, so we return the original item. + // This allows us to potentially parse other attributes. + return vec![item]; + } + let span = ecx.with_def_site_ctxt(expand_span); + + let n_active: u32 = x.input_activity.iter() + .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) + .count() as u32; + let (d_sig, new_args, idents) = gen_enzyme_decl(ecx, &sig, &x, span); + let new_decl_span = d_sig.span; + let d_body = gen_enzyme_body( + ecx, + &x, + n_active, + &sig, + &d_sig, + primal, + &new_args, + span, + sig_span, + new_decl_span, + idents, + ); + let d_ident = first_ident(&meta_item_vec[0]); + + // The first element of it is the name of the function to be generated + let asdf = Box::new(ast::Fn { + defaultness: ast::Defaultness::Final, + sig: d_sig, + generics: Generics::default(), + body: Some(d_body), + }); + let mut rustc_ad_attr = + P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); + let ts2: Vec = vec![ + TokenTree::Token( + Token::new(TokenKind::Ident(sym::never, false), span), + Spacing::Joint, + )]; + let never_arg = ast::DelimArgs { + dspan: ast::tokenstream::DelimSpan::from_single(span), + delim: ast::token::Delimiter::Parenthesis, + tokens: ast::tokenstream::TokenStream::from_iter(ts2), + }; + let inline_item = ast::AttrItem { + path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)), + args: ast::AttrArgs::Delimited(never_arg), + tokens: None, + }; + let inline_never_attr = P(ast::NormalAttr { + item: inline_item, + tokens: None, + }); + let mut attr: ast::Attribute = ast::Attribute { + kind: ast::AttrKind::Normal(rustc_ad_attr.clone()), + //id: ast::DUMMY_TR_ID, + id: ast::AttrId::from_u32(12341), // TODO: fix + style: ast::AttrStyle::Outer, + span, + }; + let inline_never : ast::Attribute = ast::Attribute { + kind: ast::AttrKind::Normal(inline_never_attr), + //id: ast::DUMMY_TR_ID, + id: ast::AttrId::from_u32(12342), // TODO: fix + style: ast::AttrStyle::Outer, + span, + }; + + // Don't add it multiple times: + let orig_annotatable: Annotatable = match item { + Annotatable::Item(ref mut iitem) => { + if !iitem.attrs.iter().any(|a| a.id == attr.id) { + iitem.attrs.push(attr.clone()); + } + if !iitem.attrs.iter().any(|a| a.id == inline_never.id) { + iitem.attrs.push(inline_never.clone()); + } + Annotatable::Item(iitem.clone()) + }, + Annotatable::ImplItem(ref mut assoc_item) => { + if !assoc_item.attrs.iter().any(|a| a.id == attr.id) { + assoc_item.attrs.push(attr.clone()); + } + if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) { + assoc_item.attrs.push(inline_never.clone()); + } + Annotatable::ImplItem(assoc_item.clone()) + }, + _ => { + panic!("not supported"); + } + }; + + // Now update for d_fn + rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { + dspan: DelimSpan::dummy(), + delim: rustc_ast::token::Delimiter::Parenthesis, + tokens: ts, + }); + attr.kind = ast::AttrKind::Normal(rustc_ad_attr); + + let d_annotatable = if is_impl { + let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); + let d_fn = P(ast::AssocItem { + attrs: thin_vec![attr.clone(), inline_never], + id: ast::DUMMY_NODE_ID, + span, + vis, + ident: d_ident, + kind: assoc_item, + tokens: None, + }); + Annotatable::ImplItem(d_fn) + } else { + let mut d_fn = ecx.item(span, d_ident, thin_vec![attr.clone(), inline_never], ItemKind::Fn(asdf)); + d_fn.vis = vis; + Annotatable::Item(d_fn) + }; + + return vec![orig_annotatable, d_annotatable]; +} + +// shadow arguments in reverse mode must be mutable references or ptrs, because Enzyme will write into them. +#[cfg(llvm_enzyme)] +fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { + let mut ty = ty.clone(); + match ty.kind { + TyKind::Ptr(ref mut mut_ty) => { + mut_ty.mutbl = ast::Mutability::Mut; + } + TyKind::Ref(_, ref mut mut_ty) => { + mut_ty.mutbl = ast::Mutability::Mut; + } + _ => { + panic!("unsupported type: {:?}", ty); + } + } + ty +} + + +// The body of our generated functions will consist of two black_Box calls. +// The first will call the primal function with the original arguments. +// The second will just take a tuple containing the new arguments. +// This way we surpress rustc from optimizing any argument away. +// The last line will 'loop {}', to match the return type of the new function +#[cfg(llvm_enzyme)] +fn gen_enzyme_body( + ecx: &ExtCtxt<'_>, + x: &AutoDiffAttrs, + n_active: u32, + sig: &ast::FnSig, + d_sig: &ast::FnSig, + primal: Ident, + new_names: &[String], + span: Span, + sig_span: Span, + new_decl_span: Span, + idents: Vec, +) -> P { + let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]); + let empty_loop_block = ecx.block(span, ThinVec::new()); + let noop = ast::InlineAsm { + template: vec![ast::InlineAsmTemplatePiece::String("NOP".to_string())], + template_strs: Box::new([]), + operands: vec![], + clobber_abis: vec![], + options: ast::InlineAsmOptions::PURE & ast::InlineAsmOptions::NOMEM, + line_spans: vec![], + }; + let noop_expr = ecx.expr_asm(span, P(noop)); + let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated); + let unsf_block = ast::Block { + stmts: thin_vec![ecx.stmt_semi(noop_expr)], + id: ast::DUMMY_NODE_ID, + tokens: None, + rules: unsf, + span, + could_be_bare_literal: false, + }; + let unsf_expr = ecx.expr_block(P(unsf_block)); + let _loop_expr = ecx.expr_loop(span, empty_loop_block); + let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); + let primal_call = gen_primal_call(ecx, span, primal, idents); + let black_box_primal_call = + ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![primal_call.clone()]); + let tup_args = new_names + .iter() + .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) + .collect(); + + let black_box_remaining_args = ecx.expr_call( + sig_span, + blackbox_call_expr.clone(), + thin_vec![ecx.expr_tuple(sig_span, tup_args)], + ); + + let mut body = ecx.block(span, ThinVec::new()); + body.stmts.push(ecx.stmt_semi(unsf_expr)); + body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); + body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); + + + if !d_sig.decl.output.has_ret() { + // there is no return type that we have to match, () works fine. + return body; + } + + // having an active-only return means we'll drop the original return type. + // So that can be treated identical to not having one in the first place. + let primal_ret = sig.decl.output.has_ret() && !x.has_active_only_ret(); + + if primal_ret && n_active == 0 && is_rev(x.mode) { + // We only have the primal ret. + body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone())); + return body; + } + + if !primal_ret && n_active == 1 { + // Again no tuple return, so return default float val. + let ty = match d_sig.decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let arg = ty.kind.is_simple_path().unwrap(); + let sl: Vec = vec![arg, Symbol::intern("default")]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + body.stmts.push(ecx.stmt_expr(default_call_expr)); + return body; + } + + let mut exprs = ThinVec::>::new(); + if primal_ret { + // We have both primal ret and active floats. + // primal ret is first, by construction. + exprs.push(primal_call.clone()); + } + + // Now construct default placeholder for each active float. + // Is there something nicer than f32::default() and f64::default()? + let d_ret_ty = match d_sig.decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let mut d_ret_ty = match d_ret_ty.kind.clone() { + TyKind::Tup(ref tys) => { + tys.clone() + } + TyKind::Path(_, rustc_ast::Path { segments, .. }) => { + if segments.len() == 1 && segments[0].args.is_none() { + let id = vec![segments[0].ident]; + let kind = TyKind::Path(None, ecx.path(span, id)); + let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None }); + thin_vec![ty] + } else { + panic!("Expected tuple or simple path return type"); + } + } + _ => { + // We messed up construction of d_sig + panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty); + } + }; + if is_fwd(x.mode) { + if x.ret_activity == DiffActivity::Dual { + assert!(d_ret_ty.len() == 2); + // both should be identical, by construction + let arg = d_ret_ty[0].kind.is_simple_path().unwrap(); + let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap(); + assert!(arg == arg2); + let sl: Vec = vec![arg, Symbol::intern("default")]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + exprs.push(default_call_expr); + } + } else { + assert!(is_rev(x.mode)); + + if primal_ret { + // We have extra handling above for the primal ret + d_ret_ty = d_ret_ty[1..].to_vec().into(); + } + + for arg in d_ret_ty.iter() { + let arg = arg.kind.is_simple_path().unwrap(); + let sl: Vec = vec![arg, Symbol::intern("default")]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + exprs.push(default_call_expr); + }; + } + + let ret : P; + if exprs.len() > 1 { + let ret_tuple: P = ecx.expr_tuple(span, exprs); + ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]); + } else if exprs.len() == 1 { + let ret_scal = exprs.pop().unwrap(); + ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_scal]); + } else { + assert!(!d_sig.decl.output.has_ret()); + // We don't have to match the return type. + return body; + } + assert!(d_sig.decl.output.has_ret()); + body.stmts.push(ecx.stmt_expr(ret)); + + body +} + +#[cfg(llvm_enzyme)] +fn gen_primal_call( + ecx: &ExtCtxt<'_>, + span: Span, + primal: Ident, + idents: Vec, +) -> P { + let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; + if has_self { + let args: ThinVec<_> = idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); + let self_expr = ecx.expr_self(span); + ecx.expr_method_call(span, self_expr, primal, args.clone()) + } else { + let args: ThinVec<_> = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); + let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); + ecx.expr_call(span, primal_call_expr, args) + } +} + +// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must +// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer. +// Active arguments must be scalars. Their shadow argument is added to the return type (and will be +// zero-initialized by Enzyme). Active arguments are not handled yet. +// Each argument of the primal function (and the return type if existing) must be annotated with an +// activity. +#[cfg(llvm_enzyme)] +fn gen_enzyme_decl( + ecx: &ExtCtxt<'_>, + sig: &ast::FnSig, + x: &AutoDiffAttrs, + span: Span, +) -> (ast::FnSig, Vec, Vec) { + let sig_args = sig.decl.inputs.len() + if sig.decl.output.has_ret() { 1 } else { 0 }; + let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 }; + if sig_args != num_activities { + ecx.sess.dcx().emit_fatal(errors::AutoDiffInvalidNumberActivities { + span, + expected: sig_args, + found: num_activities, + }); + } + assert!(sig.decl.inputs.len() == x.input_activity.len()); + assert!(sig.decl.output.has_ret() == x.has_ret_activity()); + let mut d_decl = sig.decl.clone(); + let mut d_inputs = Vec::new(); + let mut new_inputs = Vec::new(); + let mut idents = Vec::new(); + let mut act_ret = ThinVec::new(); + for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) { + d_inputs.push(arg.clone()); + if !valid_input_activity(x.mode, *activity) { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplicationModeAct { + span, + mode: x.mode.to_string(), + act: activity.to_string() + }); + } + if !valid_ty_for_activity(&arg.ty, *activity) { + ecx.sess.dcx().emit_err(errors::AutoDiffInvalidTypeForActivity { + span: arg.ty.span, + act: activity.to_string() + }); + } + match activity { + DiffActivity::Active => { + act_ret.push(arg.ty.clone()); + } + DiffActivity::Duplicated => { + let mut shadow_arg = arg.clone(); + // We += into the shadow in reverse mode. + shadow_arg.ty = P(assure_mut_ref(&arg.ty)); + let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { + ident.name + } else { + dbg!(&shadow_arg.pat); + panic!("not an ident?"); + }; + let name: String = format!("d{}", old_name); + new_inputs.push(name.clone()); + let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); + shadow_arg.pat = P(ast::Pat { + // TODO: Check id + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingAnnotation::NONE, ident, None), + span: shadow_arg.pat.span, + tokens: shadow_arg.pat.tokens.clone(), + }); + d_inputs.push(shadow_arg); + } + DiffActivity::Dual => { + let mut shadow_arg = arg.clone(); + let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { + ident.name + } else { + dbg!(&shadow_arg.pat); + panic!("not an ident?"); + }; + let name: String = format!("b{}", old_name); + new_inputs.push(name.clone()); + let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); + shadow_arg.pat = P(ast::Pat { + // TODO: Check id + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingAnnotation::NONE, ident, None), + span: shadow_arg.pat.span, + tokens: shadow_arg.pat.tokens.clone(), + }); + d_inputs.push(shadow_arg); + } + DiffActivity::Const => { + // Nothing to do here. + } + _ => { + dbg!(&activity); + panic!("Not implemented"); + } + } + if let PatKind::Ident(_, ident, _) = arg.pat.kind { + idents.push(ident.clone()); + } else { + panic!("not an ident?"); + } + } + + let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly; + if active_only_ret { + assert!(is_rev(x.mode)); + } + + // If we return a scalar in the primal and the scalar is active, + // then add it as last arg to the inputs. + if is_rev(x.mode) { + match x.ret_activity { + DiffActivity::Active | DiffActivity::ActiveOnly => { + let ty = match d_decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let name = "dret".to_string(); + let ident = Ident::from_str_and_span(&name, ty.span); + let shadow_arg = ast::Param { + attrs: ThinVec::new(), + ty: ty.clone(), + pat: P(ast::Pat { + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingAnnotation::NONE, ident, None), + span: ty.span, + tokens: None, + }), + id: ast::DUMMY_NODE_ID, + span: ty.span, + is_placeholder: false, + }; + d_inputs.push(shadow_arg); + new_inputs.push(name); + } + _ => {} + } + } + d_decl.inputs = d_inputs.into(); + + if is_fwd(x.mode) { + if let DiffActivity::Dual = x.ret_activity { + let ty = match d_decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + // Dual can only be used for f32/f64 ret. + // In that case we return now a tuple with two floats. + let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]); + let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }); + d_decl.output = FnRetTy::Ty(ty); + } + if let DiffActivity::DualOnly = x.ret_activity { + // No need to change the return type, + // we will just return the shadow in place + // of the primal return. + } + } + + // If we use ActiveOnly, drop the original return value. + d_decl.output = if active_only_ret { + FnRetTy::Default(span) + } else { + d_decl.output.clone() + }; + + trace!("act_ret: {:?}", act_ret); + + // If we have an active input scalar, add it's gradient to the + // return type. This might require changing the return type to a + // tuple. + if act_ret.len() > 0 { + let ret_ty = match d_decl.output { + FnRetTy::Ty(ref ty) => { + if !active_only_ret { + act_ret.insert(0, ty.clone()); + } + let kind = TyKind::Tup(act_ret); + P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }) + } + FnRetTy::Default(span) => { + if act_ret.len() == 1 { + act_ret[0].clone() + } else { + let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect()); + P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None }) + } + } + }; + d_decl.output = FnRetTy::Ty(ret_ty); + } + + let d_sig = FnSig { header: sig.header.clone(), decl: d_decl, span }; + trace!("Generated signature: {:?}", d_sig); + (d_sig, new_inputs, idents) +} diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index e07eb2e490b71..9d4e64027c14f 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -163,6 +163,58 @@ pub(crate) struct AllocMustStatics { #[primary_span] pub(crate) span: Span, } +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_unknown_activity)] +pub(crate) struct AutoDiffUnknownActivity { + #[primary_span] + pub(crate) span: Span, + pub(crate) act: String, +} +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_ty_activity)] +pub(crate) struct AutoDiffInvalidTypeForActivity { + #[primary_span] + pub(crate) span: Span, + pub(crate) act: String, +} +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_number_activities)] +pub(crate) struct AutoDiffInvalidNumberActivities { + #[primary_span] + pub(crate) span: Span, + pub(crate) expected: usize, + pub(crate) found: usize, +} +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_mode_activity)] +pub(crate) struct AutoDiffInvalidApplicationModeAct { + #[primary_span] + pub(crate) span: Span, + pub(crate) mode: String, + pub(crate) act: String, +} + +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_mode)] +pub(crate) struct AutoDiffInvalidMode { + #[primary_span] + pub(crate) span: Span, + pub(crate) mode: String, +} + +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff)] +pub(crate) struct AutoDiffInvalidApplication { + #[primary_span] + pub(crate) span: Span, +} + +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff_not_build)] +pub(crate) struct AutoDiffSupportNotBuild { + #[primary_span] + pub(crate) span: Span, +} #[derive(Diagnostic)] #[diag(builtin_macros_concat_bytes_invalid)] diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs index f60b73fbe9b13..7c64c148b8fd2 100644 --- a/compiler/rustc_builtin_macros/src/lib.rs +++ b/compiler/rustc_builtin_macros/src/lib.rs @@ -14,6 +14,7 @@ #![feature(lint_reasons)] #![feature(proc_macro_internals)] #![feature(proc_macro_quote)] +#![cfg_attr(not(bootstrap), feature(autodiff))] #![recursion_limit = "256"] extern crate proc_macro; @@ -29,6 +30,7 @@ use rustc_span::symbol::sym; mod alloc_error_handler; mod assert; +mod autodiff; mod cfg; mod cfg_accessible; mod cfg_eval; @@ -105,6 +107,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) { register_attr! { alloc_error_handler: alloc_error_handler::expand, + autodiff: autodiff::expand, bench: test::expand_bench, cfg_accessible: cfg_accessible::Expander, cfg_eval: cfg_eval::expand, diff --git a/compiler/rustc_codegen_llvm/messages.ftl b/compiler/rustc_codegen_llvm/messages.ftl index 7a86ddc7556a0..23872d7fa1dcd 100644 --- a/compiler/rustc_codegen_llvm/messages.ftl +++ b/compiler/rustc_codegen_llvm/messages.ftl @@ -60,6 +60,11 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO codegen_llvm_run_passes = failed to run LLVM passes codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err} +codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error} +codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error} + +codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto + codegen_llvm_sanitizer_memtag_requires_mte = `-Zsanitizer=memtag` requires `-Ctarget-feature=+mte` diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs index 97dc401251cf0..4f36f6c825373 100644 --- a/compiler/rustc_codegen_llvm/src/abi.rs +++ b/compiler/rustc_codegen_llvm/src/abi.rs @@ -248,6 +248,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(self.layout.size.bytes()), MemFlags::empty(), + None, ); bx.lifetime_end(llscratch, scratch_size); diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index 481741bb1277a..7be288886d039 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -1,5 +1,6 @@ //! Set and unset common attributes on LLVM values. +use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs; use rustc_codegen_ssa::traits::*; use rustc_hir::def_id::DefId; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; @@ -294,6 +295,7 @@ pub fn from_fn_attrs<'ll, 'tcx>( instance: ty::Instance<'tcx>, ) { let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id()); + let autodiff_attrs: &AutoDiffAttrs = cx.tcx.autodiff_attrs(instance.def_id()); let mut to_add = SmallVec::<[_; 16]>::new(); @@ -311,6 +313,8 @@ pub fn from_fn_attrs<'ll, 'tcx>( let inline = if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) { InlineAttr::Hint + } else if autodiff_attrs.is_active() { + InlineAttr::Never } else { codegen_fn_attrs.inline }; diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index e9e8ade09b77b..014e38f11be65 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -616,7 +616,11 @@ pub(crate) fn run_pass_manager( } let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO }; let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); - write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage)?; + // We will run this again with different values in the context of automatic differentiation. + let first_run = true; + let noop = false; + debug!("running llvm pm opt pipeline"); + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop)?; } debug!("lto done"); Ok(()) @@ -714,7 +718,12 @@ pub unsafe fn optimize_thin_module( let llcx = llvm::LLVMRustContextCreate(cgcx.fewer_names); let llmod_raw = parse_module(llcx, module_name, thin_module.data(), &dcx)? as *const _; let mut module = ModuleCodegen { - module_llvm: ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }, + module_llvm: ModuleLlvm { + llmod_raw, + llcx, + tm: ManuallyDrop::new(tm), + typetrees: Default::default(), + }, name: thin_module.name().to_string(), kind: ModuleKind::Regular, }; diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 75f99f964d0a9..ad4bd05a08b46 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1,3 +1,22 @@ +#![allow(unused_imports)] +#![allow(unused_variables)] +use crate::llvm::LLVMGetFirstBasicBlock; +use crate::llvm::LLVMBuildCondBr; +use crate::llvm::LLVMBuildICmp; +use crate::llvm::LLVMBuildRetVoid; +use crate::llvm::LLVMRustEraseInstBefore; +use crate::llvm::LLVMRustHasDbgMetadata; +use crate::llvm::LLVMRustHasMetadata; +use crate::llvm::LLVMRustRemoveFncAttr; +use crate::llvm::LLVMMetadataAsValue; +use crate::llvm::LLVMRustGetLastInstruction; +use crate::llvm::LLVMRustDIGetInstMetadata; +use crate::llvm::LLVMRustDIGetInstMetadataOfTy; +use crate::llvm::LLVMRustgetFirstNonPHIOrDbgOrLifetime; +use crate::llvm::LLVMRustGetTerminator; +use crate::llvm::LLVMRustEraseInstFromParent; +use crate::llvm::LLVMRustEraseBBFromParent; +//use crate::llvm::LLVMEraseFromParent; use crate::back::lto::ThinBuffer; use crate::back::owned_target_machine::OwnedTargetMachine; use crate::back::profiling::{ @@ -10,13 +29,33 @@ use crate::errors::{ WithLlvmError, WriteBytecode, }; use crate::llvm::{self, DiagnosticInfo, PassManager}; +use crate::llvm::{ + enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind, BasicBlock, FreeTypeAnalysis, + CreateEnzymeLogic, CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, + LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet, + LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, + LLVMCreateStringAttribute, LLVMDeleteFunction, LLVMDisposeBuilder, LLVMDumpModule, + LLVMGetBasicBlockTerminator, LLVMGetFirstFunction, LLVMGetModuleContext, + LLVMGetNextFunction, LLVMGetParams, LLVMGetReturnType, LLVMRustGetFunctionType, LLVMGetStringAttributeAtIndex, + LLVMGlobalGetValueType, LLVMIsEnumAttribute, LLVMIsStringAttribute, LLVMPositionBuilderAtEnd, + LLVMRemoveStringAttributeAtIndex, LLVMReplaceAllUsesWith, LLVMRustAddEnumAttributeAtIndex, + LLVMRustAddFunctionAttributes, LLVMRustGetEnumAttributeAtIndex, + LLVMRustRemoveEnumAttributeAtIndex, LLVMSetValueName2, LLVMVerifyFunction, + LLVMVoidTypeInContext, Value, +}; use crate::llvm_util; use crate::type_::Type; +use crate::typetree::to_enzyme_typetree; +use crate::DiffTypeTree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; +use llvm::IntPredicate; +use llvm::LLVMRustDISetInstMetadata; use llvm::{ - LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, + LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock, }; +use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; +use rustc_ast::expand::typetree::FncTree; use rustc_codegen_ssa::back::link::ensure_removed; use rustc_codegen_ssa::back::write::{ BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig, @@ -24,6 +63,7 @@ use rustc_codegen_ssa::back::write::{ }; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::small_c_str::SmallCStr; use rustc_errors::{DiagCtxt, FatalError, Level}; @@ -37,7 +77,7 @@ use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, Tl use crate::llvm::diagnostic::OptimizationDiagnosticKind; use libc::{c_char, c_int, c_uint, c_void, size_t}; -use std::ffi::CString; +use std::ffi::{CStr, CString}; use std::fs; use std::io::{self, Write}; use std::path::{Path, PathBuf}; @@ -510,9 +550,35 @@ pub(crate) unsafe fn llvm_optimize( config: &ModuleConfig, opt_level: config::OptLevel, opt_stage: llvm::OptStage, + first_run: bool, + noop: bool, ) -> Result<(), FatalError> { - let unroll_loops = + if noop { + return Ok(()); + } + // Enzyme: + // We want to simplify / optimize functions before AD. + // However, benchmarks show that optimizations increasing the code size + // tend to reduce AD performance. Therefore activate them first, then differentiate the code + // and finally re-optimize the module, now with all optimizations available. + // RIP compile time. + + let unroll_loops; + let vectorize_slp; + let vectorize_loop; + + if first_run { + unroll_loops = false; + vectorize_slp = false; + vectorize_loop = false; + } else { + unroll_loops = opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + vectorize_slp = config.vectorize_slp; + vectorize_loop = config.vectorize_loop; + dbg!("Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}", unroll_loops, vectorize_slp, vectorize_loop); + } + let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -567,8 +633,8 @@ pub(crate) unsafe fn llvm_optimize( using_thin_buffers, config.merge_functions, unroll_loops, - config.vectorize_slp, - config.vectorize_loop, + vectorize_slp, + vectorize_loop, config.emit_lifetime_markers, sanitizer_options.as_ref(), pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()), @@ -589,6 +655,623 @@ pub(crate) unsafe fn llvm_optimize( result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses)) } +fn get_params(fnc: &Value) -> Vec<&Value> { + unsafe { + let param_num = LLVMCountParams(fnc) as usize; + let mut fnc_args: Vec<&Value> = vec![]; + fnc_args.reserve(param_num); + LLVMGetParams(fnc, fnc_args.as_mut_ptr()); + fnc_args.set_len(param_num); + fnc_args + } +} + +// DESIGN: +// Today we have our placeholder function, and our Enzyme generated one. +// We create a wrapper function and delete the placeholder body. +// We then call the wrapper from the placeholder. +// +// Soon, we won't delete the whole placeholder, but just the loop, +// and the two inline asm sections. For now we can still call the wrapper. +// In the future we call our Enzyme generated function directly and unwrap the return +// struct in our original placeholder. +// +// define internal double @_ZN2ad3bar17ha38374e821680177E(ptr align 8 %0, ptr align 8 %1, double %2) unnamed_addr #17 !dbg !13678 { +// %4 = alloca double, align 8 +// %5 = alloca ptr, align 8 +// %6 = alloca ptr, align 8 +// %7 = alloca { ptr, double }, align 8 +// store ptr %0, ptr %6, align 8 +// call void @llvm.dbg.declare(metadata ptr %6, metadata !13682, metadata !DIExpression()), !dbg !13685 +// store ptr %1, ptr %5, align 8 +// call void @llvm.dbg.declare(metadata ptr %5, metadata !13683, metadata !DIExpression()), !dbg !13685 +// store double %2, ptr %4, align 8 +// call void @llvm.dbg.declare(metadata ptr %4, metadata !13684, metadata !DIExpression()), !dbg !13686 +// call void asm sideeffect alignstack inteldialect "NOP", "~{dirflag},~{fpsr},~{flags},~{memory}"(), !dbg !13687, !srcloc !23 +// %8 = call double @_ZN2ad3foo17h95b548a9411653b2E(ptr align 8 %0), !dbg !13687 +// %9 = call double @_ZN4core4hint9black_box17h7bd67a41b0f12bdfE(double %8), !dbg !13687 +// store ptr %1, ptr %7, align 8, !dbg !13687 +// %10 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 1, !dbg !13687 +// store double %2, ptr %10, align 8, !dbg !13687 +// %11 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 0, !dbg !13687 +// %12 = load ptr, ptr %11, align 8, !dbg !13687, !nonnull !23, !align !1047, !noundef !23 +// %13 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 1, !dbg !13687 +// %14 = load double, ptr %13, align 8, !dbg !13687, !noundef !23 +// %15 = call { ptr, double } @_ZN4core4hint9black_box17h669f3b22afdcb487E(ptr align 8 %12, double %14), !dbg !13687 +// %16 = extractvalue { ptr, double } %15, 0, !dbg !13687 +// %17 = extractvalue { ptr, double } %15, 1, !dbg !13687 +// br label %18, !dbg !13687 +// +//18: ; preds = %18, %3 +// br label %18, !dbg !13687 + + +unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, + llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize]) { + + // first, remove all calls from fnc + let bb = LLVMGetFirstBasicBlock(tgt); + let br = LLVMRustGetTerminator(bb); + LLVMRustEraseInstFromParent(br); + + // now add a call to inner. + // append call to src at end of bb. + let f_ty = LLVMRustGetFunctionType(src); + + let inner_param_num = LLVMCountParams(src); + let outer_param_num = LLVMCountParams(tgt); + let outer_args: Vec<&Value> = get_params(tgt); + let inner_args: Vec<&Value> = get_params(src); + let mut call_args: Vec<&Value> = vec![]; + + let mut safety_vals = vec![]; + let builder = LLVMCreateBuilderInContext(llcx); + let last_inst = LLVMRustGetLastInstruction(bb).unwrap(); + LLVMPositionBuilderAtEnd(builder, bb); + + let safety_run_checks; + if std::env::var("ENZYME_NO_SAFETY_CHECKS").is_ok() { + safety_run_checks = false; + } else { + safety_run_checks = true; + } + + if inner_param_num == outer_param_num { + call_args = outer_args; + } else { + trace!("Different number of args, adjusting"); + let mut outer_pos: usize = 0; + let mut inner_pos: usize = 0; + // copy over if they are identical. + // If not, skip the outer arg (and assert it's int). + while outer_pos < outer_param_num as usize { + let inner_arg = inner_args[inner_pos]; + let outer_arg = outer_args[outer_pos]; + let inner_arg_ty = llvm::LLVMTypeOf(inner_arg); + let outer_arg_ty = llvm::LLVMTypeOf(outer_arg); + if inner_arg_ty == outer_arg_ty { + call_args.push(outer_arg); + inner_pos += 1; + outer_pos += 1; + } else { + // out: (ptr, <>int1, ptr, int2) + // inner: (ptr, <>ptr, int) + // goal: (ptr, ptr, int1), skipping int2 + // we are here: <> + assert!(llvm::LLVMRustGetTypeKind(outer_arg_ty) == llvm::TypeKind::Integer); + assert!(llvm::LLVMRustGetTypeKind(inner_arg_ty) == llvm::TypeKind::Pointer); + let next_outer_arg = outer_args[outer_pos + 1]; + let next_inner_arg = inner_args[inner_pos + 1]; + let next_outer_arg_ty = llvm::LLVMTypeOf(next_outer_arg); + let next_inner_arg_ty = llvm::LLVMTypeOf(next_inner_arg); + assert!(llvm::LLVMRustGetTypeKind(next_outer_arg_ty) == llvm::TypeKind::Pointer); + assert!(llvm::LLVMRustGetTypeKind(next_inner_arg_ty) == llvm::TypeKind::Integer); + let next2_outer_arg = outer_args[outer_pos + 2]; + let next2_outer_arg_ty = llvm::LLVMTypeOf(next2_outer_arg); + assert!(llvm::LLVMRustGetTypeKind(next2_outer_arg_ty) == llvm::TypeKind::Integer); + call_args.push(next_outer_arg); + call_args.push(outer_arg); + + outer_pos += 3; + inner_pos += 2; + + + if safety_run_checks { + + // Now we assert if int1 <= int2 + let res = LLVMBuildICmp( + builder, + IntPredicate::IntULE as u32, + outer_arg, + next2_outer_arg, + "safety_check".as_ptr() as *const c_char); + safety_vals.push(res); + } + } + } + } + + if inner_param_num as usize != call_args.len() { + panic!("Args len shouldn't differ. Please report this. {} : {}", inner_param_num, call_args.len()); + } + + // Now add the safety checks. + if !safety_vals.is_empty() { + dbg!("Adding safety checks"); + assert!(safety_run_checks); + // first we create one bb per check and two more for the fail and success case. + let fail_bb = LLVMAppendBasicBlockInContext(llcx, tgt, "ad_safety_fail".as_ptr() as *const c_char); + let success_bb = LLVMAppendBasicBlockInContext(llcx, tgt, "ad_safety_success".as_ptr() as *const c_char); + for i in 1..safety_vals.len() { + // 'or' all safety checks together + // Doing some binary tree style or'ing here would be more efficient, + // but I assume LLVM will opt it anyway + let prev = safety_vals[i - 1]; + let curr = safety_vals[i]; + let res = llvm::LLVMBuildOr(builder, prev, curr, "safety_check".as_ptr() as *const c_char); + safety_vals[i] = res; + } + LLVMBuildCondBr(builder, safety_vals.last().unwrap(), success_bb, fail_bb); + LLVMPositionBuilderAtEnd(builder, fail_bb); + + + let panic_name: CString = get_panic_name(llmod); + + let mut arg_vec = vec![add_panic_msg_to_global(llmod, llcx)]; + + let fnc1 = llvm::LLVMGetNamedFunction(llmod, panic_name.as_ptr() as *const c_char); + assert!(fnc1.is_some()); + let fnc1 = fnc1.unwrap(); + let ty = LLVMRustGetFunctionType(fnc1); + let call = LLVMBuildCall2(builder, ty, fnc1, arg_vec.as_mut_ptr(), arg_vec.len(), panic_name.as_ptr() as *const c_char); + llvm::LLVMSetTailCall(call, 1); + llvm::LLVMBuildUnreachable(builder); + LLVMPositionBuilderAtEnd(builder, success_bb); + } + + let inner_fnc_name = llvm::get_value_name(src); + let c_inner_fnc_name = CString::new(inner_fnc_name).unwrap(); + + let mut struct_ret = LLVMBuildCall2( + builder, + f_ty, + src, + call_args.as_mut_ptr(), + call_args.len(), + c_inner_fnc_name.as_ptr(), + ); + + // Add dummy dbg info to our newly generated call, if we have any. + let inst = LLVMRustgetFirstNonPHIOrDbgOrLifetime(bb).unwrap(); + let md_ty = llvm::LLVMGetMDKindIDInContext( + llcx, + "dbg".as_ptr() as *const c_char, + "dbg".len() as c_uint, + ); + + if LLVMRustHasMetadata(last_inst, md_ty) { + let md = LLVMRustDIGetInstMetadata(last_inst); + let md_val = LLVMMetadataAsValue(llcx, md); + let md2 = llvm::LLVMSetMetadata(struct_ret, md_ty, md_val); + } else { + trace!("No dbg info"); + } + + // Now clean up placeholder code. + LLVMRustEraseInstBefore(bb, last_inst); + + let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src)); + let f_is_struct = llvm::LLVMRustIsStructType(f_return_type); + let void_type = LLVMVoidTypeInContext(llcx); + // Now unwrap the struct_ret if it's actually a struct + if f_is_struct { + let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); + if num_elem_in_ret_struct == 1 { + let inner_grad_name = "foo".to_string(); + let c_inner_grad_name = CString::new(inner_grad_name).unwrap(); + struct_ret = LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); + } + } + if f_return_type != void_type { + let _ret = LLVMBuildRet(builder, struct_ret); + } else { + let _ret = LLVMBuildRetVoid(builder); + } + LLVMDisposeBuilder(builder); + let _fnc_ok = + LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); +} + +unsafe fn get_panic_name(llmod: &llvm::Module) -> CString { + // The names are mangled and their ending changes based on a hash, so just take whichever. + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let fnc_name = llvm::get_value_name(lf); + let fnc_name: String = String::from_utf8(fnc_name.to_vec()).unwrap(); + if fnc_name.starts_with("_ZN4core9panicking14panic_explicit") { + return CString::new(fnc_name).unwrap(); + } else if fnc_name.starts_with("_RN4core9panicking14panic_explicit") { + return CString::new(fnc_name).unwrap(); + } + } else { + break; + } + } + panic!("Could not find panic function"); +} +unsafe fn add_panic_msg_to_global<'a>(llmod: &'a llvm::Module, llcx: &'a llvm::Context) -> &'a llvm::Value { + use llvm::*; + + // Convert the message to a CString + let msg = "autodiff safety check failed!"; + let cmsg = CString::new(msg).unwrap(); + + let msg_global_name = "ad_safety_msg".to_string(); + let cmsg_global_name = CString::new(msg_global_name).unwrap(); + + // Get the length of the message + let msg_len = msg.len(); + + // Create the array type + let i8_array_type = LLVMRustArrayType(LLVMInt8TypeInContext(llcx), msg_len as u64); + + // Create the string constant + let string_const_val = LLVMConstStringInContext(llcx, cmsg.as_ptr() as *const i8, msg_len as u32, 0); + + // Create the array initializer + let mut array_elems: Vec<_> = Vec::with_capacity(msg_len); + for i in 0..msg_len { + let char_value = LLVMConstInt(LLVMInt8TypeInContext(llcx), cmsg.as_bytes()[i] as u64, 0); + array_elems.push(char_value); + } + let array_initializer = LLVMConstArray(LLVMInt8TypeInContext(llcx), array_elems.as_mut_ptr(), msg_len as u32); + + // Create the struct type + let global_type = LLVMStructTypeInContext(llcx, [i8_array_type].as_mut_ptr(), 1, 0); + + // Create the struct initializer + let struct_initializer = LLVMConstStructInContext(llcx, [array_initializer].as_mut_ptr(), 1, 0); + + // Add the global variable to the module + let global_var = LLVMAddGlobal(llmod, global_type, cmsg_global_name.as_ptr() as *const i8); + LLVMRustSetLinkage(global_var, Linkage::PrivateLinkage); + LLVMSetInitializer(global_var, struct_initializer); + + global_var +} + +// As unsafe as it can be. +#[allow(unused_variables)] +#[allow(unused)] +pub(crate) unsafe fn enzyme_ad( + llmod: &llvm::Module, + llcx: &llvm::Context, + diag_handler: &DiagCtxt, + item: AutoDiffItem, + logic_ref: EnzymeLogicRef, +) -> Result<(), FatalError> { + let autodiff_mode = item.attrs.mode; + let rust_name = item.source; + let rust_name2 = &item.target; + + let args_activity = item.attrs.input_activity.clone(); + let ret_activity: DiffActivity = item.attrs.ret_activity; + + // get target and source function + let name = CString::new(rust_name.to_owned()).unwrap(); + let name2 = CString::new(rust_name2.clone()).unwrap(); + let src_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()); + let src_fnc = match src_fnc_opt { + Some(x) => x, + None => { + return Err(llvm_err( + diag_handler, + LlvmError::PrepareAutoDiff { + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find src function".to_owned(), + }, + )); + } + }; + let target_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()); + let target_fnc = match target_fnc_opt { + Some(x) => x, + None => { + return Err(llvm_err( + diag_handler, + LlvmError::PrepareAutoDiff { + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find target function".to_owned(), + }, + )); + } + }; + let src_num_args = llvm::LLVMCountParams(src_fnc); + let target_num_args = llvm::LLVMCountParams(target_fnc); + // A really simple check + assert!(src_num_args <= target_num_args); + + // create enzyme typetrees + let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + + let input_tts = + item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); + let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); + + let type_analysis: EnzymeTypeAnalysisRef = + CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); + + llvm::set_strict_aliasing(false); + + if std::env::var("ENZYME_PRINT_TA").is_ok() { + llvm::set_print_type(true); + } + if std::env::var("ENZYME_PRINT_AA").is_ok() { + llvm::set_print_activity(true); + } + if std::env::var("ENZYME_PRINT_PERF").is_ok() { + llvm::set_print_perf(true); + } + if std::env::var("ENZYME_PRINT").is_ok() { + llvm::set_print(true); + } + + let mode = match autodiff_mode { + DiffMode::Forward => DiffMode::Forward, + DiffMode::Reverse => DiffMode::Reverse, + DiffMode::ForwardFirst => DiffMode::Forward, + DiffMode::ReverseFirst => DiffMode::Reverse, + _ => unreachable!(), + }; + + let void_type = LLVMVoidTypeInContext(llcx); + let return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src_fnc)); + let void_ret = void_type == return_type; + let mut tmp = match mode { + DiffMode::Forward => enzyme_rust_forward_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + input_tts, + output_tt, + void_ret, + ), + DiffMode::Reverse => enzyme_rust_reverse_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + input_tts, + output_tt, + ), + _ => unreachable!(), + }; + let mut res: &Value = tmp.0; + let size_positions: Vec = tmp.1; + + let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res)); + + let rev_mode = item.attrs.mode == DiffMode::Reverse; + create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions); + // TODO: implement drop for wrapper type? + FreeTypeAnalysis(type_analysis); + + Ok(()) +} + + +pub(crate) unsafe fn differentiate( + module: &ModuleCodegen, + cgcx: &CodegenContext, + diff_items: Vec, + _typetrees: FxHashMap, + config: &ModuleConfig, +) -> Result<(), FatalError> { + for item in &diff_items { + trace!("{}", item); + } + + let llmod = module.module_llvm.llmod(); + let llcx = &module.module_llvm.llcx; + let diag_handler = cgcx.create_dcx(); + + llvm::set_strict_aliasing(false); + + if std::env::var("ENZYME_LOOSE_TYPES").is_ok() { + dbg!("Setting loose types to true"); + llvm::set_loose_types(true); + } + + // Before dumping the module, we want all the tt to become part of the module. + for (i, item) in diff_items.iter().enumerate() { + let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + let tt: FncTree = FncTree { + args: item.inputs.clone(), + ret: item.output.clone(), + }; + let name = CString::new(item.source.clone()).unwrap(); + let fn_def: &llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap(); + crate::builder::add_tt2(llmod, llcx, fn_def, tt); + + // Before dumping the module, we also might want to add dummy functions, which will + // trigger the LLVMEnzyme pass to run on them, if we invoke the opt binary. + // This is super helpfull if we want to create a MWE bug reproducer, e.g. to run in + // Enzyme's compiler explorer. TODO: Can we run llvm-extract on the module to remove all other functions? + if std::env::var("ENZYME_OPT").is_ok() { + dbg!("Enable extra debug helper to debug Enzyme through the opt plugin"); + crate::builder::add_opt_dbg_helper(llmod, llcx, fn_def, item.attrs.clone(), i); + } + } + + if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() || std::env::var("ENZYME_OPT").is_ok(){ + unsafe { + LLVMDumpModule(llmod); + } + } + + if std::env::var("ENZYME_INLINE").is_ok() { + dbg!("Setting inline to true"); + llvm::set_inline(true); + } + + if std::env::var("ENZYME_TT_DEPTH").is_ok() { + let depth = std::env::var("ENZYME_TT_DEPTH").unwrap(); + let depth = depth.parse::().unwrap(); + assert!(depth >= 1); + llvm::set_max_int_offset(depth); + } + if std::env::var("ENZYME_TT_WIDTH").is_ok() { + let width = std::env::var("ENZYME_TT_WIDTH").unwrap(); + let width = width.parse::().unwrap(); + assert!(width >= 1); + llvm::set_max_type_offset(width); + } + + if std::env::var("ENZYME_RUNTIME_ACTIVITY").is_ok() { + dbg!("Setting runtime activity check to true"); + llvm::set_runtime_activity_check(true); + } + + let differentiate = !diff_items.is_empty(); + let mut first_order_items: Vec = vec![]; + let mut higher_order_items: Vec = vec![]; + for item in diff_items { + if item.attrs.mode == DiffMode::ForwardFirst || item.attrs.mode == DiffMode::ReverseFirst{ + first_order_items.push(item); + } else { + // default + higher_order_items.push(item); + } + } + + let mut fnc_opt = false; + if std::env::var("ENZYME_ENABLE_FNC_OPT").is_ok() { + dbg!("Enable extra optimizations for Enzyme"); + fnc_opt = true; + } + + // If a function is a base for some higher order ad, always optimize + let fnc_opt_base = true; + let logic_ref_opt: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt_base as u8); + + for item in first_order_items { + let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt); + assert!(res.is_ok()); + } + + // For the rest, follow the user choice on debug vs release. + // Reuse the opt one if possible for better compile time (Enzyme internal caching). + let logic_ref = match fnc_opt { + true => logic_ref_opt, + false => CreateEnzymeLogic(fnc_opt as u8), + }; + for item in higher_order_items { + let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref); + assert!(res.is_ok()); + } + + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let attr = LLVMGetStringAttributeAtIndex( + lf, + c_uint::MAX, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + ); + if LLVMIsStringAttribute(attr) { + LLVMRemoveStringAttributeAtIndex( + lf, + c_uint::MAX, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + ); + } else { + LLVMRustRemoveEnumAttributeAtIndex( + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + } + } else { + break; + } + } + if std::env::var("ENZYME_PRINT_MOD_AFTER_ENZYME").is_ok() { + unsafe { + LLVMDumpModule(llmod); + } + } + + + if std::env::var("ENZYME_NO_MOD_OPT_AFTER").is_ok() || !differentiate { + trace!("Skipping module optimization after automatic differentiation"); + } else { + if let Some(opt_level) = config.opt_level { + let opt_stage = match cgcx.lto { + Lto::Fat => llvm::OptStage::PreLinkFatLTO, + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, + _ => llvm::OptStage::PreLinkNoLTO, + }; + let mut first_run = false; + dbg!("Running Module Optimization after differentiation"); + if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() { + // disables vectorization and loop unrolling + first_run = true; + } + if std::env::var("ENZYME_ALT_PIPELINE").is_ok() { + dbg!("Running first postAD optimization"); + first_run = true; + } + let noop = false; + llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?; + } + if std::env::var("ENZYME_ALT_PIPELINE").is_ok() { + dbg!("Running Second postAD optimization"); + if let Some(opt_level) = config.opt_level { + let opt_stage = match cgcx.lto { + Lto::Fat => llvm::OptStage::PreLinkFatLTO, + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, + _ => llvm::OptStage::PreLinkNoLTO, + }; + let mut first_run = false; + dbg!("Running Module Optimization after differentiation"); + if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() { + // enables vectorization and loop unrolling + first_run = false; + } + let noop = false; + llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?; + } + } + } + + if std::env::var("ENZYME_PRINT_MOD_AFTER_OPTS").is_ok() { + unsafe { + LLVMDumpModule(llmod); + } + } + + Ok(()) +} + // Unsafe due to LLVM calls. pub(crate) unsafe fn optimize( cgcx: &CodegenContext, @@ -611,6 +1294,46 @@ pub(crate) unsafe fn optimize( llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()); } + // This code enables Enzyme to differentiate code containing Rust enums. + // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing + // away the enums and allows Enzyme to understand why a value can be of different types in + // different code sections. We remove this attribute after Enzyme is done, to not affect the + // rest of the compilation. + { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let myhwv = ""; + let prevattr = LLVMRustGetEnumAttributeAtIndex( + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + if LLVMIsEnumAttribute(prevattr) { + let attr = LLVMCreateStringAttribute( + llcx, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + myhwv.as_ptr() as *const c_char, + myhwv.as_bytes().len() as c_uint, + ); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } else { + LLVMRustAddEnumAttributeAtIndex( + llcx, + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + } + } else { + break; + } + } + } + if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { Lto::Fat => llvm::OptStage::PreLinkFatLTO, @@ -618,7 +1341,16 @@ pub(crate) unsafe fn optimize( _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, _ => llvm::OptStage::PreLinkNoLTO, }; - return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage); + // Second run only relevant for AD + let first_run = true; + let noop; + if std::env::var("ENZYME_ALT_PIPELINE").is_ok() { + noop = true; + dbg!("Skipping PreAD optimization"); + } else { + noop = false; + } + return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop); } Ok(()) } diff --git a/compiler/rustc_codegen_llvm/src/base.rs b/compiler/rustc_codegen_llvm/src/base.rs index 5dc271ccddb7c..344c5763483c8 100644 --- a/compiler/rustc_codegen_llvm/src/base.rs +++ b/compiler/rustc_codegen_llvm/src/base.rs @@ -27,11 +27,13 @@ use rustc_data_structures::small_c_str::SmallCStr; use rustc_middle::dep_graph; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::mir::mono::{Linkage, Visibility}; -use rustc_middle::ty::TyCtxt; use rustc_session::config::DebugInfo; use rustc_span::symbol::Symbol; use rustc_target::spec::SanitizerSet; +use rustc_middle::mir::mono::MonoItem; +use rustc_middle::ty::{ParamEnv, TyCtxt, fnc_typetrees}; + use std::time::Instant; pub struct ValueIter<'ll> { @@ -86,6 +88,15 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen let mono_items = cx.codegen_unit.items_in_deterministic_order(cx.tcx); for &(mono_item, data) in &mono_items { mono_item.predefine::>(&cx, data.linkage, data.visibility); + let inst = match mono_item { + MonoItem::Fn(instance) => instance, + _ => continue, + }; + let fn_ty = inst.ty(tcx, ParamEnv::empty()); + let _fnc_tree = fnc_typetrees(tcx, fn_ty, &mut vec![], None); + //trace!("codegen_module: predefine fn {}", inst); + //trace!("{} \n {:?} \n {:?}", inst, fn_ty, _fnc_tree); + // Manuel: TODO } // ... and now that we have everything pre-defined, fill out those definitions. diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 8f60175a6031c..78cb56b46916b 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -8,6 +8,7 @@ use crate::type_::Type; use crate::type_of::LayoutLlvmExt; use crate::value::Value; use libc::{c_char, c_uint}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity}; use rustc_codegen_ssa::common::{IntPredicate, RealPredicate, SynchronizationScope, TypeKind}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::PlaceRef; @@ -30,6 +31,10 @@ use std::iter; use std::ops::Deref; use std::ptr; +use rustc_ast::expand::autodiff_attrs::DiffMode; +use crate::typetree::to_enzyme_typetree; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; + // All Builders must have an llfn associated with them #[must_use] pub struct Builder<'a, 'll, 'tcx> { @@ -134,6 +139,179 @@ macro_rules! builder_methods_for_value_instructions { } } +pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) { + let inputs = tt.args; + let ret_tt: TypeTree = tt.ret; + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + let attr_name = "enzyme_type"; + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); + for (i, &ref input) in inputs.iter().enumerate() { + let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) }; + let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) }; + unsafe { + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::LLVMRustAddFncParamAttr(fn_def, i as u32, attr); + } + unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) }; + } + let ret_attr = unsafe { + let c_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + let c_str = llvm::EnzymeTypeTreeToString(c_tt.inner); + let c_str = std::ffi::CStr::from_ptr(c_str); + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + attr + }; + unsafe { + llvm::LLVMRustAddRetFncAttr(fn_def, ret_attr); + } +} + +#[allow(unused)] +pub fn add_opt_dbg_helper<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, val: &'ll Value, attrs: AutoDiffAttrs, i: usize) { + //pub mode: DiffMode, + //pub ret_activity: DiffActivity, + //pub input_activity: Vec, + let inputs = attrs.input_activity; + let outputs = attrs.ret_activity; + let ad_name = match attrs.mode { + DiffMode::Forward => "__enzyme_fwddiff", + DiffMode::Reverse => "__enzyme_autodiff", + DiffMode::ForwardFirst => "__enzyme_fwddiff", + DiffMode::ReverseFirst => "__enzyme_autodiff", + _ => panic!("Why are we here?"), + }; + + // Assuming that our val is the fnc square, want to generate the following llvm-ir: + // declare double @__enzyme_autodiff(...) + // + // define double @dsquare(double %x) { + // entry: + // %0 = tail call double (...) @__enzyme_autodiff(double (double)* nonnull @square, double %x) + // ret double %0 + // } + + let mut final_num_args; + unsafe { + let fn_ty = llvm::LLVMRustGetFunctionType(val); + let ret_ty = llvm::LLVMGetReturnType(fn_ty); + + // First we add the declaration of the __enzyme function + let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True); + let ad_fn = llvm::LLVMRustGetOrInsertFunction( + llmod, + ad_name.as_ptr() as *const c_char, + ad_name.len().try_into().unwrap(), + enzyme_ty, + ); + + let wrapper_name = String::from("enzyme_opt_helper_") + i.to_string().as_str(); + let wrapper_fn = llvm::LLVMRustGetOrInsertFunction( + llmod, + wrapper_name.as_ptr() as *const c_char, + wrapper_name.len().try_into().unwrap(), + fn_ty, + ); + let entry = llvm::LLVMAppendBasicBlockInContext(llcx, wrapper_fn, "entry".as_ptr() as *const c_char); + let builder = llvm::LLVMCreateBuilderInContext(llcx); + llvm::LLVMPositionBuilderAtEnd(builder, entry); + let num_args = llvm::LLVMCountParams(wrapper_fn); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(val); + // metadata !"enzyme_const" + let enzyme_const = llvm::LLVMMDStringInContext(llcx, "enzyme_const".as_ptr() as *const c_char, 12); + let enzyme_out = llvm::LLVMMDStringInContext(llcx, "enzyme_out".as_ptr() as *const c_char, 10); + let enzyme_dup = llvm::LLVMMDStringInContext(llcx, "enzyme_dup".as_ptr() as *const c_char, 10); + let enzyme_dupnoneed = llvm::LLVMMDStringInContext(llcx, "enzyme_dupnoneed".as_ptr() as *const c_char, 16); + final_num_args = num_args * 2 + 1; + for i in 0..num_args { + let arg = llvm::LLVMGetParam(wrapper_fn, i); + let activity = inputs[i as usize]; + let (activity, duplicated): (&Value, bool) = match activity { + DiffActivity::None => panic!(), + DiffActivity::Const => (enzyme_const, false), + DiffActivity::Active => (enzyme_out, false), + DiffActivity::ActiveOnly => (enzyme_out, false), + DiffActivity::Dual => (enzyme_dup, true), + DiffActivity::DualOnly => (enzyme_dupnoneed, true), + DiffActivity::Duplicated => (enzyme_dup, true), + DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true), + DiffActivity::FakeActivitySize => (enzyme_const, false), + }; + args.push(activity); + args.push(arg); + if duplicated { + final_num_args += 1; + args.push(arg); + } + } + + // declare void @__enzyme_autodiff(...) + + // define void @enzyme_opt_helper_0(ptr %0, ptr %1) { + // call void (...) @__enzyme_autodiff(ptr @ffff, ptr %0, ptr %1) + // ret void + // } + + let call = llvm::LLVMBuildCall2(builder, enzyme_ty, ad_fn, args.as_mut_ptr(), final_num_args as usize, ad_name.as_ptr() as *const c_char); + let void_ty = llvm::LLVMVoidTypeInContext(llcx); + if llvm::LLVMTypeOf(call) != void_ty { + llvm::LLVMBuildRet(builder, call); + } else { + llvm::LLVMBuildRetVoid(builder); + } + llvm::LLVMDisposeBuilder(builder); + + let _fnc_ok = + llvm::LLVMVerifyFunction(wrapper_fn, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + } + +} + +fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) { + let inputs = tt.args; + let _ret: TypeTree = tt.ret; + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + let attr_name = "enzyme_type"; + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); + for (i, &ref input) in inputs.iter().enumerate() { + let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) }; + let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) }; + unsafe { + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::LLVMRustAddParamAttr(val, i as u32, attr); + } + unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) }; + } +} + + impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { fn build(cx: &'a CodegenCx<'ll, 'tcx>, llbb: &'ll BasicBlock) -> Self { let bx = Builder::with_cx(cx); @@ -445,12 +623,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } - fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value { - unsafe { + fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align, tt: Option) -> &'ll Value { + let load = unsafe { let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED); llvm::LLVMSetAlignment(load, align.bytes() as c_uint); load + }; + if let Some(tt) = tt { + add_tt(self.cx().llmod, self.cx().llcx, load, tt); } + load } fn volatile_load(&mut self, ty: &'ll Type, ptr: &'ll Value) -> &'ll Value { @@ -645,8 +827,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } - fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align) -> &'ll Value { - self.store_with_flags(val, ptr, align, MemFlags::empty()) + fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align, tt: Option) -> &'ll Value { + let store = self.store_with_flags(val, ptr, align, MemFlags::empty()); + if let Some(tt) = tt { + add_tt(self.cx().llmod, self.cx().llcx, store, tt); + } + store } fn store_with_flags( @@ -874,11 +1060,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let val = unsafe { llvm::LLVMRustBuildMemCpy( self.llbuilder, dst, @@ -887,7 +1074,14 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + let llmod = self.cx.llmod; + let llcx = self.cx.llcx; + add_tt(llmod, llcx, val, tt); + } else { + trace!("builder: no tt"); } } @@ -899,11 +1093,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memmove not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let val = unsafe { llvm::LLVMRustBuildMemMove( self.llbuilder, dst, @@ -912,7 +1107,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + let llmod = self.cx.llmod; + let llcx = self.cx.llcx; + add_tt(llmod, llcx, val, tt); } } @@ -923,9 +1123,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { size: &'ll Value, align: Align, flags: MemFlags, + tt: Option, ) { let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let val = unsafe { llvm::LLVMRustBuildMemSet( self.llbuilder, ptr, @@ -933,7 +1134,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { fill_byte, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + let llmod = self.cx.llmod; + let llcx = self.cx.llcx; + add_tt(llmod, llcx, val, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 3053c4e0daaa9..c58373bf168f6 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -639,6 +639,10 @@ impl<'ll, 'tcx> MiscMethods<'tcx> for CodegenCx<'ll, 'tcx> { None } } + + fn create_autodiff(&self) -> Vec { + return vec![]; + } } impl<'ll> CodegenCx<'ll, '_> { diff --git a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs index 33bfde03a31c3..3fb4eaea16b01 100644 --- a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs +++ b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs @@ -409,7 +409,7 @@ fn add_unused_functions(cx: &CodegenCx<'_, '_>) { /// All items participating in code generation together with (instrumented) /// items inlined into them. fn codegenned_and_inlined_items(tcx: TyCtxt<'_>) -> DefIdSet { - let (items, cgus) = tcx.collect_and_partition_mono_items(()); + let (items, _, cgus) = tcx.collect_and_partition_mono_items(()); let mut visited = DefIdSet::default(); let mut result = items.clone(); diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 8db97d577ca76..7b7ee00959bdb 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -97,6 +97,11 @@ pub(crate) struct DlltoolFailImportLibrary<'a> { #[note] pub(crate) struct DynamicLinkingWithLTO; +#[derive(Diagnostic)] +#[diag(codegen_llvm_autodiff_without_lto)] +#[note] +pub(crate) struct AutoDiffWithoutLTO; + pub(crate) struct ParseTargetMachineConfig<'a>(pub LlvmError<'a>); impl IntoDiagnostic<'_, G> for ParseTargetMachineConfig<'_> { @@ -182,6 +187,8 @@ pub enum LlvmError<'a> { PrepareThinLtoModule, #[diag(codegen_llvm_parse_bitcode)] ParseBitcode, + #[diag(codegen_llvm_prepare_autodiff)] + PrepareAutoDiff { src: String, target: String, error: String }, } pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String); @@ -203,6 +210,7 @@ impl IntoDiagnostic<'_, G> for WithLlvmError<'_> { } PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err, ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err, + PrepareAutoDiff { .. } => fluent::codegen_llvm_prepare_autodiff_with_llvm_err, }; let mut diag = self.0.into_diagnostic(dcx, level); diag.set_primary_message(msg_with_llvm_err); diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index a81056ed3ad62..7506bbe436785 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -25,12 +25,18 @@ extern crate rustc_macros; #[macro_use] extern crate tracing; +use rustc_session::config::Lto; use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; -use errors::ParseTargetMachineConfig; +use errors::{ParseTargetMachineConfig, AutoDiffWithoutLTO}; + +#[allow(unused_imports)] +use llvm::TypeTree; + pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, @@ -38,7 +44,7 @@ use rustc_codegen_ssa::back::write::{ use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::{CodegenResults, CompiledModule}; -use rustc_data_structures::fx::FxIndexMap; +use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_errors::{DiagCtxt, ErrorGuaranteed, FatalError}; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; @@ -76,6 +82,7 @@ mod debuginfo; mod declare; mod errors; mod intrinsic; +mod typetree; // The following is a workaround that replaces `pub mod llvm;` and that fixes issue 53912. #[path = "llvm/mod.rs"] @@ -171,6 +178,7 @@ impl WriteBackendMethods for LlvmCodegenBackend { type TargetMachineError = crate::errors::LlvmError<'static>; type ThinData = back::lto::ThinData; type ThinBuffer = back::lto::ThinBuffer; + type TypeTree = DiffTypeTree; fn print_pass_timings(&self) { unsafe { let mut size = 0; @@ -253,6 +261,26 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + if cgcx.lto != Lto::Fat { + let dcx = cgcx.create_dcx(); + return Err(dcx.emit_almost_fatal(AutoDiffWithoutLTO{})); + } + unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) } + } + + // The typetrees contain all information, their order therefore is irrelevant. + #[allow(rustc::potential_query_instability)] + fn typetrees(module: &mut Self::Module) -> FxHashMap { + module.typetrees.drain().collect() + } } unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis @@ -405,6 +433,13 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[derive(Clone, Debug)] +pub struct DiffTypeTree { + pub ret_tt: TypeTree, + pub input_tt: Vec, +} + +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, @@ -412,6 +447,7 @@ pub struct ModuleLlvm { // This field is `ManuallyDrop` because it is important that the `TargetMachine` // is disposed prior to the `Context` being disposed otherwise UAFs can occur. tm: ManuallyDrop, + typetrees: FxHashMap, } unsafe impl Send for ModuleLlvm {} @@ -426,6 +462,7 @@ impl ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(create_target_machine(tcx, mod_name)), + typetrees: Default::default(), } } } @@ -438,6 +475,7 @@ impl ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(create_informational_target_machine(tcx.sess)), + typetrees: Default::default(), } } } @@ -459,7 +497,12 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }) + Ok(ModuleLlvm { + llmod_raw, + llcx, + tm: ManuallyDrop::new(tm), + typetrees: Default::default(), + }) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 81702baa8c053..0bf2eb5fb3a2f 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,6 +1,8 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] +use rustc_ast::expand::autodiff_attrs::DiffActivity; + use super::debuginfo::{ DIArray, DIBasicType, DIBuilder, DICompositeType, DIDerivedType, DIDescriptor, DIEnumerator, DIFile, DIFlags, DIGlobalVariableExpression, DILexicalBlock, DILocation, DINameSpace, @@ -832,7 +834,243 @@ pub type SelfProfileAfterPassCallback = unsafe extern "C" fn(*mut c_void); pub type GetSymbolsCallback = unsafe extern "C" fn(*mut c_void, *const c_char) -> *mut c_void; pub type GetSymbolsErrorCallback = unsafe extern "C" fn(*const c_char) -> *mut c_void; +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, +} + +#[allow(dead_code)] +pub(crate) unsafe fn enzyme_rust_forward_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_diffactivity: Vec, + ret_diffactivity: DiffActivity, + _input_tts: Vec, + _output_tt: TypeTree, + void_ret: bool, +) -> (&Value, Vec) { + let ret_activity = cdiffe_from(ret_diffactivity); + assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); + let mut input_activity: Vec = vec![]; + for input in input_diffactivity { + let act = cdiffe_from(input); + assert!( + act == CDIFFE_TYPE::DFT_CONSTANT + || act == CDIFFE_TYPE::DFT_DUP_ARG + || act == CDIFFE_TYPE::DFT_DUP_NONEED + ); + input_activity.push(act); + } + + // if we have void ret, this must be false; + let ret_primary_ret = if void_ret { + false + } else { + match ret_activity { + CDIFFE_TYPE::DFT_CONSTANT => true, + CDIFFE_TYPE::DFT_DUP_ARG => true, + CDIFFE_TYPE::DFT_DUP_NONEED => false, + _ => panic!("Implementation error in enzyme_rust_forward_diff."), + } + }; + trace!("ret_primary_ret: {}", &ret_primary_ret); + + //let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_activity.len()]; + let num_fnc_args = LLVMCountParams(fnc); + trace!("num_fnc_args: {}", num_fnc_args); + trace!("input_activity.len(): {}", input_activity.len()); + assert!(num_fnc_args == input_activity.len() as u32); + + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + let mut known_values = vec![kv_tmp; input_activity.len()]; + + let tree_tmp = TypeTree::new(); + let mut args_tree = vec![tree_tmp.inner; input_activity.len()]; + + //let mut args_tree = vec![std::ptr::null_mut(); input_activity.len()]; + //let ret_tt = std::ptr::null_mut(); + //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()]; + let ret_tt = TypeTree::new(); + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: ret_tt.inner, + KnownValues: known_values.as_mut_ptr(), + }; + + trace!("ret_activity: {}", &ret_activity); + for i in &input_activity { + trace!("input_activity i: {}", &i); + } + trace!("before calling Enzyme"); + let res = EnzymeCreateForwardDiff( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + CDerivativeMode::DEM_ForwardMode, // return value, dret_used, top_level which was 1 + 1, // free memory + 1, // vector mode width + Option::None, + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + ); + trace!("after calling Enzyme"); + (res, vec![]) +} + +pub(crate) unsafe fn enzyme_rust_reverse_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + rust_input_activity: Vec, + ret_activity: DiffActivity, + input_tts: Vec, + _output_tt: TypeTree, +) -> (&Value, Vec) { + let (primary_ret, ret_activity) = match ret_activity { + DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT), + DiffActivity::Active => (true, CDIFFE_TYPE::DFT_OUT_DIFF), + DiffActivity::ActiveOnly => (false, CDIFFE_TYPE::DFT_OUT_DIFF), + DiffActivity::None => (false, CDIFFE_TYPE::DFT_CONSTANT), + _ => panic!("Invalid return activity"), + }; + // This only is needed for split-mode AD, which we don't support. + // See Julia: + // https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3132 + // https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3092 + let diff_ret = false; + + let mut primal_sizes = vec![]; + let mut input_activity: Vec = vec![]; + for (i, &x) in rust_input_activity.iter().enumerate() { + if is_size(x) { + primal_sizes.push(i); + input_activity.push(CDIFFE_TYPE::DFT_CONSTANT); + continue; + } + input_activity.push(cdiffe_from(x)); + } + + //let args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_activity.len()]; + let num_fnc_args = LLVMCountParams(fnc); + println!("num_fnc_args: {}", num_fnc_args); + println!("input_activity.len(): {}", input_activity.len()); + assert!(num_fnc_args == input_activity.len() as u32); + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + let mut known_values = vec![kv_tmp; input_tts.len()]; + + let tree_tmp = TypeTree::new(); + let mut args_tree = vec![tree_tmp.inner; input_tts.len()]; + //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()]; + let ret_tt = TypeTree::new(); + //let mut args_tree = vec![std::ptr::null_mut(); input_tts.len()]; + //let ret_tt = std::ptr::null_mut(); + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: ret_tt.inner, + KnownValues: known_values.as_mut_ptr(), + }; + + trace!("primary_ret: {}", &primary_ret); + trace!("ret_activity: {}", &ret_activity); + for i in &input_activity { + trace!("input_activity i: {}", &i); + } + trace!("before calling Enzyme"); + let res = EnzymeCreatePrimalAndGradient( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + primary_ret as u8, + diff_ret as u8, //0 + CDerivativeMode::DEM_ReverseModeCombined, // return value, dret_used, top_level which was 1 + 1, // vector mode width + 1, // free memory + Option::None, + 0, // do not force anonymous tape + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + 0, + ); + trace!("after calling Enzyme"); + (res, primal_sizes) +} + extern "C" { + // TODO: can I just ignore the non void return + // EraseFromParent doesn't exist :( + //pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value; + // Enzyme + pub fn LLVMRustAddFncParamAttr<'a>( + F: &'a Value, + index: c_uint, + Attr: &'a Attribute + ); + + pub fn LLVMRustAddRetFncAttr(F: &Value, attr: &Attribute); + pub fn LLVMRustRemoveFncAttr(V: &Value, attr: AttributeKind); + pub fn LLVMRustHasDbgMetadata(I: &Value) -> bool; + pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; + pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value); + pub fn LLVMRustgetFirstNonPHIOrDbgOrLifetime<'a>(BB: &BasicBlock) -> Option<&'a Value>; + pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>; + pub fn LLVMRustDIGetInstMetadataOfTy(I: &Value, KindID: c_uint) -> &Metadata; + pub fn LLVMRustDIGetInstMetadata(I: &Value) -> &Metadata; + pub fn LLVMRustDISetInstMetadata<'a>(I: &Value, MD: &'a Metadata); + pub fn LLVMRustEraseBBFromParent(B: &BasicBlock); + pub fn LLVMRustEraseInstFromParent(V: &Value); + pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; + pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMRustIsStructType(T: &Type) -> bool; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; + pub fn LLVMDeleteFunction(V: &Value); + pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; + pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); + pub fn LLVMBuildCall2<'a>( + arg1: &Builder<'a>, + ty: &Type, + func: &Value, + args: *mut &Value, + num_args: size_t, + name: *const c_char, + ) -> &'a Value; + pub fn LLVMGetBasicBlockTerminator(B: &BasicBlock) -> &Value; + pub fn LLVMAddFunction<'a>(M: &Module, Name: *const c_char, Ty: &Type) -> &'a Value; + pub fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub fn LLVMGetNextFunction(V: &Value) -> Option<&Value>; + pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; + pub fn LLVMGlobalGetValueType(val: &Value) -> &Type; + pub fn LLVMRustGetFunctionType(fnc: &Value) -> &Type; + // Create and destroy contexts. pub fn LLVMContextDispose(C: &'static mut Context); pub fn LLVMGetMDKindIDInContext(C: &Context, Name: *const c_char, SLen: c_uint) -> c_uint; @@ -996,6 +1234,30 @@ extern "C" { Value: *const c_char, ValueLen: c_uint, ) -> &Attribute; + pub fn LLVMRemoveStringAttributeAtIndex(F: &Value, Idx: c_uint, K: *const c_char, KLen: c_uint); + pub fn LLVMGetStringAttributeAtIndex( + F: &Value, + Idx: c_uint, + K: *const c_char, + KLen: c_uint, + ) -> &Attribute; + pub fn LLVMAddAttributeAtIndex(F: &Value, Idx: c_uint, K: &Attribute); + pub fn LLVMRemoveEnumAttributeAtIndex(F: &Value, Idx: c_uint, K: Attribute); + pub fn LLVMGetEnumAttributeAtIndex(F: &Value, Idx: c_uint, K: Attribute) -> &Attribute; + pub fn LLVMIsEnumAttribute(A: &Attribute) -> bool; + pub fn LLVMIsStringAttribute(A: &Attribute) -> bool; + pub fn LLVMRustAddEnumAttributeAtIndex( + C: &Context, + V: &Value, + index: c_uint, + attr: AttributeKind, + ); + pub fn LLVMRustRemoveEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRustGetEnumAttributeAtIndex( + V: &Value, + index: c_uint, + attr: AttributeKind, + ) -> &Attribute; // Operations on functions pub fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint); @@ -1016,6 +1278,7 @@ extern "C" { // Operations on instructions pub fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>; pub fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock; + pub fn LLVMGetNextBasicBlock(Fn: &BasicBlock) -> &BasicBlock; // Operations on call sites pub fn LLVMSetInstructionCallConv(Instr: &Value, CC: c_uint); @@ -1606,6 +1869,12 @@ extern "C" { AttrsLen: size_t, ); + pub fn LLVMRustAddParamAttr<'a>( + Instr: &'a Value, + index: c_uint, + Attr: &'a Attribute + ); + pub fn LLVMRustBuildInvoke<'a>( B: &Builder<'a>, Ty: &'a Type, @@ -2405,3 +2674,528 @@ extern "C" { error_callback: GetSymbolsErrorCallback, ) -> *mut c_void; } + +#[cfg(not(llvm_enzyme))] +pub use self::Fallback_AD::*; + +#[cfg(not(llvm_enzyme))] +pub mod Fallback_AD { + #![allow(unused_variables)] + use super::*; + + pub fn EnzymeNewTypeTree() -> CTypeTreeRef { unimplemented!() } + pub fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) { unimplemented!() } + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8) { unimplemented!() } + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64) { unimplemented!() } + + pub fn set_inline(val: bool) { unimplemented!() } + pub fn set_runtime_activity_check(check: bool) { unimplemented!() } + pub fn set_max_int_offset(offset: u64) { unimplemented!() } + pub fn set_max_type_offset(offset: u64) { unimplemented!() } + pub fn set_max_type_depth(depth: u64) { unimplemented!() } + pub fn set_print_perf(print: bool) { unimplemented!() } + pub fn set_print_activity(print: bool) { unimplemented!() } + pub fn set_print_type(print: bool) { unimplemented!() } + pub fn set_print(print: bool) { unimplemented!() } + pub fn set_strict_aliasing(strict: bool) { unimplemented!() } + pub fn set_loose_types(loose: bool) { unimplemented!() } + + pub fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: ::std::os::raw::c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value { + unimplemented!() + } + pub fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: ::std::os::raw::c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value { + unimplemented!() + } +pub type CustomRuleType = ::std::option::Option< + unsafe extern "C" fn( + direction: ::std::os::raw::c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const ::std::os::raw::c_void, + ) -> u8, +>; +extern "C" { + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut ::std::os::raw::c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; +} + //pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() } + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef) { unimplemented!() } + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef { unimplemented!() } + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef) { unimplemented!() } + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef) { unimplemented!() } + + pub fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef { + unimplemented!() + } + pub fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef { + unimplemented!() + } + pub fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool { + unimplemented!() + } + pub fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) { + unimplemented!() + } + pub fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) { + unimplemented!() + } + pub fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ) { + unimplemented!() + } + pub fn EnzymeTypeTreeToStringFree(arg1: *const c_char) { + unimplemented!() + } + pub fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char { + unimplemented!() + } +} + + +// Enzyme specific, but doesn't require Enzyme to be build +pub use self::Shared_AD::*; +pub mod Shared_AD { + // Depending on the AD backend (Enzyme or Fallback), some functions might or might not be + // unsafe. So we just allways call them in an unsafe context. + #![allow(unused_unsafe)] + #![allow(unused_variables)] + + use libc::size_t; + use super::Context; + + #[cfg(not(llvm_enzyme))] + use super::Fallback_AD::*; + #[cfg(llvm_enzyme)] + use super::Enzyme_AD::*; + + use core::fmt; + use std::ffi::{CStr, CString}; + use rustc_ast::expand::autodiff_attrs::DiffActivity; + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum CDIFFE_TYPE { + DFT_OUT_DIFF = 0, + DFT_DUP_ARG = 1, + DFT_CONSTANT = 2, + DFT_DUP_NONEED = 3, + } + + impl fmt::Display for CDIFFE_TYPE { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let value = match self { + CDIFFE_TYPE::DFT_OUT_DIFF => "DFT_OUT_DIFF", + CDIFFE_TYPE::DFT_DUP_ARG => "DFT_DUP_ARG", + CDIFFE_TYPE::DFT_CONSTANT => "DFT_CONSTANT", + CDIFFE_TYPE::DFT_DUP_NONEED => "DFT_DUP_NONEED", + }; + write!(f, "{}", value) + } + } + + pub fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { + return match act { + DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::ActiveOnly => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::Dual => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DualOnly => CDIFFE_TYPE::DFT_DUP_NONEED, + DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DuplicatedOnly => CDIFFE_TYPE::DFT_DUP_NONEED, + DiffActivity::FakeActivitySize => panic!("Implementation error"), + }; + } + + pub fn is_size(act: DiffActivity) -> bool { + return act == DiffActivity::FakeActivitySize; + } + + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum CDerivativeMode { + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3, + DEM_ForwardModeSplit = 4, + } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeOpaqueTypeAnalysis { + _unused: [u8; 0], + } + pub type EnzymeTypeAnalysisRef = *mut EnzymeOpaqueTypeAnalysis; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeOpaqueLogic { + _unused: [u8; 0], + } + pub type EnzymeLogicRef = *mut EnzymeOpaqueLogic; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeOpaqueAugmentedReturn { + _unused: [u8; 0], + } + pub type EnzymeAugmentedReturnPtr = *mut EnzymeOpaqueAugmentedReturn; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct IntList { + pub data: *mut i64, + pub size: size_t, + } + #[repr(u32)] + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] + pub enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, + } + + pub type CTypeTreeRef = *mut EnzymeTypeTree; + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct EnzymeTypeTree { + _unused: [u8; 0], + } + pub struct TypeTree { + pub inner: CTypeTreeRef, + } + + impl TypeTree { + pub fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + TypeTree { inner } + } + + #[must_use] + pub fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + TypeTree { inner } + } + + #[must_use] + pub fn only(self, idx: isize) -> TypeTree { + unsafe { + EnzymeTypeTreeOnlyEq(self.inner, idx as i64); + } + self + } + + #[must_use] + pub fn data0(self) -> TypeTree { + unsafe { + EnzymeTypeTreeData0Eq(self.inner); + } + self + } + + pub fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + self + } + + #[must_use] + pub fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self { + let layout = CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ) + } + + self + } + } + + impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } + } + } + + impl fmt::Display for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } + + // delete C string pointer + unsafe { EnzymeTypeTreeToStringFree(ptr) } + + Ok(()) + } + } + + impl fmt::Debug for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } + } + + impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } + } + + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct CFnTypeInfo { + #[doc = " Types of arguments, assumed of size len(Arguments)"] + pub Arguments: *mut CTypeTreeRef, + #[doc = " Type of return"] + pub Return: CTypeTreeRef, + #[doc = " The specific constant(s) known to represented by an argument, if constant"] + pub KnownValues: *mut IntList, + } +} + +#[cfg(llvm_enzyme)] +pub use self::Enzyme_AD::*; + +// Enzyme is an optional component, so we do need to provide a fallback when it is ont getting +// compiled. We deny the usage of #[autodiff(..)] on a higher level, so a placeholder implementation +// here is completely fine. +#[cfg(llvm_enzyme)] +pub mod Enzyme_AD { +use super::*; + +use libc::{c_char, size_t}; +use libc::c_void; + +extern "C" { + pub fn EnzymeNewTypeTree() -> CTypeTreeRef; + pub fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64); +} + +extern "C" { + static mut MaxIntOffset: c_void; + static mut MaxTypeOffset: c_void; + static mut EnzymeMaxTypeDepth: c_void; + + static mut EnzymeRuntimeActivityCheck: c_void; + static mut EnzymePrintPerf: c_void; + static mut EnzymePrintActivity: c_void; + static mut EnzymePrintType: c_void; + static mut EnzymePrint: c_void; + static mut EnzymeStrictAliasing: c_void; + static mut looseTypeAnalysis: c_void; + static mut EnzymeInline: c_void; +} +pub fn set_runtime_activity_check(check: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeRuntimeActivityCheck), check as u8); + } +} +pub fn set_max_int_offset(offset: u64) { + let offset = offset.try_into().unwrap(); + unsafe { + EnzymeSetCLInteger(std::ptr::addr_of_mut!(MaxIntOffset), offset); + } +} +pub fn set_max_type_offset(offset: u64) { + let offset = offset.try_into().unwrap(); + unsafe { + EnzymeSetCLInteger(std::ptr::addr_of_mut!(MaxTypeOffset), offset); + } +} +pub fn set_max_type_depth(depth: u64) { + let depth = depth.try_into().unwrap(); + unsafe { + EnzymeSetCLInteger(std::ptr::addr_of_mut!(EnzymeMaxTypeDepth), depth); + } +} +pub fn set_print_perf(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8); + } +} +pub fn set_print_activity(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8); + } +} +pub fn set_print_type(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8); + } +} +pub fn set_print(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8); + } +} +pub fn set_strict_aliasing(strict: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8); + } +} +pub fn set_loose_types(loose: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8); + } +} +pub fn set_inline(val: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8); + } +} + +extern "C" { + pub fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: ::std::os::raw::c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value; +} +extern "C" { + pub fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8, // &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: ::std::os::raw::c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value; +} +pub type CustomRuleType = ::std::option::Option< + unsafe extern "C" fn( + direction: ::std::os::raw::c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const ::std::os::raw::c_void, + ) -> u8, +>; +extern "C" { + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut ::std::os::raw::c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; +} +extern "C" { + //pub(super) fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); +} + +extern "C" { + pub(super) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + pub(super) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + pub(super) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + pub(super) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + pub(super) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + pub(super) fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + pub fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + pub fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; +} +} diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs new file mode 100644 index 0000000000000..c45f5ec5e7005 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -0,0 +1,33 @@ +use crate::llvm; +use rustc_ast::expand::typetree::{Kind, TypeTree}; + +pub fn to_enzyme_typetree( + tree: TypeTree, + llvm_data_layout: &str, + llcx: &llvm::Context, +) -> llvm::TypeTree { + tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| { + let scalar = match x.kind { + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + _ => panic!("Unknown kind {:?}", x.kind), + }; + + let tt = llvm::TypeTree::from_type(scalar, llcx).only(-1); + + let tt = if !x.child.0.is_empty() { + let inner_tt = to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx); + tt.merge(inner_tt.only(-1)) + } else { + tt + }; + + if x.offset != -1 { + obj.merge(tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize)) + } else { + obj.merge(tt) + } + }) +} diff --git a/compiler/rustc_codegen_ssa/messages.ftl b/compiler/rustc_codegen_ssa/messages.ftl index 5881c6236ece6..44e7324c829a3 100644 --- a/compiler/rustc_codegen_ssa/messages.ftl +++ b/compiler/rustc_codegen_ssa/messages.ftl @@ -330,3 +330,5 @@ codegen_ssa_use_cargo_directive = use the `cargo:rustc-link-lib` directive to sp codegen_ssa_version_script_write_failure = failed to write version script: {$error} codegen_ssa_visual_studio_not_installed = you may need to install Visual Studio build tools with the "C++ build tools" workload + +codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto diff --git a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs index a5bd10ecb34b5..d182ed88e4d10 100644 --- a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs +++ b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs @@ -46,7 +46,7 @@ pub fn assert_module_sources(tcx: TyCtxt<'_>, set_reuse: &dyn Fn(&mut CguReuseTr } let available_cgus = - tcx.collect_and_partition_mono_items(()).1.iter().map(|cgu| cgu.name()).collect(); + tcx.collect_and_partition_mono_items(()).2.iter().map(|cgu| cgu.name()).collect(); let mut ams = AssertModuleSource { tcx, diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index cb6244050df24..2d8e230e1976b 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,8 +1,10 @@ use super::write::CodegenContext; +use crate::back::write::ModuleConfig; use crate::traits::*; use crate::ModuleCodegen; -use rustc_data_structures::memmap::Mmap; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; +use rustc_data_structures::{fx::FxHashMap, memmap::Mmap}; use rustc_errors::FatalError; use std::ffi::CString; @@ -76,6 +78,24 @@ impl LtoModuleCodegen { } } + /// Run autodiff on Fat LTO module + pub unsafe fn autodiff( + self, + cgcx: &CodegenContext, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result, FatalError> { + match &self { + LtoModuleCodegen::Fat { ref module, .. } => { + B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; + } + _ => panic!("Unreachable? Autodiff called with non-fat LTO module"), + } + + Ok(self) + } + /// A "gauge" of how costly it is to optimize this module, used to sort /// biggest modules first. pub fn cost(&self) -> u64 { diff --git a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs index 54b523cb6bd54..defbff19c7778 100644 --- a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs +++ b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs @@ -331,7 +331,7 @@ fn exported_symbols_provider_local( // external linkage is enough for monomorphization to be linked to. let need_visibility = tcx.sess.target.dynamic_linking && !tcx.sess.target.only_cdylib; - let (_, cgus) = tcx.collect_and_partition_mono_items(()); + let (_, _, cgus) = tcx.collect_and_partition_mono_items(()); for (mono_item, data) in cgus.iter().flat_map(|cgu| cgu.items().iter()) { if data.linkage != Linkage::External { diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 53ae085a72160..fda8330ea8f11 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -9,6 +9,7 @@ use crate::{ }; use jobserver::{Acquired, Client}; use rustc_ast::attr; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_data_structures::memmap::Mmap; use rustc_data_structures::profiling::{SelfProfilerRef, VerboseTimingGuard}; @@ -34,7 +35,7 @@ use rustc_span::symbol::sym; use rustc_span::{BytePos, FileName, InnerSpan, Pos, Span}; use rustc_target::spec::{MergeFunctions, SanitizerSet}; -use crate::errors::ErrorCreatingRemarkDir; +use crate::errors::{ErrorCreatingRemarkDir, AutodiffWithoutLto}; use std::any::Any; use std::borrow::Cow; use std::fs; @@ -374,6 +375,8 @@ impl CodegenContext { fn generate_lto_work( cgcx: &CodegenContext, + autodiff: Vec, + typetrees: FxHashMap, needs_fat_lto: Vec>, needs_thin_lto: Vec<(String, B::ThinBuffer)>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, @@ -382,11 +385,19 @@ fn generate_lto_work( if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let module = + let mut module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); + if cgcx.lto == Lto::Fat { + let config = cgcx.config(ModuleKind::Regular); + module = unsafe { module.autodiff(cgcx, autodiff, typetrees, config).unwrap() }; + } // We are adding a single work item, so the cost doesn't matter. vec![(WorkItem::LTO(module), 0)] } else { + if !autodiff.is_empty() { + let dcx = cgcx.create_dcx(); + dcx.emit_fatal(AutodiffWithoutLto{}); + } assert!(needs_fat_lto.is_empty()); let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) .unwrap_or_else(|e| e.raise()); @@ -956,12 +967,18 @@ pub(crate) enum Message { /// The backend has finished processing a work item for a codegen unit. /// Sent from a backend worker thread. - WorkItem { result: Result, Option>, worker_id: usize }, + WorkItem { + result: Result, Option>, + worker_id: usize, + }, /// The frontend has finished generating something (backend IR or a /// post-LTO artifact) for a codegen unit, and it should be passed to the /// backend. Sent from the main thread. - CodegenDone { llvm_work_item: WorkItem, cost: u64 }, + CodegenDone { + llvm_work_item: WorkItem, + cost: u64, + }, /// Similar to `CodegenDone`, but for reusing a pre-LTO artifact /// Sent from the main thread. @@ -970,6 +987,8 @@ pub(crate) enum Message { work_product: WorkProduct, }, + AddAutoDiffItems(Vec), + /// The frontend has finished generating everything for all codegen units. /// Sent from the main thread. CodegenComplete, @@ -1264,6 +1283,7 @@ fn start_executing_work( // This is where we collect codegen units that have gone all the way // through codegen and LLVM. + let mut autodiff_items = Vec::new(); let mut compiled_modules = vec![]; let mut compiled_allocator_module = None; let mut needs_link = Vec::new(); @@ -1271,6 +1291,7 @@ fn start_executing_work( let mut needs_thin_lto = Vec::new(); let mut lto_import_only_modules = Vec::new(); let mut started_lto = false; + let mut typetrees = FxHashMap::::default(); /// Possible state transitions: /// - Ongoing -> Completed @@ -1375,9 +1396,14 @@ fn start_executing_work( let needs_thin_lto = mem::take(&mut needs_thin_lto); let import_only_modules = mem::take(&mut lto_import_only_modules); - for (work, cost) in - generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules) - { + for (work, cost) in generate_lto_work( + &cgcx, + autodiff_items.clone(), + typetrees.clone(), + needs_fat_lto, + needs_thin_lto, + import_only_modules, + ) { let insertion_index = work_items .binary_search_by_key(&cost, |&(_, cost)| cost) .unwrap_or_else(|e| e); @@ -1490,7 +1516,16 @@ fn start_executing_work( } } - Message::CodegenDone { llvm_work_item, cost } => { + Message::CodegenDone { mut llvm_work_item, cost } => { + //// extract build typetrees + match &mut llvm_work_item { + WorkItem::Optimize(module) => { + let tt = B::typetrees(&mut module.module_llvm); + typetrees.extend(tt); + } + _ => {} + } + // We keep the queue sorted by estimated processing cost, // so that more expensive items are processed earlier. This // is good for throughput as it gives the main thread more @@ -1512,6 +1547,10 @@ fn start_executing_work( main_thread_state = MainThreadState::Idle; } + Message::AddAutoDiffItems(mut items) => { + autodiff_items.append(&mut items); + } + Message::CodegenComplete => { if codegen_state != Aborted { codegen_state = Completed; @@ -1981,6 +2020,10 @@ impl OngoingCodegen { drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::))); } + pub fn submit_autodiff_items(&self, items: Vec) { + drop(self.coordinator.sender.send(Box::new(Message::::AddAutoDiffItems(items)))); + } + pub fn check_for_errors(&self, sess: &Session) { self.shared_emitter_main.check(sess, false); } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index b87f4b6bf8907..8bf68cd56cdcd 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -44,6 +44,9 @@ use std::time::{Duration, Instant}; use itertools::Itertools; +use rustc_middle::ty::typetree_from; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; + pub fn bin_op_to_icmp_predicate(op: hir::BinOpKind, signed: bool) -> IntPredicate { match op { hir::BinOpKind::Eq => IntPredicate::IntEQ, @@ -268,7 +271,7 @@ pub fn coerce_unsized_into<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( OperandValue::Immediate(base) => unsize_ptr(bx, base, src_ty, dst_ty, None), OperandValue::Ref(..) | OperandValue::ZeroSized => bug!(), }; - OperandValue::Pair(base, info).store(bx, dst); + OperandValue::Pair(base, info).store(bx, dst, None); } (&ty::Adt(def_a, _), &ty::Adt(def_b, _)) => { @@ -357,6 +360,7 @@ pub fn wants_new_eh_instructions(sess: &Session) -> bool { wants_wasm_eh(sess) || wants_msvc_seh(sess) } +// Manuel TODO pub fn memcpy_ty<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( bx: &mut Bx, dst: Bx::Value, @@ -370,6 +374,13 @@ pub fn memcpy_ty<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( if size == 0 { return; } + let my_ty = layout.ty; + let tcx: TyCtxt<'_> = bx.cx().tcx(); + let fnc_tree: TypeTree = typetree_from(tcx, my_ty); + let fnc_tree: FncTree = FncTree { + args: vec![fnc_tree.clone(), fnc_tree.clone()], + ret: TypeTree::new(), + }; if flags == MemFlags::empty() && let Some(bty) = bx.cx().scalar_copy_backend_type(layout) @@ -377,8 +388,11 @@ pub fn memcpy_ty<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( let temp = bx.load(bty, src, src_align); bx.store(temp, dst, dst_align); } else { - bx.memcpy(dst, dst_align, src, src_align, bx.cx().const_usize(size), flags); + trace!("my_ty: {:?}, enzyme tt: {:?}", my_ty, fnc_tree); + trace!("memcpy_ty: {:?} -> {:?} (size={}, align={:?})", src, dst, size, dst_align); + bx.memcpy(dst, dst_align, src, src_align, bx.cx().const_usize(size), flags, Some(fnc_tree)); } + //let (_args, _ret): (Vec, TypeTree) = (fnc_tree.args, fnc_tree.ret); } pub fn codegen_instance<'a, 'tcx: 'a, Bx: BuilderMethods<'a, 'tcx>>( @@ -594,7 +608,8 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let codegen_units = tcx.collect_and_partition_mono_items(()).1; + let (_, autodiff_fncs, codegen_units) = tcx.collect_and_partition_mono_items(()); + let autodiff_fncs = autodiff_fncs.to_vec(); // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -663,6 +678,10 @@ pub fn codegen_crate( ); } + if !autodiff_fncs.is_empty() { + ongoing_codegen.submit_autodiff_items(autodiff_fncs); + } + // For better throughput during parallel processing by LLVM, we used to sort // CGUs largest to smallest. This would lead to better thread utilization // by, for example, preventing a large CGU from being processed last and @@ -980,7 +999,7 @@ pub fn provide(providers: &mut Providers) { config::OptLevel::SizeMin => config::OptLevel::Default, }; - let (defids, _) = tcx.collect_and_partition_mono_items(cratenum); + let (defids, _, _) = tcx.collect_and_partition_mono_items(cratenum); let any_for_speed = defids.items().any(|id| { let CodegenFnAttrs { optimize, .. } = tcx.codegen_fn_attrs(*id); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index e529956b1baf6..f0b993a3844e1 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,4 +1,5 @@ -use rustc_ast::{ast, attr, MetaItemKind, NestedMetaItem}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_ret_activity, invalid_input_activities}; +use rustc_ast::{ast, attr, MetaItem, MetaItemKind, NestedMetaItem}; use rustc_attr::{list_contains_name, InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_errors::struct_span_err; use rustc_hir as hir; @@ -14,6 +15,8 @@ use rustc_span::symbol::Ident; use rustc_span::{sym, Span}; use rustc_target::spec::{abi, SanitizerSet}; +use std::str::FromStr; + use crate::errors; use crate::target_features::from_target_feature; use crate::{ @@ -689,6 +692,144 @@ fn check_link_name_xor_ordinal( } } +/// We now check the #[rustc_autodiff] attributes which we generated from the #[autodiff(...)] +/// macros. There are two forms. The pure one without args to mark primal functions (the functions +/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the +/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never +/// panic, unless we introduced a bug when parsing the autodiff macro. +fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { + let attrs = tcx.get_attrs(id, sym::rustc_autodiff); + + let attrs = attrs + .filter(|attr| attr.name_or_empty() == sym::rustc_autodiff) + .collect::>(); + + // check for exactly one autodiff attribute on placeholder functions. + // There should only be one, since we generate a new placeholder per ad macro. + let msg_once = "cg_ssa: implementation bug. Autodiff attribute can only be applied once"; + let attr = match attrs.len() { + 0 => return AutoDiffAttrs::inactive(), + 1 => attrs.get(0).unwrap(), + _ => { + tcx.sess + .struct_span_err(attrs[1].span, msg_once) + .span_label(attrs[1].span, "more than one") + .emit(); + return AutoDiffAttrs::inactive(); + } + }; + + let list = attr.meta_item_list().unwrap_or_default(); + + // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions + if list.len() == 0 { + return AutoDiffAttrs::source(); + } + + let [mode, input_activities @ .., ret_activity] = &list[..] else { + tcx.sess + .struct_span_err(attr.span, msg_once) + .span_label(attr.span, "Implementation bug in autodiff_attrs. Please report this!") + .emit(); + return AutoDiffAttrs::inactive(); + }; + let mode = if let NestedMetaItem::MetaItem(MetaItem { path: ref p1, .. }) = mode { + p1.segments.first().unwrap().ident + } else { + let msg = "autodiff attribute must contain autodiff mode"; + tcx.sess + .struct_span_err(attr.span, msg) + .span_label(attr.span, "empty argument list") + .emit(); + return AutoDiffAttrs::inactive(); + }; + + // parse mode + let msg_mode = "mode should be either forward or reverse"; + let mode = match mode.as_str() { + "Forward" => DiffMode::Forward, + "Reverse" => DiffMode::Reverse, + "ForwardFirst" => DiffMode::ForwardFirst, + "ReverseFirst" => DiffMode::ReverseFirst, + _ => { + tcx.sess + .struct_span_err(attr.span, msg_mode) + .span_label(attr.span, "invalid mode") + .emit(); + return AutoDiffAttrs::inactive(); + } + }; + + // First read the ret symbol from the attribute + let ret_symbol = if let NestedMetaItem::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity { + p1.segments.first().unwrap().ident + } else { + let msg = "autodiff attribute must contain the return activity"; + tcx.sess + .struct_span_err(attr.span, msg) + .span_label(attr.span, "missing return activity") + .emit(); + return AutoDiffAttrs::inactive(); + }; + + // Then parse it into an actual DiffActivity + let msg_unknown_ret_activity = "unknown return activity"; + let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) { + Ok(x) => x, + Err(_) => { + tcx.sess + .struct_span_err(attr.span, msg_unknown_ret_activity) + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + }; + + // Now parse all the intermediate (input) activities + let msg_arg_activity = "autodiff attribute must contain the return activity"; + let mut arg_activities: Vec = vec![]; + for arg in input_activities { + let arg_symbol = if let NestedMetaItem::MetaItem(MetaItem { path: ref p2, .. }) = arg { + p2.segments.first().unwrap().ident + } else { + tcx.sess + .struct_span_err(attr.span, msg_arg_activity) + .span_label(attr.span, "Implementation bug, please report this!") + .emit(); + return AutoDiffAttrs::inactive(); + }; + + match DiffActivity::from_str(arg_symbol.as_str()) { + Ok(arg_activity) => arg_activities.push(arg_activity), + Err(_) => { + tcx.sess + .struct_span_err(attr.span, msg_unknown_ret_activity) + .span_label(attr.span, "invalid input activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + } + + let mut msg = "".to_string(); + if let Some(i) = invalid_input_activities(mode, &arg_activities) { + msg = format!("Invalid input activity {} for {} mode", arg_activities[i], mode); + } + if !valid_ret_activity(mode, ret_activity) { + msg = format!("Invalid return activity {} for {} mode", ret_activity, mode); + } + if msg != "".to_string() { + tcx.sess + .struct_span_err(attr.span, msg) + .span_label(attr.span, "invalid activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + + AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities } +} + pub fn provide(providers: &mut Providers) { - *providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers }; + *providers = + Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers }; } diff --git a/compiler/rustc_codegen_ssa/src/errors.rs b/compiler/rustc_codegen_ssa/src/errors.rs index 2b628d2aa69b7..42fb45a6f42aa 100644 --- a/compiler/rustc_codegen_ssa/src/errors.rs +++ b/compiler/rustc_codegen_ssa/src/errors.rs @@ -35,6 +35,10 @@ pub struct CguNotRecorded<'a> { pub cgu_name: &'a str, } +#[derive(Diagnostic)] +#[diag(codegen_ssa_autodiff_without_lto)] +pub struct AutodiffWithoutLto; + #[derive(Diagnostic)] #[diag(codegen_ssa_unknown_reuse_kind)] pub struct UnknownReuseKind { diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index a1662f25e1496..1084031776c3d 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -424,7 +424,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let llslot = match op.val { Immediate(_) | Pair(..) => { let scratch = PlaceRef::alloca(bx, self.fn_abi.ret.layout); - op.val.store(bx, scratch); + op.val.store(bx, scratch, None); scratch.llval } Ref(llval, _, align) => { @@ -833,7 +833,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { self.get_caller_location(bx, mir::SourceInfo { span: fn_span, ..source_info }); if let ReturnDest::IndirectOperand(tmp, _) = ret_dest { - location.val.store(bx, tmp); + location.val.store(bx, tmp, None); } self.store_return(bx, ret_dest, &fn_abi.ret, location.immediate()); helper.funclet_br(self, bx, target, mergeable_succ) @@ -993,7 +993,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { | (&mir::Operand::Constant(_), Ref(_, None, _)) => { let tmp = PlaceRef::alloca(bx, op.layout); bx.lifetime_start(tmp.llval, tmp.layout.size); - op.val.store(bx, tmp); + op.val.store(bx, tmp, None); op.val = Ref(tmp.llval, None, tmp.align); copied_constant_arguments.push(tmp); } @@ -1337,12 +1337,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { None => arg.layout.align.abi, }; let scratch = PlaceRef::alloca_aligned(bx, arg.layout, required_align); - op.val.store(bx, scratch); + op.val.store(bx, scratch, None); (scratch.llval, scratch.align, true) } PassMode::Cast { .. } => { let scratch = PlaceRef::alloca(bx, arg.layout); - op.val.store(bx, scratch); + op.val.store(bx, scratch, None); (scratch.llval, scratch.align, true) } _ => (op.immediate_or_packed_pair(bx), arg.layout.align.abi, false), @@ -1502,7 +1502,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let slot = self.get_personality_slot(&mut cleanup_bx); slot.storage_live(&mut cleanup_bx); - Pair(exn0, exn1).store(&mut cleanup_bx, slot); + Pair(exn0, exn1).store(&mut cleanup_bx, slot, None); cleanup_bx.br(llbb); cleanup_llbb diff --git a/compiler/rustc_codegen_ssa/src/mir/debuginfo.rs b/compiler/rustc_codegen_ssa/src/mir/debuginfo.rs index 14915e816ee9b..6cd614aa2ea4e 100644 --- a/compiler/rustc_codegen_ssa/src/mir/debuginfo.rs +++ b/compiler/rustc_codegen_ssa/src/mir/debuginfo.rs @@ -259,7 +259,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { if let Some(name) = name { bx.set_var_name(spill_slot.llval, &(name + ".dbg.spill")); } - operand.val.store(bx, spill_slot); + operand.val.store(bx, spill_slot, None); spill_slot } diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index a5bffc33d393c..fbb57ec475e7d 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -16,6 +16,10 @@ use rustc_target::abi::{ WrappingRange, }; +use rustc_middle::ty::typetree_from; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; +use crate::rustc_middle::ty::layout::HasTyCtxt; + fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( bx: &mut Bx, allow_overlap: bool, @@ -25,15 +29,23 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( src: Bx::Value, count: Bx::Value, ) { + let tcx: TyCtxt<'_> = bx.cx().tcx(); + let tt: TypeTree = typetree_from(tcx, ty); + let fnc_tree: FncTree = FncTree { + args: vec![tt.clone(), tt.clone(), TypeTree::all_ints()], + ret: TypeTree::new(), + }; + let layout = bx.layout_of(ty); let size = layout.size; let align = layout.align.abi; let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; + trace!("copy: mir ty: {:?}, enzyme tt: {:?}", ty, fnc_tree); if allow_overlap { - bx.memmove(dst, align, src, align, size, flags); + bx.memmove(dst, align, src, align, size, flags, Some(fnc_tree)); } else { - bx.memcpy(dst, align, src, align, size, flags); + bx.memcpy(dst, align, src, align, size, flags, Some(fnc_tree)); } } @@ -45,12 +57,19 @@ fn memset_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( val: Bx::Value, count: Bx::Value, ) { + let tcx: TyCtxt<'_> = bx.cx().tcx(); + let tt: TypeTree = typetree_from(tcx, ty); + let fnc_tree: FncTree = FncTree { + args: vec![tt.clone(), tt.clone(), TypeTree::all_ints()], + ret: TypeTree::new(), + }; + let layout = bx.layout_of(ty); let size = layout.size; let align = layout.align.abi; let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; - bx.memset(dst, val, size, align, flags); + bx.memset(dst, val, size, align, flags, Some(fnc_tree)); } impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { @@ -506,7 +525,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { } else { OperandRef::from_immediate_or_packed_pair(bx, llval, result.layout) .val - .store(bx, result); + .store(bx, result, None); } } } diff --git a/compiler/rustc_codegen_ssa/src/mir/mod.rs b/compiler/rustc_codegen_ssa/src/mir/mod.rs index a6fcf1fd38c1f..6955d07d0fe8a 100644 --- a/compiler/rustc_codegen_ssa/src/mir/mod.rs +++ b/compiler/rustc_codegen_ssa/src/mir/mod.rs @@ -384,7 +384,7 @@ fn arg_local_refs<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( let indirect_operand = OperandValue::Pair(llarg, llextra); let tmp = PlaceRef::alloca_unsized_indirect(bx, arg.layout); - indirect_operand.store(bx, tmp); + indirect_operand.store(bx, tmp, None); LocalRef::UnsizedPlace(tmp) } else { let tmp = PlaceRef::alloca(bx, arg.layout); diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index 794cbd315b795..f4587a7727646 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -12,6 +12,8 @@ use rustc_middle::ty::layout::{LayoutOf, TyAndLayout}; use rustc_middle::ty::Ty; use rustc_target::abi::{self, Abi, Align, Size}; +use rustc_ast::expand::typetree::FncTree; + use std::fmt; /// The representation of a Rust value. The enum variant is in fact @@ -373,8 +375,9 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { self, bx: &mut Bx, dest: PlaceRef<'tcx, V>, + tt: Option, ) { - self.store_with_flags(bx, dest, MemFlags::empty()); + self.store_with_flags(bx, dest, MemFlags::empty(), tt); } pub fn volatile_store>( @@ -382,7 +385,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { bx: &mut Bx, dest: PlaceRef<'tcx, V>, ) { - self.store_with_flags(bx, dest, MemFlags::VOLATILE); + self.store_with_flags(bx, dest, MemFlags::VOLATILE, None); } pub fn unaligned_volatile_store>( @@ -390,7 +393,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { bx: &mut Bx, dest: PlaceRef<'tcx, V>, ) { - self.store_with_flags(bx, dest, MemFlags::VOLATILE | MemFlags::UNALIGNED); + self.store_with_flags(bx, dest, MemFlags::VOLATILE | MemFlags::UNALIGNED, None); } pub fn nontemporal_store>( @@ -398,7 +401,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { bx: &mut Bx, dest: PlaceRef<'tcx, V>, ) { - self.store_with_flags(bx, dest, MemFlags::NONTEMPORAL); + self.store_with_flags(bx, dest, MemFlags::NONTEMPORAL, None); } fn store_with_flags>( @@ -406,6 +409,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { bx: &mut Bx, dest: PlaceRef<'tcx, V>, flags: MemFlags, + tt: Option, ) { debug!("OperandRef::store: operand={:?}, dest={:?}", self, dest); match self { @@ -482,11 +486,11 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { let neg_address = bx.neg(address); let offset = bx.and(neg_address, align_minus_1); let dst = bx.inbounds_gep(bx.type_i8(), alloca, &[offset]); - bx.memcpy(dst, min_align, llptr, min_align, size, MemFlags::empty()); + bx.memcpy(dst, min_align, llptr, min_align, size, MemFlags::empty(), None); // Store the allocated region and the extra to the indirect place. let indirect_operand = OperandValue::Pair(dst, llextra); - indirect_operand.store(bx, indirect_dest); + indirect_operand.store(bx, indirect_dest, None); } } diff --git a/compiler/rustc_codegen_ssa/src/mir/place.rs b/compiler/rustc_codegen_ssa/src/mir/place.rs index c0bb3ac5661af..f824e6af85117 100644 --- a/compiler/rustc_codegen_ssa/src/mir/place.rs +++ b/compiler/rustc_codegen_ssa/src/mir/place.rs @@ -363,7 +363,7 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> { } else { bx.cx().const_uint_big(niche_llty, niche_value) }; - OperandValue::Immediate(niche_llval).store(bx, niche); + OperandValue::Immediate(niche_llval).store(bx, niche, None); } } } diff --git a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs index 02b51dfe5bf7f..fd136429020ae 100644 --- a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs +++ b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs @@ -29,7 +29,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let cg_operand = self.codegen_operand(bx, operand); // FIXME: consider not copying constants through stack. (Fixable by codegen'ing // constants into `OperandValue::Ref`; why don’t we do that yet if we don’t?) - cg_operand.val.store(bx, dest); + cg_operand.val.store(bx, dest, None); } mir::Rvalue::Cast( @@ -43,7 +43,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // Into-coerce of a thin pointer to a fat pointer -- just // use the operand path. let temp = self.codegen_rvalue_operand(bx, rvalue); - temp.val.store(bx, dest); + temp.val.store(bx, dest, None); return; } @@ -63,7 +63,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { debug!("codegen_rvalue: creating ugly alloca"); let scratch = PlaceRef::alloca(bx, operand.layout); scratch.storage_live(bx); - operand.val.store(bx, scratch); + operand.val.store(bx, scratch, None); base::coerce_unsized_into(bx, scratch, dest); scratch.storage_dead(bx); } @@ -100,15 +100,17 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // Use llvm.memset.p0i8.* to initialize all zero arrays if bx.cx().const_to_opt_u128(v, false) == Some(0) { + //let ty = bx.cx().val_ty(v); let fill = bx.cx().const_u8(0); - bx.memset(start, fill, size, dest.align, MemFlags::empty()); + bx.memset(start, fill, size, dest.align, MemFlags::empty(), None); return; } // Use llvm.memset.p0i8.* to initialize byte arrays let v = bx.from_immediate(v); if bx.cx().val_ty(v) == bx.cx().type_i8() { - bx.memset(start, v, size, dest.align, MemFlags::empty()); + //let ty = bx.cx().type_i8(); + bx.memset(start, v, size, dest.align, MemFlags::empty(), None); return; } } @@ -142,7 +144,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { } else { variant_dest.project_field(bx, field_index.as_usize()) }; - op.val.store(bx, field); + op.val.store(bx, field, None); } } dest.codegen_set_discr(bx, variant_index); @@ -151,7 +153,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { _ => { assert!(self.rvalue_creates_operand(rvalue, DUMMY_SP)); let temp = self.codegen_rvalue_operand(bx, rvalue); - temp.val.store(bx, dest); + temp.val.store(bx, dest, None); } } } @@ -167,7 +169,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { debug_assert!(dst.layout.is_sized()); if let Some(val) = self.codegen_transmute_operand(bx, src, dst.layout) { - val.store(bx, dst); + val.store(bx, dst, None); return; } @@ -182,7 +184,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { OperandValue::Immediate(..) | OperandValue::Pair(..) => { // When we have immediate(s), the alignment of the source is irrelevant, // so we can store them using the destination's alignment. - src.val.store(bx, PlaceRef::new_sized_aligned(dst.llval, src.layout, dst.align)); + src.val.store(bx, PlaceRef::new_sized_aligned(dst.llval, src.layout, dst.align), None); } } } diff --git a/compiler/rustc_codegen_ssa/src/mir/statement.rs b/compiler/rustc_codegen_ssa/src/mir/statement.rs index a158fc6e26074..90cfd430f2dfd 100644 --- a/compiler/rustc_codegen_ssa/src/mir/statement.rs +++ b/compiler/rustc_codegen_ssa/src/mir/statement.rs @@ -85,7 +85,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let align = pointee_layout.align; let dst = dst_val.immediate(); let src = src_val.immediate(); - bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty()); + bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None); } mir::StatementKind::FakeRead(..) | mir::StatementKind::Retag { .. } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index aa411f002a0c6..af2e85ab87caf 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -22,6 +22,8 @@ use rustc_target::abi::call::FnAbi; use rustc_target::abi::{Abi, Align, Scalar, Size, WrappingRange}; use rustc_target::spec::HasTargetSpec; +use rustc_ast::expand::typetree::FncTree; + #[derive(Copy, Clone)] pub enum OverflowOp { Add, @@ -238,6 +240,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memmove( &mut self, @@ -247,6 +250,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memset( &mut self, @@ -255,6 +259,7 @@ pub trait BuilderMethods<'a, 'tcx>: size: Self::Value, align: Align, flags: MemFlags, + tt: Option, ); fn select( diff --git a/compiler/rustc_codegen_ssa/src/traits/misc.rs b/compiler/rustc_codegen_ssa/src/traits/misc.rs index 04e2b8796c46a..5f64dd3367661 100644 --- a/compiler/rustc_codegen_ssa/src/traits/misc.rs +++ b/compiler/rustc_codegen_ssa/src/traits/misc.rs @@ -19,4 +19,5 @@ pub trait MiscMethods<'tcx>: BackendTypes { fn apply_target_cpu_attr(&self, llfn: Self::Function); /// Declares the extern "C" main function for the entry point. Returns None if the symbol already exists. fn declare_c_main(&self, fn_type: Self::Type) -> Option; + fn create_autodiff(&self) -> Vec; } diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index 048540894ac9b..18f3855c5d5f9 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -2,6 +2,8 @@ use crate::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use crate::back::write::{CodegenContext, FatLtoInput, ModuleConfig}; use crate::{CompiledModule, ModuleCodegen}; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; +use rustc_data_structures::fx::FxHashMap; use rustc_errors::{DiagCtxt, FatalError}; use rustc_middle::dep_graph::WorkProduct; @@ -12,6 +14,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { type ModuleBuffer: ModuleBufferMethods; type ThinData: Send + Sync; type ThinBuffer: ThinBufferMethods; + type TypeTree: Clone; /// Merge all modules into main_module and returning it fn run_link( @@ -58,6 +61,15 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { ) -> Result; fn prepare_thin(module: ModuleCodegen) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError>; + fn typetrees(module: &mut Self::Module) -> FxHashMap; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_expand/src/base.rs b/compiler/rustc_expand/src/base.rs index 1fd4d2d55dd81..fff7810ea76a2 100644 --- a/compiler/rustc_expand/src/base.rs +++ b/compiler/rustc_expand/src/base.rs @@ -917,6 +917,8 @@ pub trait ResolverExpand { fn check_unused_macros(&mut self); + //fn autodiff() -> bool; + // Resolver interfaces for specific built-in macros. /// Does `#[derive(...)]` attribute with the given `ExpnId` have built-in `Copy` inside it? fn has_derive_copy(&self, expn_id: LocalExpnId) -> bool; diff --git a/compiler/rustc_expand/src/build.rs b/compiler/rustc_expand/src/build.rs index f9bfebee12e92..ac75c942b2b3f 100644 --- a/compiler/rustc_expand/src/build.rs +++ b/compiler/rustc_expand/src/build.rs @@ -157,6 +157,10 @@ impl<'a> ExtCtxt<'a> { ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Expr(expr) } } + pub fn stmt_semi(&self, expr: P) -> ast::Stmt { + ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Semi(expr) } + } + pub fn stmt_let_pat(&self, sp: Span, pat: P, ex: P) -> ast::Stmt { let local = P(ast::Local { pat, @@ -283,6 +287,10 @@ impl<'a> ExtCtxt<'a> { self.expr(sp, ast::ExprKind::Paren(e)) } + pub fn expr_method_call(&self, span: Span, expr: P, ident: Ident, args: ThinVec>) -> P { + let seg = ast::PathSegment::from_ident(ident); + self.expr(span, ast::ExprKind::MethodCall(Box::new(ast::MethodCall { seg, receiver: expr, args, span }))) + } pub fn expr_call( &self, span: Span, @@ -405,6 +413,12 @@ impl<'a> ExtCtxt<'a> { pub fn expr_tuple(&self, sp: Span, exprs: ThinVec>) -> P { self.expr(sp, ast::ExprKind::Tup(exprs)) } + pub fn expr_loop(&self, sp: Span, block: P) -> P { + self.expr(sp, ast::ExprKind::Loop(block, None, sp)) + } + pub fn expr_asm(&self, sp: Span, expr: P) -> P { + self.expr(sp, ast::ExprKind::InlineAsm(expr)) + } pub fn expr_fail(&self, span: Span, msg: Symbol) -> P { self.expr_call_global( diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index 5523543cd4fb9..1725ed887ec6e 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -588,6 +588,8 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ IMPL_DETAIL, ), rustc_attr!(rustc_proc_macro_decls, Normal, template!(Word), WarnFollowing, INTERNAL_UNSTABLE), + rustc_attr!(rustc_autodiff, Normal, template!(Word, List: r#""...""#), DuplicatesOk, INTERNAL_UNSTABLE), + rustc_attr!( rustc_macro_transparency, Normal, template!(NameValueStr: "transparent|semitransparent|opaque"), ErrorFollowing, diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 0df7b7eed11f9..43ed4c50613a3 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -300,17 +300,60 @@ extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index, AddAttributes(F, Index, Attrs, AttrsLen); } -extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, unsigned Index, - LLVMAttributeRef *Attrs, size_t AttrsLen) { +extern "C" bool LLVMRustIsStructType(LLVMTypeRef Ty) { + return unwrap(Ty)->isStructTy(); +} + +extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, + unsigned Index, + LLVMAttributeRef *Attrs, + size_t AttrsLen) { CallBase *Call = unwrap(Instr); AddAttributes(Call, Index, Attrs, AttrsLen); } -extern "C" LLVMAttributeRef LLVMRustCreateAttrNoValue(LLVMContextRef C, - LLVMRustAttribute RustAttr) { +extern "C" LLVMAttributeRef +LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttribute RustAttr) { return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); } +extern "C" LLVMValueRef LLVMRustGetTerminator(LLVMBasicBlockRef BB) { + Instruction *ret = unwrap(BB)->getTerminator(); + return wrap(ret); +} + +extern "C" void LLVMRustEraseInstFromParent(LLVMValueRef Instr) { + if (auto I = dyn_cast(unwrap(Instr))) { + I->eraseFromParent(); + } +} + +extern "C" void LLVMRustEraseBBFromParent(LLVMBasicBlockRef BB) { + unwrap(BB)->eraseFromParent(); +} + +extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) { + auto Ftype = unwrap(Fn)->getFunctionType(); + return wrap(Ftype); +} + +extern "C" void LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + +extern "C" void LLVMRustAddEnumAttributeAtIndex(LLVMContextRef C, + LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + LLVMAddAttributeAtIndex(F, index, LLVMRustCreateAttrNoValue(C, RustAttr)); +} + +extern "C" LLVMAttributeRef +LLVMRustGetEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + return LLVMGetEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + extern "C" LLVMAttributeRef LLVMRustCreateAlignmentAttr(LLVMContextRef C, uint64_t Bytes) { return wrap(Attribute::getWithAlignment(*unwrap(C), llvm::Align(Bytes))); @@ -761,6 +804,105 @@ extern "C" bool LLVMRustHasModuleFlag(LLVMModuleRef M, const char *Name, return unwrap(M)->getModuleFlag(StringRef(Name, Len)) != nullptr; } +extern "C" LLVMValueRef +LLVMRustgetFirstNonPHIOrDbgOrLifetime(LLVMBasicBlockRef BB) { + if (auto *I = unwrap(BB)->getFirstNonPHIOrDbgOrLifetime()) + return wrap(I); + return nullptr; +} + +// pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>; +extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) { + auto Point = unwrap(BB)->rbegin(); + if (Point != unwrap(BB)->rend()) + return wrap(&*Point); + return nullptr; +} + +extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) { + auto &BB = *unwrap(bb); + auto &Inst = *unwrap(I); + auto It = BB.begin(); + while (&*It != &Inst) + ++It; + assert(It != BB.end()); + // Delete in rev order to ensure no dangling references. + while (It != BB.begin()) { + auto Prev = std::prev(It); + It->eraseFromParent(); + It = Prev; + } + It->eraseFromParent(); +} + +extern "C" bool LLVMRustHasMetadata(LLVMValueRef inst, unsigned kindID) { + if (auto *I = dyn_cast(unwrap(inst))) { + return I->hasMetadata(kindID); + } + return false; +} + +extern "C" bool LLVMRustHasDbgMetadata(LLVMValueRef inst) { + if (auto *I = dyn_cast(unwrap(inst))) { + return false; + // return I->hasDbgValues(); + } + return false; +} + +extern "C" void LLVMRustRemoveFncAttr(LLVMValueRef F, + LLVMRustAttribute RustAttr) { + if (auto *Fn = dyn_cast(unwrap(F))) { + Fn->removeFnAttr(fromRust(RustAttr)); + } +} + +extern "C" void LLVMRustAddFncParamAttr(LLVMValueRef F, unsigned i, + LLVMAttributeRef RustAttr) { + if (auto *Fn = dyn_cast(unwrap(F))) { + Fn->addParamAttr(i, unwrap(RustAttr)); + } +} + +extern "C" void LLVMRustAddRetFncAttr(LLVMValueRef F, + LLVMAttributeRef RustAttr) { + if (auto *Fn = dyn_cast(unwrap(F))) { + Fn->addRetAttr(unwrap(RustAttr)); + } +} + +extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) { + if (auto *I = dyn_cast(unwrap(x))) { + // auto *MD = I->getMetadata(LLVMContext::MD_dbg); + auto *MD = I->getDebugLoc().getAsMDNode(); + return wrap(MD); + } + return nullptr; +} + +extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadataOfTy(LLVMValueRef x, + unsigned kindID) { + if (auto *I = dyn_cast(unwrap(x))) { + auto *MD = I->getMetadata(kindID); + return wrap(MD); + } + return nullptr; +} + +extern "C" void LLVMRustAddParamAttr(LLVMValueRef call, unsigned i, + LLVMAttributeRef RustAttr) { + if (auto *CI = dyn_cast(unwrap(call))) { + CI->addParamAttr(i, unwrap(RustAttr)); + } +} + +extern "C" void LLVMRustDISetInstMetadata(LLVMValueRef Inst, + LLVMMetadataRef Desc) { + if (auto *I = dyn_cast(unwrap(Inst))) { + I->setMetadata(LLVMContext::MD_dbg, unwrap(Desc)); + } +} + extern "C" void LLVMRustGlobalAddMetadata( LLVMValueRef Global, unsigned Kind, LLVMMetadataRef MD) { unwrap(Global)->addMetadata(Kind, *unwrap(MD)); diff --git a/compiler/rustc_middle/messages.ftl b/compiler/rustc_middle/messages.ftl index 27d555d7e26c7..a1e8e0b1d8f90 100644 --- a/compiler/rustc_middle/messages.ftl +++ b/compiler/rustc_middle/messages.ftl @@ -1,6 +1,8 @@ middle_adjust_for_foreign_abi_error = target architecture {$arch} does not support `extern {$abi}` ABI +middle_autodiff_unsafe_inner_const_ref = reading from a `Duplicated` const {$ty} is unsafe + middle_assert_async_resume_after_panic = `async fn` resumed after panicking middle_assert_async_resume_after_return = `async fn` resumed after completion @@ -85,3 +87,5 @@ middle_unknown_layout = middle_values_too_big = values of the type `{$ty}` are too big for the current architecture + +middle_unsupported_union = we don't support unions yet: '{$ty_name}' diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 52fd494a10db0..92cc1adcc7245 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -88,6 +88,7 @@ macro_rules! arena_types { [] upvars_mentioned: rustc_data_structures::fx::FxIndexMap, [] object_safety_violations: rustc_middle::traits::ObjectSafetyViolation, [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, + [] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem, [decode] attribute: rustc_ast::Attribute, [] name_set: rustc_data_structures::unord::UnordSet, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, diff --git a/compiler/rustc_middle/src/error.rs b/compiler/rustc_middle/src/error.rs index 3c5536570872a..b362b9facc98b 100644 --- a/compiler/rustc_middle/src/error.rs +++ b/compiler/rustc_middle/src/error.rs @@ -17,6 +17,14 @@ pub struct DropCheckOverflow<'tcx> { pub overflow_ty: Ty<'tcx>, } +#[derive(Diagnostic)] +#[diag(middle_autodiff_unsafe_inner_const_ref)] +pub struct AutodiffUnsafeInnerConstRef { + #[primary_span] + pub span: Span, + pub ty: String, +} + #[derive(Diagnostic)] #[diag(middle_opaque_hidden_type_mismatch)] pub struct OpaqueHiddenTypeMismatch<'tcx> { @@ -151,5 +159,11 @@ pub struct ErroneousConstant { pub span: Span, } +#[derive(Diagnostic)] +#[diag(middle_unsupported_union)] +pub struct UnsupportedUnion { + pub ty_name: String, +} + /// Used by `rustc_const_eval` pub use crate::fluent_generated::middle_adjust_for_foreign_abi_error; diff --git a/compiler/rustc_middle/src/query/erase.rs b/compiler/rustc_middle/src/query/erase.rs index b9200f1abf161..b9476a7ee92d6 100644 --- a/compiler/rustc_middle/src/query/erase.rs +++ b/compiler/rustc_middle/src/query/erase.rs @@ -200,6 +200,10 @@ impl EraseType for (&'_ T0, &'_ [T1]) { type Result = [u8; size_of::<(&'static (), &'static [()])>()]; } +impl EraseType for (&'_ T0, &'_ [T1], &'_ [T2]) { + type Result = [u8; size_of::<(&'static (), &'static [()], &'static [()])>()]; +} + macro_rules! trivial { ($($ty:ty),+ $(,)?) => { $( diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 3a54f5f6b3d01..836387a82442f 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -54,6 +54,7 @@ use crate::ty::{ use crate::ty::{GenericArg, GenericArgsRef}; use rustc_arena::TypedArena; use rustc_ast as ast; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem}; use rustc_ast::expand::{allocator::AllocatorKind, StrippedCfgItem}; use rustc_attr as attr; use rustc_data_structures::fingerprint::Fingerprint; @@ -1227,6 +1228,13 @@ rustc_queries! { separate_provide_extern } + /// The list autodiff extern functions in current crate + query autodiff_attrs(def_id: DefId) -> &'tcx AutoDiffAttrs { + desc { |tcx| "computing autodiff attributes of `{}`", tcx.def_path_str(def_id) } + arena_cache + cache_on_disk_if { def_id.is_local() } + } + query asm_target_features(def_id: DefId) -> &'tcx FxIndexSet { desc { |tcx| "computing target features for inline asm of `{}`", tcx.def_path_str(def_id) } } @@ -1880,7 +1888,7 @@ rustc_queries! { separate_provide_extern } - query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [CodegenUnit<'tcx>]) { + query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [AutoDiffItem], &'tcx [CodegenUnit<'tcx>]) { eval_always desc { "collect_and_partition_mono_items" } } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 36b8d387d6966..e15c35d6b0639 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -10,6 +10,10 @@ //! ["The `ty` module: representing types"]: https://rustc-dev-guide.rust-lang.org/ty.html #![allow(rustc::usage_of_ty_tykind)] +#![allow(unused_imports)] + +use rustc_ast::expand::typetree::{Type, Kind, TypeTree, FncTree}; +use rustc_target::abi::FieldsShape; pub use self::fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable}; pub use self::visit::{TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor}; @@ -17,7 +21,7 @@ pub use self::AssocItemContainer::*; pub use self::BorrowKind::*; pub use self::IntVarValue::*; pub use self::Variance::*; -use crate::error::{OpaqueHiddenTypeMismatch, TypeMismatchReason}; +use crate::error::{OpaqueHiddenTypeMismatch, TypeMismatchReason, UnsupportedUnion, AutodiffUnsafeInnerConstRef}; use crate::metadata::ModChild; use crate::middle::privacy::EffectiveVisibilities; use crate::mir::{Body, CoroutineLayout}; @@ -101,6 +105,7 @@ pub use self::typeck_results::{ CanonicalUserType, CanonicalUserTypeAnnotation, CanonicalUserTypeAnnotations, IsIdentity, TypeckResults, UserType, UserTypeAnnotationIndex, }; +use crate::query::Key; pub mod _match; pub mod abstract_const; @@ -195,6 +200,8 @@ pub struct ResolverAstLowering { pub node_id_to_def_id: NodeMap, pub def_id_to_node_id: IndexVec, + /// Mapping of autodiff function IDs + pub autodiff_map: FxHashMap, pub trait_map: NodeMap>, /// List functions and methods for which lifetime elision was successful. pub lifetime_elision_allowed: FxHashSet, @@ -2713,3 +2720,307 @@ mod size_asserts { static_assert_size!(WithCachedTypeInfo>, 56); // tidy-alphabetical-end } + +pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + let mut visited = vec![]; + let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited, None); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child: ty }; + return TypeTree(vec![tt]); +} + +use rustc_ast::expand::autodiff_attrs::DiffActivity; + +// This function combines three tasks. To avoid traversing each type 3x, we combine them. +// 1. Create a TypeTree from a Ty. This is the main task. +// 2. IFF da is not empty, we also want to adjust DiffActivity to account for future MIR->LLVM +// lowering. E.g. fat ptr are going to introduce an extra int. +// 3. IFF da is not empty, we are creating TT for a function directly differentiated (has an +// autodiff macro on top). Here we want to make sure that shadows are mutable internally. +// We know the outermost ref/ptr indirection is mutability - we generate it like that. +// We now have to make sure that inner ptr/ref are mutable too, or issue a warning. +// Not an error, becaues it only causes issues if they are actually read, which we don't check +// yet. We should add such analysis to relibably either issue an error or accept without warning. +// If there only were some reasearch to do that... +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec, span: Option) -> FncTree { + if !fn_ty.is_fn() { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // If rustc compiles the unmodified primal, we know that this copy of the function + // also has correct lifetimes. We know that Enzyme won't free the shadow too early + // (or actually at all), so let's strip lifetimes when computing the layout. + // Recommended by compiler-errors: + // https://discord.com/channels/273534239310479360/957720175619215380/1223454360676208751 + let x = tcx.instantiate_bound_regions_with_erased(fnc_binder); + + let mut new_activities = vec![]; + let mut new_positions = vec![]; + let mut visited = vec![]; + let mut args = vec![]; + for (i, ty) in x.inputs().iter().enumerate() { + // We care about safety checks, if an argument get's duplicated and we write into the + // shadow. That's equivalent to Duplicated or DuplicatedOnly. + let safety = if !da.is_empty() { + assert!(da.len() == x.inputs().len(), "{:?} != {:?}", da.len(), x.inputs().len()); + // If we have Activities, we also have spans + assert!(span.is_some()); + match da[i] { + DiffActivity::DuplicatedOnly | DiffActivity::Duplicated => true, + _ => false, + } + } else { + false + }; + + visited.clear(); + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + let inner_ty = ty.builtin_deref(true).unwrap().ty; + if inner_ty.is_slice() { + // We know that the lenght will be passed as extra arg. + let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + args.push(TypeTree(vec![tt])); + let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() }; + args.push(TypeTree(vec![i64_tt])); + if !da.is_empty() { + // We are looking at a slice. The length of that slice will become an + // extra integer on llvm level. Integers are always const. + // However, if the slice get's duplicated, we want to know to later check the + // size. So we mark the new size argument as FakeActivitySize. + let activity = match da[i] { + DiffActivity::DualOnly | DiffActivity::Dual | + DiffActivity::DuplicatedOnly | DiffActivity::Duplicated + => DiffActivity::FakeActivitySize, + DiffActivity::Const => DiffActivity::Const, + _ => panic!("unexpected activity for ptr/ref"), + }; + new_activities.push(activity); + new_positions.push(i + 1); + } + trace!("ABI MATCHING!"); + continue; + } + } + let arg_tt = typetree_from_ty(*ty, tcx, 0, safety, &mut visited, span); + args.push(arg_tt); + } + + // now add the extra activities coming from slices + // Reverse order to not invalidate the indices + for _ in 0..new_activities.len() { + let pos = new_positions.pop().unwrap(); + let activity = new_activities.pop().unwrap(); + da.insert(pos, activity); + } + + visited.clear(); + let ret = typetree_from_ty(x.output(), tcx, 0, false, &mut visited, span); + + FncTree { args, ret } +} + +fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec>, span: Option) -> TypeTree { + if depth > 20 { + trace!("depth > 20 for ty: {}", &ty); + } + if visited.contains(&ty) { + // recursive type + trace!("recursive type: {}", &ty); + return TypeTree::new(); + } + visited.push(ty); + + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + + let inner_ty_and_mut = ty.builtin_deref(true).unwrap(); + let is_mut = inner_ty_and_mut.mutbl == hir::Mutability::Mut; + let inner_ty = inner_ty_and_mut.ty; + + // Now account for inner mutability. + if !is_mut && depth > 0 && safety { + let ptr_ty: String = if ty.is_ref() { + "ref" + } else if ty.is_unsafe_ptr() { + "ptr" + } else { + assert!(ty.is_box()); + "box" + }.to_string(); + + // If we have mutability, we also have a span + assert!(span.is_some()); + let span = span.unwrap(); + + tcx.sess + .dcx() + .emit_warning(AutodiffUnsafeInnerConstRef{span, ty: ptr_ty}); + } + + //visited.push(inner_ty); + let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + visited.pop(); + return TypeTree(vec![tt]); + } + + + if ty.is_closure() || ty.is_coroutine() || ty.is_fresh() || ty.is_fn() { + visited.pop(); + return TypeTree::new(); + } + + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => panic!("floatTy scalar that is neither f32 nor f64"), + } + } else { + panic!("scalar that is neither integral nor floating point"); + }; + visited.pop(); + return TypeTree(vec![Type { offset: -1, child: TypeTree::new(), kind, size }]); + } + + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; + + let layout = tcx.layout_of(param_env_and); + assert!(layout.is_ok()); + + let layout = layout.unwrap().layout; + let fields = layout.fields(); + let max_size = layout.size(); + + + + if ty.is_adt() && !ty.is_simd() { + let adt_def = ty.ty_adt_def().unwrap(); + + if adt_def.is_struct() { + let (offsets, _memory_index) = match fields { + // Manuel TODO: + FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), + FieldsShape::Array { .. } => {return TypeTree::new();}, //e.g. core::arch::x86_64::__m128i, TODO: later + FieldsShape::Union(_) => {return TypeTree::new();}, + FieldsShape::Primitive => {return TypeTree::new();}, + }; + + let substs = match ty.kind() { + Adt(_, subst_ref) => subst_ref, + _ => panic!(""), + }; + + let fields = adt_def.all_fields(); + let fields = fields + .into_iter() + .zip(offsets.into_iter()) + .filter_map(|(field, offset)| { + let field_ty: Ty<'_> = field.ty(tcx, substs); + let field_ty: Ty<'_> = + tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); + + if field_ty.is_phantom_data() { + return None; + } + + //visited.push(field_ty); + let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0; + + for c in &mut child { + if c.offset == -1 { + c.offset = offset.bytes() as isize + } else { + c.offset += offset.bytes() as isize; + } + } + + Some(child) + }) + .flatten() + .collect::>(); + + visited.pop(); + let ret_tt = TypeTree(fields); + return ret_tt; + } else if adt_def.is_enum() { + // Enzyme can't represent enums, so let it figure it out itself, without seeeding + // typetree + //unimplemented!("adt that is an enum"); + } else { + //let ty_name = tcx.def_path_debug_str(adt_def.did()); + //tcx.sess.emit_fatal(UnsupportedUnion { ty_name }); + } + } + + if ty.is_simd() { + trace!("simd"); + let (_size, inner_ty) = ty.simd_size_and_type(tcx); + //visited.push(inner_ty); + let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); + //let tt = TypeTree( + // std::iter::repeat(subtt) + // .take(*count as usize) + // .enumerate() + // .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + // .flatten() + // .collect(), + //); + // TODO + visited.pop(); + return TypeTree::new(); + } + + if ty.is_array() { + let (stride, count) = match fields { + FieldsShape::Array { stride: s, count: c } => (s, c), + _ => panic!(""), + }; + let byte_stride = stride.bytes_usize(); + let byte_max_size = max_size.bytes_usize(); + + assert!(byte_stride * *count as usize == byte_max_size); + if (*count as usize) == 0 { + return TypeTree::new(); + } + let sub_ty = ty.builtin_index().unwrap(); + //visited.push(sub_ty); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); + + // calculate size of subtree + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; + let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; + let tt = TypeTree( + std::iter::repeat(subtt) + .take(*count as usize) + .enumerate() + .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + .flatten() + .collect(), + ); + + visited.pop(); + return tt; + } + + if ty.is_slice() { + let sub_ty = ty.builtin_index().unwrap(); + //visited.push(sub_ty); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); + + visited.pop(); + return subtt; + } + + visited.pop(); + TypeTree::new() +} diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index c7f1b9fa78454..bce8d9c4a9878 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] # tidy-alphabetical-start +rustc_ast = { path = "../rustc_ast" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } rustc_fluent_macro = { path = "../rustc_fluent_macro" } @@ -13,6 +14,7 @@ rustc_macros = { path = "../rustc_macros" } rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } +rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } rustc_target = { path = "../rustc_target" } serde = "1" serde_json = "1" diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index a68bfcd06d505..644afb0251407 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -206,7 +206,7 @@ pub enum MonoItemCollectionMode { pub struct UsageMap<'tcx> { // Maps every mono item to the mono items used by it. - used_map: FxHashMap, Vec>>, + pub used_map: FxHashMap, Vec>>, // Maps every mono item to the mono items that use it. user_map: FxHashMap, Vec>>, diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index d47d3e5e7d3b5..36a3c4f34a490 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -98,6 +98,7 @@ use std::fs::{self, File}; use std::io::{BufWriter, Write}; use std::path::{Path, PathBuf}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, AutoDiffAttrs}; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_data_structures::sync; use rustc_hir::def::DefKind; @@ -111,15 +112,22 @@ use rustc_middle::mir::mono::{ }; use rustc_middle::query::Providers; use rustc_middle::ty::print::{characteristic_def_id_of_type, with_no_trimmed_paths}; -use rustc_middle::ty::{self, visit::TypeVisitableExt, InstanceDef, TyCtxt}; +use rustc_middle::ty::{ + self, visit::TypeVisitableExt, InstanceDef, ParamEnv, TyCtxt, + fnc_typetrees +}; use rustc_session::config::{DumpMonoStatsFormat, SwitchWithOptPath}; use rustc_session::CodegenUnits; use rustc_span::symbol::Symbol; +use rustc_symbol_mangling::symbol_name_for_instance_in_crate; use crate::collector::UsageMap; use crate::collector::{self, MonoItemCollectionMode}; use crate::errors::{CouldntDumpMonoStats, SymbolAlreadyDefined, UnknownCguCollectionMode}; +use rustc_ast::expand::autodiff_attrs::DiffActivity; + + struct PartitioningCx<'a, 'tcx> { tcx: TyCtxt<'tcx>, usage_map: &'a UsageMap<'tcx>, @@ -248,7 +256,14 @@ where &mut can_be_internalized, export_generics, ); - if visibility == Visibility::Hidden && can_be_internalized { + + // We can't differentiate something that got inlined. + let autodiff_active = match characteristic_def_id { + Some(def_id) => cx.tcx.autodiff_attrs(def_id).is_active(), + None => false, + }; + + if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); } let size_estimate = mono_item.size_estimate(cx.tcx); @@ -1084,7 +1099,10 @@ where } } -fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[CodegenUnit<'_>]) { +fn collect_and_partition_mono_items( + tcx: TyCtxt<'_>, + (): (), +) -> (&DefIdSet, &[AutoDiffItem], &[CodegenUnit<'_>]) { let collection_mode = match tcx.sess.opts.unstable_opts.print_mono_items { Some(ref s) => { let mode = s.to_lowercase(); @@ -1143,6 +1161,61 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co }) .collect(); + let autodiff_items2: Vec<_> = items + .iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance) => Some((item, instance)), + _ => None, + }).collect(); + let mut autodiff_items: Vec = vec![]; + + for (item, instance) in autodiff_items2 { + let target_id = instance.def_id(); + let target_attrs: &AutoDiffAttrs = tcx.autodiff_attrs(target_id); + let mut input_activities: Vec = target_attrs.input_activity.clone(); + if target_attrs.is_source() { + trace!("source found: {:?}", target_id); + } + if !target_attrs.apply_autodiff() { + continue; + } + + let target_symbol = + symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + + let source = + usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item { + MonoItem::Fn(ref instance_s) => { + let source_id = instance_s.def_id(); + if tcx.autodiff_attrs(source_id).is_active() { + return Some(instance_s); + } + None + } + _ => None, + }); + let inst = match source { + Some(source) => source, + None => continue, + }; + + println!("source_id: {:?}", inst.def_id()); + let fn_ty = inst.ty(tcx, ParamEnv::empty()); + assert!(fn_ty.is_fn()); + let span = tcx.def_span(inst.def_id()); + let fnc_tree = fnc_typetrees(tcx, fn_ty, &mut input_activities, Some(span)); + let (inputs, output) = (fnc_tree.args, fnc_tree.ret); + //check_types(inst.ty(tcx, ParamEnv::empty()), tcx, &target_attrs.input_activity); + let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); + + let mut new_target_attrs = target_attrs.clone(); + new_target_attrs.input_activity = input_activities; + let itm = new_target_attrs.into_item(symb, target_symbol, inputs, output); + autodiff_items.push(itm); + }; + + let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); + // Output monomorphization stats per def_id if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats { if let Err(err) = @@ -1202,10 +1275,80 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co println!("MONO_ITEM {item}"); } } + if autodiff_items.len() > 0 { + trace!("AUTODIFF ITEMS EXIST"); + for item in &mut *autodiff_items { + trace!("{}", &item); + } + } - (tcx.arena.alloc(mono_items), codegen_units) + (tcx.arena.alloc(mono_items), autodiff_items, codegen_units) } +#[allow(dead_code)] +enum Checks { + Slice(usize), + Enum(usize), +} + +// ty is going to get duplicated. So we need to find DST +// inside to later make sure that it's shadow is at least +// equally large. +//pub fn check_for_types<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, pos: usize) { +// // Find all DST inside of this type. +// let johnnie = ty.walk(); +// let mut check_positions = 0; +// let mut check_vec = Vec::new(); +// while let Some(ty) = johnnie.next() { +// if ty.is_trivially_sized(tcx) { +// ty.skip_current_subtree() +// } +// if ty.is_slice() { +// check_vec.push(Checks::Slice(check_positions)); +// } +// if ty.is_enum() { +// check_vec.push(Checks::Enum(check_positions)); +// } +// if ty.is_adt() { +// let adt_def = ty.ty_adt_def().unwrap(); +// if adt_def.is_struct() { +// let fields = adt_def.all_fields(); +// for field in fields { +// let field_ty: Ty<'_> = field.ty(tcx, ty.substs); +// if field_ty.is_phantom_data() { +// continue; +// } +// if field_ty.is_trivially_sized(tcx) { +// continue; +// } +// check_vec.push(Checks::Enum(check_positions)); +// } +// } +// } +// +// //ty::Str | ty::Slice(_) | ty::Dynamic(..) | ty::Foreign(..) => false, +// //ty::Tuple(tys) => tys.iter().all(|ty| ty.is_trivially_sized(tcx)), +// //ty::Adt(def, _args) => def.sized_constraint(tcx).skip_binder().is_empty(), +// check_positions += 1; +// } +//} +//pub fn check_types<'tcx>(fn_ty: Ty<'tcx>, tcx: TyCtxt<'tcx>, activities: &[DiffActivity]) { +// let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); +// +// // TODO: verify. +// let x: ty::FnSig<'_> = fnc_binder.skip_binder(); +// let _inputs = x.inputs(); +// for i in 0..activities.len() { +// match activities[i] { +// DiffActivity::Const => continue, +// DiffActivity::Active => continue, +// DiffActivity::ActiveOnly => continue, +// _ => {}, +// } +// check_for_types(inputs[i], tcx, i); +// } +//} + /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s /// def, to a file in the given output directory. fn dump_mono_items_stats<'tcx>( @@ -1288,12 +1431,12 @@ pub fn provide(providers: &mut Providers) { providers.collect_and_partition_mono_items = collect_and_partition_mono_items; providers.is_codegened_item = |tcx, def_id| { - let (all_mono_items, _) = tcx.collect_and_partition_mono_items(()); + let (all_mono_items, _, _) = tcx.collect_and_partition_mono_items(()); all_mono_items.contains(&def_id) }; providers.codegen_unit = |tcx, name| { - let (_, all) = tcx.collect_and_partition_mono_items(()); + let (_, _, all) = tcx.collect_and_partition_mono_items(()); all.iter() .find(|cgu| cgu.name() == name) .unwrap_or_else(|| panic!("failed to find cgu with name {name:?}")) diff --git a/compiler/rustc_passes/messages.ftl b/compiler/rustc_passes/messages.ftl index be50aad13032f..5a8fdc8c9e301 100644 --- a/compiler/rustc_passes/messages.ftl +++ b/compiler/rustc_passes/messages.ftl @@ -13,6 +13,10 @@ passes_abi_ne = passes_abi_of = fn_abi_of({$fn_name}) = {$fn_abi} +passes_autodiff_attr = + `#[autodiff]` should be applied to a function + .label = not a function + passes_allow_incoherent_impl = `rustc_allow_incoherent_impl` attribute should be applied to impl items. .label = the only currently supported targets are inherent methods diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index c5073048be3e8..64243a4c438eb 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -232,6 +232,7 @@ impl CheckAttrVisitor<'_> { self.check_generic_attr(hir_id, attr, target, Target::Fn); self.check_proc_macro(hir_id, target, ProcMacroKind::Derive) } + sym::autodiff => self.check_autodiff(hir_id, attr, span, target), _ => {} } @@ -2382,6 +2383,17 @@ impl CheckAttrVisitor<'_> { self.abort.set(true); } } + + /// Checks if `#[autodiff]` is applied to an item other than a function item. + fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) { + match target { + Target::Fn => {} + _ => { + self.tcx.sess.emit_err(errors::AutoDiffAttr { attr_span: span }); + self.abort.set(true); + } + } + } } impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> { diff --git a/compiler/rustc_passes/src/errors.rs b/compiler/rustc_passes/src/errors.rs index 856256a064149..fb8cb98346e14 100644 --- a/compiler/rustc_passes/src/errors.rs +++ b/compiler/rustc_passes/src/errors.rs @@ -24,6 +24,14 @@ pub struct IncorrectDoNotRecommendLocation { pub span: Span, } +#[derive(Diagnostic)] +#[diag(passes_autodiff_attr)] +pub struct AutoDiffAttr { + #[primary_span] + #[label] + pub attr_span: Span, +} + #[derive(LintDiagnostic)] #[diag(passes_outer_crate_level_attr)] pub struct OuterCrateLevelAttr; diff --git a/compiler/rustc_resolve/src/lib.rs b/compiler/rustc_resolve/src/lib.rs index a14f3d494fb4d..13c495b7a4368 100644 --- a/compiler/rustc_resolve/src/lib.rs +++ b/compiler/rustc_resolve/src/lib.rs @@ -1539,6 +1539,7 @@ impl<'a, 'tcx> Resolver<'a, 'tcx> { next_node_id: self.next_node_id, node_id_to_def_id: self.node_id_to_def_id, def_id_to_node_id: self.def_id_to_node_id, + autodiff_map: Default::default(), trait_map: self.trait_map, lifetime_elision_allowed: self.lifetime_elision_allowed, lint_buffer: Steal::new(self.lint_buffer), diff --git a/compiler/rustc_resolve/src/macros.rs b/compiler/rustc_resolve/src/macros.rs index 1001286b6c2d7..7ea8bf31d095b 100644 --- a/compiler/rustc_resolve/src/macros.rs +++ b/compiler/rustc_resolve/src/macros.rs @@ -344,6 +344,8 @@ impl<'a, 'tcx> ResolverExpand for Resolver<'a, 'tcx> { self.containers_deriving_copy.contains(&expn_id) } + // TODO: add autodiff? + fn resolve_derives( &mut self, expn_id: LocalExpnId, diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 0d731e330f3d9..2219fd5e951a8 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -1273,6 +1273,7 @@ fn default_configuration(sess: &Session) -> Cfg { // // NOTE: These insertions should be kept in sync with // `CheckCfg::fill_well_known` below. + ins_none!(sym::autodiff_fallback); if sess.opts.debug_assertions { ins_none!(sym::debug_assertions); @@ -1460,6 +1461,7 @@ impl CheckCfg { // // When adding a new config here you should also update // `tests/ui/check-cfg/well-known-values.rs`. + ins!(sym::autodiff_fallback, no_values); ins!(sym::debug_assertions, no_values); diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 95106cc64c129..c1783c56a0125 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -69,6 +69,7 @@ symbols! { // Keywords that are used in unstable Rust or reserved for future use. Abstract: "abstract", + //Autodiff: "autodiff", Become: "become", Box: "box", Do: "do", @@ -438,6 +439,8 @@ symbols! { attributes, augmented_assignments, auto_traits, + autodiff, + autodiff_fallback, automatically_derived, avx, avx512_target_feature, @@ -491,6 +494,7 @@ symbols! { cfg_accessible, cfg_attr, cfg_attr_multi, + cfg_autodiff_fallback, cfg_doctest, cfg_eval, cfg_hide, @@ -1372,6 +1376,7 @@ symbols! { rustc_allow_incoherent_impl, rustc_allowed_through_unstable_modules, rustc_attrs, + rustc_autodiff, rustc_box, rustc_builtin_macro, rustc_capture_analysis, diff --git a/config.example.toml b/config.example.toml index 4cf7c1e81990c..53995ac7d7202 100644 --- a/config.example.toml +++ b/config.example.toml @@ -149,6 +149,9 @@ # Whether to build the clang compiler. #clang = false +# Wheter to build Enzyme as AutoDiff backend. +#enzyme = true + # Whether to enable llvm compilation warnings. #enable-warnings = false diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index 7f5908e477cfd..041d9990354ca 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1486,6 +1486,20 @@ pub(crate) mod builtin { ($file:expr $(,)?) => {{ /* compiler built-in */ }}; } + /// Attribute macro used to apply derive macros for implementing traits + /// in a const context. Autodiff + /// + /// See [the reference] for more info. + /// + /// [the reference]: ../../../reference/attributes/derive.html + #[unstable(feature = "autodiff", issue = "none")] + #[allow_internal_unstable(rustc_attrs)] + #[rustc_builtin_macro] + #[cfg(not(bootstrap))] + pub macro autodiff($item:item) { + /* compiler built-in */ + } + /// Asserts that a boolean expression is `true` at runtime. /// /// This will invoke the [`panic!`] macro if the provided expression cannot be diff --git a/library/core/src/prelude/v1.rs b/library/core/src/prelude/v1.rs index 10525a16f3a66..a934e62086c8e 100644 --- a/library/core/src/prelude/v1.rs +++ b/library/core/src/prelude/v1.rs @@ -83,6 +83,11 @@ pub use crate::macros::builtin::{ #[unstable(feature = "derive_const", issue = "none")] pub use crate::macros::builtin::derive_const; +#[cfg(not(bootstrap))] +#[unstable(feature = "autodiff", issue = "none")] +#[rustc_builtin_macro] +pub use crate::macros::builtin::autodiff; + #[unstable( feature = "cfg_accessible", issue = "64797", diff --git a/library/std/src/lib.rs b/library/std/src/lib.rs index 6365366297c43..11e21c577dd70 100644 --- a/library/std/src/lib.rs +++ b/library/std/src/lib.rs @@ -255,6 +255,7 @@ #![allow(unused_features)] // // Features: +#![cfg_attr(not(bootstrap), feature(autodiff))] #![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))] #![cfg_attr( all(target_vendor = "fortanix", target_env = "sgx"), @@ -662,6 +663,14 @@ pub use core::{ module_path, option_env, stringify, trace_macros, }; +// #[unstable( +// feature = "autodiff", +// issue = "87555", +// reason = "`autodiff` is not stable enough for use and is subject to change" +// )] +// #[cfg(not(bootstrap))] +// pub use core::autodiff; + #[unstable( feature = "concat_bytes", issue = "87555", diff --git a/library/std/src/prelude/v1.rs b/library/std/src/prelude/v1.rs index 7a7a773763559..0be8fa35e100d 100644 --- a/library/std/src/prelude/v1.rs +++ b/library/std/src/prelude/v1.rs @@ -67,6 +67,11 @@ pub use core::prelude::v1::{ #[unstable(feature = "derive_const", issue = "none")] pub use core::prelude::v1::derive_const; +#[unstable(feature = "autodiff", issue = "none")] +#[cfg(not(bootstrap))] +#[rustc_builtin_macro] +pub use core::prelude::v1::autodiff; + // Do not `doc(no_inline)` either. #[unstable( feature = "cfg_accessible", diff --git a/src/bootstrap/configure.py b/src/bootstrap/configure.py index 544a42d9ada1a..dfc07a4b7ddae 100755 --- a/src/bootstrap/configure.py +++ b/src/bootstrap/configure.py @@ -72,6 +72,7 @@ def v(*args): o("optimize-llvm", "llvm.optimize", "build optimized LLVM") o("llvm-assertions", "llvm.assertions", "build LLVM with assertions") o("llvm-plugins", "llvm.plugins", "build LLVM with plugin interface") +o("llvm-enzyme", "llvm.enzyme", "build LLVM with enzyme") o("debug-assertions", "rust.debug-assertions", "build with debugging assertions") o("debug-assertions-std", "rust.debug-assertions-std", "build the standard library with debugging assertions") o("overflow-checks", "rust.overflow-checks", "build with overflow checks") diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index df4d1a43dabc7..8be210a215bc3 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1048,6 +1048,9 @@ pub fn rustc_cargo_env( if builder.config.rust_verify_llvm_ir { cargo.env("RUSTC_VERIFY_LLVM_IR", "1"); } + if builder.config.llvm_enzyme { + cargo.rustflag("--cfg=llvm_enzyme"); + } // Note that this is disabled if LLVM itself is disabled or we're in a check // build. If we are in a check build we still go ahead here presuming we've @@ -1576,6 +1579,7 @@ pub struct Assemble { pub target_compiler: Compiler, } +#[allow(unreachable_code)] impl Step for Assemble { type Output = Compiler; const ONLY_HOSTS: bool = true; @@ -1636,6 +1640,30 @@ impl Step for Assemble { return target_compiler; } + // Build enzyme + let enzyme_install = if builder.config.llvm_enzyme { + Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })) + } else { + None + }; + + if let Some(enzyme_install) = enzyme_install { + let lib_ext = match env::consts::OS { + "macos" => "dylib", + "windows" => "dll", + _ => "so", + }; + + let src_lib = enzyme_install.join("build/Enzyme/libEnzyme-17").with_extension(lib_ext); + + let libdir = builder.sysroot_libdir(build_compiler, build_compiler.host); + let target_libdir = builder.sysroot_libdir(target_compiler, target_compiler.host); + let dst_lib = libdir.join("libEnzyme-17").with_extension(lib_ext); + let target_dst_lib = target_libdir.join("libEnzyme-17").with_extension(lib_ext); + builder.copy(&src_lib, &dst_lib); + builder.copy(&src_lib, &target_dst_lib); + } + // Build the libraries for this compiler to link to (i.e., the libraries // it uses at runtime). NOTE: Crates the target compiler compiles don't // link to these. (FIXME: Is that correct? It seems to be correct most @@ -1643,9 +1671,10 @@ impl Step for Assemble { // when not performing a full bootstrap). builder.ensure(Rustc::new(build_compiler, target_compiler.host)); - // FIXME: For now patch over problems noted in #90244 by early returning here, even though - // we've not properly assembled the target sysroot. A full fix is pending further investigation, - // for now full bootstrap usage is rare enough that this is OK. + // FIXME: For now patch over problems noted in #90244 by early returning + // here, even though we've not properly assembled the target sysroot. A + // full fix is pending further investigation, for now full bootstrap + // usage is rare enough that this is OK. if target_compiler.stage >= 3 && !builder.config.full_bootstrap { return target_compiler; } @@ -1664,9 +1693,8 @@ impl Step for Assemble { let lld_install = if builder.config.lld_enabled { Some(builder.ensure(llvm::Lld { target: target_compiler.host })) - } else { - None - }; + } + else {None}; let stage = target_compiler.stage; let host = target_compiler.host; diff --git a/src/bootstrap/src/core/build_steps/llvm.rs b/src/bootstrap/src/core/build_steps/llvm.rs index 4b2d3e9ab4b75..9c3e40913187f 100644 --- a/src/bootstrap/src/core/build_steps/llvm.rs +++ b/src/bootstrap/src/core/build_steps/llvm.rs @@ -817,6 +817,94 @@ fn get_var(var_base: &str, host: &str, target: &str) -> Option { .or_else(|| env::var_os(var_base)) } +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct Enzyme { + pub target: TargetSelection, +} + +impl Step for Enzyme { + type Output = PathBuf; + const ONLY_HOSTS: bool = true; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/tools/enzyme/enzyme") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(Enzyme { target: run.target }); + } + + /// Compile Enzyme for `target`. + fn run(self, builder: &Builder<'_>) -> PathBuf { + if builder.config.dry_run() { + let out_dir = builder.enzyme_out(self.target); + return out_dir; + } + let target = self.target; + + let LlvmResult { llvm_config, llvm_cmake_dir } = builder.ensure(Llvm { target }); + + static STAMP_HASH_MEMO: OnceLock = OnceLock::new(); + let smart_stamp_hash = STAMP_HASH_MEMO.get_or_init(|| { + generate_smart_stamp_hash( + &builder.config.src.join("src/tools/enzyme"), + &builder.enzyme_info.sha().unwrap_or_default(), + ) + }); + + let out_dir = builder.enzyme_out(target); + let stamp = out_dir.join("enzyme-finished-building"); + let stamp = HashStamp::new(stamp, Some(smart_stamp_hash)); + + if stamp.is_done() { + if stamp.hash.is_none() { + builder.info( + "Could not determine the Enzyme submodule commit hash. \ + Assuming that an Enzyme rebuild is not necessary.", + ); + builder.info(&format!( + "To force Enzyme to rebuild, remove the file `{}`", + stamp.path.display() + )); + } + return out_dir; + } + + builder.info(&format!("Building Enzyme for {}", target)); + t!(stamp.remove()); + let _time = helpers::timeit(&builder); + t!(fs::create_dir_all(&out_dir)); + + builder.update_submodule(&Path::new("src").join("tools").join("enzyme")); + let mut cfg = cmake::Config::new(builder.src.join("src/tools/enzyme/enzyme/")); + // TODO: Find a nicer way to use Enzyme Debug builds + //cfg.profile("Debug"); + //cfg.define("CMAKE_BUILD_TYPE", "Debug"); + configure_cmake(builder, target, &mut cfg, true, LdFlags::default(), &[]); + + // Re-use the same flags as llvm to control the level of debug information + // generated for lld. + let profile = match (builder.config.llvm_optimize, builder.config.llvm_release_debuginfo) { + (false, _) => "Debug", + (true, false) => "Release", + (true, true) => "RelWithDebInfo", + }; + + cfg.out_dir(&out_dir) + .profile(profile) + .env("LLVM_CONFIG_REAL", &llvm_config) + .define("LLVM_ENABLE_ASSERTIONS", "ON") + .define("ENZYME_EXTERNAL_SHARED_LIB", "ON") + .define("LLVM_DIR", &llvm_cmake_dir); + + cfg.build(); + + t!(stamp.write()); + + out_dir + } +} + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] pub struct Lld { pub target: TargetSelection, @@ -1123,7 +1211,12 @@ impl HashStamp { fn is_done(&self) -> bool { match fs::read(&self.path) { - Ok(h) => self.hash.as_deref().unwrap_or(b"") == h.as_slice(), + Ok(h) => { + let unwrapped = self.hash.as_deref().unwrap_or(b""); + let res = unwrapped == h.as_slice(); + eprintln!("Result for {:?}: {res:?} for expected '{unwrapped:?}' and read '{h:?}'", self.path); + res + }, Err(e) if e.kind() == io::ErrorKind::NotFound => false, Err(e) => { panic!("failed to read stamp file `{}`: {}", self.path.display(), e); diff --git a/src/bootstrap/src/core/builder.rs b/src/bootstrap/src/core/builder.rs index 753b41abaf489..e7d02653baf78 100644 --- a/src/bootstrap/src/core/builder.rs +++ b/src/bootstrap/src/core/builder.rs @@ -1416,6 +1416,12 @@ impl<'a> Builder<'a> { rustflags.arg(sysroot_str); } + // https://rust-lang.zulipchat.com/#narrow/stream/182449-t-compiler.2Fhelp/topic/.E2.9C.94.20link.20new.20library.20into.20stage1.2Frustc + if self.config.llvm_enzyme { + rustflags.arg("-l"); + rustflags.arg("Enzyme-17"); + } + let use_new_symbol_mangling = match self.config.rust_new_symbol_mangling { Some(setting) => { // If an explicit setting is given, use that diff --git a/src/bootstrap/src/core/config/config.rs b/src/bootstrap/src/core/config/config.rs index f1e1b89d9ba71..1023d3f0a584c 100644 --- a/src/bootstrap/src/core/config/config.rs +++ b/src/bootstrap/src/core/config/config.rs @@ -211,6 +211,7 @@ pub struct Config { pub llvm_assertions: bool, pub llvm_tests: bool, pub llvm_plugins: bool, + pub llvm_enzyme: bool, pub llvm_optimize: bool, pub llvm_thin_lto: bool, pub llvm_release_debuginfo: bool, @@ -872,6 +873,7 @@ define_config! { release_debuginfo: Option = "release-debuginfo", assertions: Option = "assertions", tests: Option = "tests", + enzyme: Option = "enzyme", plugins: Option = "plugins", ccache: Option = "ccache", static_libstdcpp: Option = "static-libstdcpp", @@ -1093,6 +1095,7 @@ define_config! { codegen_backends: Option> = "codegen-backends", lld: Option = "lld", lld_mode: Option = "use-lld", + llvm_enzyme: Option = "llvm-enzyme", llvm_tools: Option = "llvm-tools", deny_warnings: Option = "deny-warnings", backtrace_on_ice: Option = "backtrace-on-ice", @@ -1494,6 +1497,7 @@ impl Config { // we'll infer default values for them later let mut llvm_assertions = None; let mut llvm_tests = None; + let mut llvm_enzyme = None; let mut llvm_plugins = None; let mut debug = None; let mut debug_assertions = None; @@ -1542,6 +1546,7 @@ impl Config { save_toolstates, codegen_backends, lld, + llvm_enzyme, llvm_tools, deny_warnings, backtrace_on_ice, @@ -1631,6 +1636,8 @@ impl Config { } set(&mut config.llvm_tools_enabled, llvm_tools); + config.llvm_enzyme = + llvm_enzyme.unwrap_or(config.channel == "dev" || config.channel == "nightly"); config.rustc_parallel = parallel_compiler.unwrap_or(config.channel == "dev" || config.channel == "nightly"); config.rustc_default_linker = default_linker; @@ -1697,6 +1704,7 @@ impl Config { release_debuginfo, assertions, tests, + enzyme, plugins, ccache, static_libstdcpp, @@ -1729,6 +1737,7 @@ impl Config { set(&mut config.ninja_in_file, ninja); llvm_assertions = assertions; llvm_tests = tests; + llvm_enzyme = enzyme; llvm_plugins = plugins; set(&mut config.llvm_optimize, optimize_toml); set(&mut config.llvm_thin_lto, thin_lto); @@ -1885,6 +1894,7 @@ impl Config { config.llvm_assertions = llvm_assertions.unwrap_or(false); config.llvm_tests = llvm_tests.unwrap_or(false); + config.llvm_enzyme = llvm_enzyme.unwrap_or(true); config.llvm_plugins = llvm_plugins.unwrap_or(false); config.rust_optimize = optimize.unwrap_or(RustOptimize::Bool(true)); diff --git a/src/bootstrap/src/lib.rs b/src/bootstrap/src/lib.rs index 871318de5955e..0726fc6ced40f 100644 --- a/src/bootstrap/src/lib.rs +++ b/src/bootstrap/src/lib.rs @@ -76,6 +76,9 @@ const LLD_FILE_NAMES: &[&str] = &["ld.lld", "ld64.lld", "lld-link", "wasm-ld"]; /// (Mode restriction, config name, config values (if any)) const EXTRA_CHECK_CFGS: &[(Option, &str, Option<&[&'static str]>)] = &[ (None, "bootstrap", None), + (Some(Mode::Rustc), "llvm_enzyme", None), + (Some(Mode::Codegen), "llvm_enzyme", None), + (Some(Mode::ToolRustc), "llvm_enzyme", None), (Some(Mode::Rustc), "parallel_compiler", None), (Some(Mode::ToolRustc), "parallel_compiler", None), (Some(Mode::ToolRustc), "rust_analyzer", None), @@ -165,6 +168,7 @@ pub struct Build { clippy_info: GitInfo, miri_info: GitInfo, rustfmt_info: GitInfo, + enzyme_info: GitInfo, in_tree_llvm_info: GitInfo, local_rebuild: bool, fail_fast: bool, @@ -328,6 +332,7 @@ impl Build { let clippy_info = GitInfo::new(omit_git_hash, &src.join("src/tools/clippy")); let miri_info = GitInfo::new(omit_git_hash, &src.join("src/tools/miri")); let rustfmt_info = GitInfo::new(omit_git_hash, &src.join("src/tools/rustfmt")); + let enzyme_info = GitInfo::new(omit_git_hash, &src.join("src/tools/enzyme")); // we always try to use git for LLVM builds let in_tree_llvm_info = GitInfo::new(false, &src.join("src/llvm-project")); @@ -410,6 +415,7 @@ impl Build { clippy_info, miri_info, rustfmt_info, + enzyme_info, in_tree_llvm_info, cc: RefCell::new(HashMap::new()), cxx: RefCell::new(HashMap::new()), @@ -799,6 +805,10 @@ impl Build { self.out.join(&*target.triple).join("lld") } + fn enzyme_out(&self, target: TargetSelection) -> PathBuf { + self.out.join(&*target.triple).join("enzyme") + } + /// Output directory for all documentation for a target fn doc_out(&self, target: TargetSelection) -> PathBuf { self.out.join(&*target.triple).join("doc") @@ -1856,6 +1866,9 @@ pub fn generate_smart_stamp_hash(dir: &Path, additional_input: &str) -> String { .map(|o| String::from_utf8(o.stdout).unwrap_or_default()) .unwrap_or_default(); + eprintln!("Computing stamp for {dir:?}"); + eprintln!("Diff output: {diff:?}"); + let status = Command::new("git") .current_dir(dir) .arg("status") @@ -1866,13 +1879,19 @@ pub fn generate_smart_stamp_hash(dir: &Path, additional_input: &str) -> String { .map(|o| String::from_utf8(o.stdout).unwrap_or_default()) .unwrap_or_default(); + eprintln!("Status output: {status:?}"); + eprintln!("Additional input: {additional_input:?}"); let mut hasher = sha2::Sha256::new(); hasher.update(diff); hasher.update(status); hasher.update(additional_input); - hex_encode(hasher.finalize().as_slice()) + let result = hex_encode(hasher.finalize().as_slice()); + + eprintln!("Final hash: {result:?}"); + + result } /// Ensures that the behavior dump directory is properly initialized. diff --git a/src/tools/enzyme b/src/tools/enzyme new file mode 160000 index 0000000000000..4b20c482cb0bc --- /dev/null +++ b/src/tools/enzyme @@ -0,0 +1 @@ +Subproject commit 4b20c482cb0bcffd18b462613e6e68cf23c0a44b