diff --git a/CHANGELOG.md b/CHANGELOG.md index e279273..d8544f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - ReleaseDate - Mark cervo_runtime::BrainId as #[must_use] +- Move the CLI tool to a separate crate `cervo-cli`. The installed name is unchanged. This avoids some dependencies. +- Upgrade tract to 0.20.0 ## [0.4.0] - 2022-11-23 diff --git a/Cargo.lock b/Cargo.lock index 331ee98..9cf51b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" version = "0.7.20" @@ -321,38 +332,12 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68b0cf012f1230e43cd00ebb729c6bb58707ecfa8ad08b52ef3a4ccd2697fc30" -[[package]] -name = "educe" -version = "0.4.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb0188e3c3ba8df5753894d54461f0e39bc91741dc5b22e1c46999ec2c71f4e4" -dependencies = [ - "enum-ordinalize", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "either" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" -[[package]] -name = "enum-ordinalize" -version = "3.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bb1df8b45ecb7ffa78dca1c17a438fb193eb083db0b1b494d2a61bcb5096a" -dependencies = [ - "num-bigint", - "num-traits", - "proc-macro2", - "quote", - "rustc_version", - "syn", -] - [[package]] name = "errno" version = "0.2.8" @@ -436,6 +421,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash", +] + [[package]] name = "heck" version = "0.4.1" @@ -619,16 +613,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" -[[package]] -name = "mapr" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46a28a55dbc005b2f6f123c4058933d57add373d362f6fd3a76aab4fe6973500" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "matrixmultiply" version = "0.3.2" @@ -644,6 +628,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "memmap2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327" +dependencies = [ + "libc", +] + [[package]] name = "memoffset" version = "0.8.0" @@ -701,17 +694,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "num-bigint" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - [[package]] name = "num-complex" version = "0.4.3" @@ -873,6 +855,15 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "primal-check" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9df7f93fd637f083201473dab4fee2db4c429d32e55e3299980ab3957ab916a0" +dependencies = [ + "num-integer", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -1063,12 +1054,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" [[package]] -name = "rustc_version" -version = "0.4.0" +name = "rustfft" +version = "6.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +checksum = "e17d4f6cbdb180c9f4b2a26bbf01c4e647f1e1dea22fe8eb9db54198b32f9434" dependencies = [ - "semver", + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", + "version_check", ] [[package]] @@ -1109,12 +1106,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" -[[package]] -name = "semver" -version = "1.0.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" - [[package]] name = "serde" version = "1.0.154" @@ -1167,6 +1158,23 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "string-interner" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e2531d8525b29b514d25e275a43581320d587b86db302b9a7e464bac579648" +dependencies = [ + "cfg-if", + "hashbrown", + "serde", +] + [[package]] name = "strsim" version = "0.10.0" @@ -1249,9 +1257,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.20" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890" +checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446" dependencies = [ "itoa", "serde", @@ -1261,15 +1269,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" +checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" [[package]] name = "time-macros" -version = "0.2.8" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd80a657e71da814b8e5d60d3374fc6d35045062245d80224748ae522dd76f36" +checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4" dependencies = [ "time-core", ] @@ -1326,22 +1334,23 @@ dependencies = [ [[package]] name = "tract-core" -version = "0.17.9" +version = "0.20.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "952ec1bab744dab58a1a1e05813ef62d61f238e18622ed4e9e1eb787ce3f3958" +checksum = "eedc3f25ef2e089c5215846d82b413a7724e90310bca05f2340986adf56e4ec1" dependencies = [ "anyhow", "bit-set", "derive-new", "downcast-rs", "dyn-clone", - "educe", "lazy_static", "log", "maplit", "ndarray", + "num-complex", "num-integer", "num-traits", + "rustfft", "smallvec", "tract-data", "tract-linalg", @@ -1349,41 +1358,40 @@ dependencies = [ [[package]] name = "tract-data" -version = "0.17.9" +version = "0.20.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc4c08be3635ebe54fac659d67505a2d94a3d3fb915d386ae6338f9c458d69d" +checksum = "ec5bde8b5da75b623952beb1d4ea27423d11cda8c6b363ebc66a2149ecce1d58" dependencies = [ "anyhow", - "educe", "half", "itertools", "lazy_static", "maplit", "ndarray", - "num-complex", + "nom", "num-integer", "num-traits", "scan_fmt", "smallvec", + "string-interner", ] [[package]] name = "tract-hir" -version = "0.17.9" +version = "0.20.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c4a4d5bdb45c9219d1995cdc4f5d09a289c3b64b34c7752d291a8eff1675ee6" +checksum = "2b3f131093ebddff204870baaf9332ea84ab6bc6457e384eb8a7a0ee965994f4" dependencies = [ "derive-new", - "educe", "log", "tract-core", ] [[package]] name = "tract-linalg" -version = "0.17.9" +version = "0.20.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3b1ea52af343ecf897be8cc710f05f87e0beab5bcabfbd261f84a951368cb6e" +checksum = "dcfc8578923d78efb8543c2098a802d45a30a6114354e7297872e3a26d7388a8" dependencies = [ "cc", "derive-new", @@ -1391,7 +1399,6 @@ dependencies = [ "dyn-clone", "half", "lazy_static", - "libc", "liquid", "liquid-core", "log", @@ -1399,6 +1406,7 @@ dependencies = [ "paste", "scan_fmt", "smallvec", + "time", "tract-data", "unicode-normalization", "walkdir", @@ -1406,9 +1414,9 @@ dependencies = [ [[package]] name = "tract-nnef" -version = "0.17.9" +version = "0.20.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc5f95192ff8848db72caeac27d81ab85b9db87aca9dd23d9160bfc4976112f8" +checksum = "c977d820f9927b51ad4162adcf868f8105bb61b1aea6cccbc3bbbeac3eb152bf" dependencies = [ "byteorder", "flate2", @@ -1421,15 +1429,14 @@ dependencies = [ [[package]] name = "tract-onnx" -version = "0.17.9" +version = "0.20.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3c54f4a330e727137b87780a43fa6201a70aa9f9dbfab2a235ecd8951e5bf70" +checksum = "bcb7e31b8bb3173ba14992916ce4f552d57dbe7034b23143a2a38adf3c11067a" dependencies = [ "bytes", "derive-new", - "educe", "log", - "mapr", + "memmap2", "num-integer", "prost", "smallvec", @@ -1440,16 +1447,28 @@ dependencies = [ [[package]] name = "tract-onnx-opl" -version = "0.17.9" +version = "0.20.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18b584c79839543996ac030663d7b3ba997a63d833d30b6a07d71954cba43ecb" +checksum = "8f0e844f252fe7075ffafe6b5072f4122a14b2bafb345be0249661d51a7e029d" dependencies = [ - "educe", "getrandom", + "log", "rand", + "rand_distr", + "rustfft", "tract-nnef", ] +[[package]] +name = "transpose" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6522d49d03727ffb138ae4cbc1283d3774f0d10aa7f9bf52e6784c45daf9b23" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "typenum" version = "1.16.0" diff --git a/Cargo.toml b/Cargo.toml index 87a5a5a..0852d7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,9 @@ members = [ exclude = [ "benchmarks/perf-test" ] resolver = "2" + +[workspace.dependencies] +tract-core = { version = "0.20" } +tract-hir = { version = "0.20" } +tract-nnef = { version = "0.20" } +tract-onnx = { version = "0.20" } diff --git a/crates/cervo-core/Cargo.toml b/crates/cervo-core/Cargo.toml index ce749af..2217576 100644 --- a/crates/cervo-core/Cargo.toml +++ b/crates/cervo-core/Cargo.toml @@ -16,8 +16,8 @@ readme = "../../README.md" [dependencies] anyhow = { version = "1.0"} -tract-core = { version = "0.17.1" } -tract-hir = { version = "0.17.1" } +tract-core = { workspace = true } +tract-hir = { workspace = true } rand = { version = "0.8.2" } rand_distr = { version = "0.4" } perchance = { version = "0.4", default-features = false } diff --git a/crates/cervo-core/src/inferer/basic.rs b/crates/cervo-core/src/inferer/basic.rs index 29bcf07..39dcb80 100644 --- a/crates/cervo-core/src/inferer/basic.rs +++ b/crates/cervo-core/src/inferer/basic.rs @@ -4,7 +4,7 @@ A basic unbatched inferer that doesn't require a lot of custom setup or manageme use super::Inferer; use crate::{batcher::ScratchPadView, model_api::ModelApi}; use anyhow::Result; -use tract_core::prelude::{tvec, TVec, Tensor, TractResult, TypedModel, TypedSimplePlan}; +use tract_core::prelude::{tvec, TValue, TVec, Tensor, TractResult, TypedModel, TypedSimplePlan}; use tract_hir::prelude::InferenceModel; use super::helpers; @@ -46,7 +46,7 @@ impl BasicInferer { Ok(Self { model, model_api }) } - fn build_inputs(&self, obs: &ScratchPadView<'_>) -> Result> { + fn build_inputs(&self, obs: &ScratchPadView<'_>) -> Result> { let mut inputs = TVec::default(); for (idx, (name, shape)) in self.model_api.inputs.iter().enumerate() { @@ -60,7 +60,7 @@ impl BasicInferer { let tensor = Tensor::from_shape(&full_shape, obs.input_slot(idx))?; - inputs.push(tensor); + inputs.push(tensor.into()); } Ok(inputs) diff --git a/crates/cervo-core/src/inferer/dynamic.rs b/crates/cervo-core/src/inferer/dynamic.rs index 790ef97..4fd7ba6 100644 --- a/crates/cervo-core/src/inferer/dynamic.rs +++ b/crates/cervo-core/src/inferer/dynamic.rs @@ -1,7 +1,7 @@ use super::{helpers, Inferer}; use crate::{batcher::ScratchPadView, model_api::ModelApi}; use anyhow::Result; -use tract_core::prelude::{tvec, TVec, Tensor, TractResult, TypedModel, TypedSimplePlan}; +use tract_core::prelude::{tvec, TValue, TVec, Tensor, TractResult, TypedModel, TypedSimplePlan}; use tract_hir::prelude::InferenceModel; /// The dynamic inferer hits a spot between the raw simplicity of a [`crate::prelude::BasicInferer`] and the spikiness @@ -58,7 +58,7 @@ impl DynamicInferer { Ok(this) } - fn build_inputs(&self, batch: &ScratchPadView<'_>) -> Result> { + fn build_inputs(&self, batch: &ScratchPadView<'_>) -> Result> { let size = batch.len(); let mut inputs = TVec::default(); @@ -76,7 +76,7 @@ impl DynamicInferer { let tensor = Tensor::from_shape(&shape, batch.input_slot(idx))?; - inputs.push(tensor); + inputs.push(tensor.into()); } Ok(inputs) diff --git a/crates/cervo-core/src/inferer/fixed.rs b/crates/cervo-core/src/inferer/fixed.rs index 3aadd9d..d9938fe 100644 --- a/crates/cervo-core/src/inferer/fixed.rs +++ b/crates/cervo-core/src/inferer/fixed.rs @@ -1,7 +1,7 @@ use super::{helpers, Inferer}; use crate::{batcher::ScratchPadView, model_api::ModelApi}; use anyhow::{Context, Result}; -use tract_core::prelude::{tvec, TVec, Tensor, TractResult, TypedModel, TypedSimplePlan}; +use tract_core::prelude::{tvec, TValue, TVec, Tensor, TractResult, TypedModel, TypedSimplePlan}; use tract_hir::prelude::InferenceModel; /// A reliable batched inferer that is a good fit if you know how much data you'll have and want stable performance. @@ -122,7 +122,7 @@ impl BatchedModel { &self, batch: &ScratchPadView<'_>, model_api: &ModelApi, - ) -> Result> { + ) -> Result> { assert_eq!(batch.len(), self.size); let size = self.size; @@ -148,7 +148,7 @@ impl BatchedModel { let tensor = Tensor::from_shape(&shape, batch.input_slot(idx))?; - inputs.push(tensor); + inputs.push(tensor.into()); } Ok(inputs) diff --git a/crates/cervo-core/src/inferer/helpers.rs b/crates/cervo-core/src/inferer/helpers.rs index 08bc51b..4616d7f 100644 --- a/crates/cervo-core/src/inferer/helpers.rs +++ b/crates/cervo-core/src/inferer/helpers.rs @@ -17,16 +17,17 @@ pub(super) fn build_symbolic_model( mut model: InferenceModel, inputs: &[(String, Vec)], ) -> TractResult<(Symbol, TypedModel)> { - let s = Symbol::from('N'); + model.set_output_fact(0, Default::default())?; + let symbol = model.symbol_table.sym("N"); for (idx, (_name, shape)) in inputs.iter().enumerate() { - let mut full_shape = tvec!(s.to_dim()); + let mut full_shape = tvec!(symbol.to_dim()); full_shape.extend(shape.iter().map(|v| (*v as i32).into())); model.set_input_fact(idx, InferenceFact::dt_shape(f32::datum_type(), full_shape))?; } let model = model.into_typed()?.into_decluttered()?; - Ok((s, model)) + Ok((symbol, model)) } pub(super) fn build_model( @@ -34,6 +35,7 @@ pub(super) fn build_model( inputs: &[(String, Vec)], batch_dim: D, ) -> TractResult> { + model.set_output_fact(0, Default::default())?; for (idx, (_name, shape)) in inputs.iter().enumerate() { let mut full_shape = tvec!(batch_dim.to_dim()); @@ -50,16 +52,16 @@ pub(super) fn build_model( pub(super) fn build_symbolic_typed(model: &mut TypedModel) -> TractResult { model.declutter()?; - Ok(Symbol::from('N')) + Ok(model.symbol_table.sym("N")) } pub(super) fn build_typed( model: TypedModel, batch_dim: D, ) -> TractResult> { - let symbol = Symbol::from('N'); + let symbol = model.symbol_table.sym("N"); let model = model.concretize_dims( - &SymbolValues::default().with(symbol, batch_dim.to_dim().to_i64().unwrap()), + &SymbolValues::default().with(&symbol, batch_dim.to_dim().to_i64().unwrap()), )?; model.into_decluttered()?.into_optimized()?.into_runnable() diff --git a/crates/cervo-core/src/inferer/memoizing.rs b/crates/cervo-core/src/inferer/memoizing.rs index 09beeee..19dc208 100644 --- a/crates/cervo-core/src/inferer/memoizing.rs +++ b/crates/cervo-core/src/inferer/memoizing.rs @@ -40,7 +40,7 @@ use tract_hir::prelude::*; /// # Cons /// /// * For small amounts of data and large models the spikes can offset -/// amortized gains signifcantly +/// amortized gains significantly pub struct MemoizingDynamicInferer { symbol: Symbol, @@ -96,7 +96,7 @@ impl MemoizingDynamicInferer { Ok(this) } - fn build_inputs(&self, batch: &ScratchPadView<'_>) -> Result> { + fn build_inputs(&self, batch: &ScratchPadView<'_>) -> Result> { let size = batch.len(); let mut inputs = TVec::default(); @@ -114,7 +114,7 @@ impl MemoizingDynamicInferer { let tensor = Tensor::from_shape(&shape, batch.input_slot(idx))?; - inputs.push(tensor); + inputs.push(tensor.into()); } Ok(inputs) @@ -131,7 +131,7 @@ impl MemoizingDynamicInferer { if let Entry::Vacant(e) = content.entry(size) { let p = self .model - .concretize_dims(&SymbolValues::default().with(self.symbol, size as i64))? + .concretize_dims(&SymbolValues::default().with(&self.symbol, size as i64))? .into_optimized()? .into_decluttered()? .into_runnable()?; diff --git a/crates/cervo-core/src/model_api.rs b/crates/cervo-core/src/model_api.rs index 60722c6..a0c8f1f 100644 --- a/crates/cervo-core/src/model_api.rs +++ b/crates/cervo-core/src/model_api.rs @@ -31,7 +31,8 @@ impl ModelApi { name, input_shape .dims() - .filter_map(|value| value.concretize().map(|v| v.to_i64().unwrap() as usize)) + .filter_map(|value| value.concretize().and_then(|v| v.to_i64().ok())) + .map(|val| val as usize) .collect(), )); } @@ -51,7 +52,8 @@ impl ModelApi { name, output_shape .dims() - .filter_map(|value| value.concretize().map(|v| v.to_i64().unwrap() as usize)) + .filter_map(|value| value.concretize().and_then(|v| v.to_i64().ok())) + .map(|val| val as usize) .collect(), )); } diff --git a/crates/cervo-nnef/Cargo.toml b/crates/cervo-nnef/Cargo.toml index e4c84e7..43a251b 100644 --- a/crates/cervo-nnef/Cargo.toml +++ b/crates/cervo-nnef/Cargo.toml @@ -14,8 +14,9 @@ readme = "../../README.md" [dependencies] anyhow = "1.0.57" -tract-onnx = { version = "0.17.1" } -tract-nnef = { version = "0.17.1" } -tract-hir = { version = "0.17.1" } -cervo-core = { version= "0.4.1-alpha.0", path = "../cervo-core" } +tract-onnx = { workspace = true } +tract-nnef = { workspace = true } +tract-hir = { workspace = true } lazy_static = "1.4" + +cervo-core = { version= "0.4.1-alpha.0", path = "../cervo-core" } diff --git a/crates/cervo-nnef/tests/infer-complex.rs b/crates/cervo-nnef/tests/infer-complex.rs index eba8165..71d8f58 100644 --- a/crates/cervo-nnef/tests/infer-complex.rs +++ b/crates/cervo-nnef/tests/infer-complex.rs @@ -64,7 +64,7 @@ fn test_infer_once_complex_batched_not_loaded() { let shapes = instance.input_shapes().to_vec(); let observations = helpers::build_inputs_from_desc(10, &shapes); let result = instance.infer_batch(observations); - eprintln!("result {:?}", result); + assert!(result.is_ok()); let result = result.unwrap(); diff --git a/crates/cervo-onnx/Cargo.toml b/crates/cervo-onnx/Cargo.toml index cc968b9..396add1 100644 --- a/crates/cervo-onnx/Cargo.toml +++ b/crates/cervo-onnx/Cargo.toml @@ -14,6 +14,6 @@ readme = "../../README.md" [dependencies] anyhow = "1.0.57" -tract-onnx = { version = "0.17.1" } -tract-nnef = { version = "0.17.1" } +tract-onnx = { workspace = true } +tract-nnef = { workspace = true } cervo-core = { version = "0.4.1-alpha.0", path = "../cervo-core" } diff --git a/crates/cervo-onnx/src/lib.rs b/crates/cervo-onnx/src/lib.rs index 3d7f769..0219de0 100644 --- a/crates/cervo-onnx/src/lib.rs +++ b/crates/cervo-onnx/src/lib.rs @@ -90,10 +90,14 @@ pub fn builder(read: T) -> InfererBuilder> { pub fn to_nnef(reader: &mut dyn Read, batch_size: Option) -> Result> { let mut model = model_for_reader(reader)?; - let batch = batch_size.map_or_else(|| Symbol::from('N').to_dim(), |v| v.to_dim()); - + model.set_output_fact(0, Default::default())?; + let symbol = model.symbol_table.sym("N"); let input_outlets = model.input_outlets()?.to_vec(); + let batch = batch_size + .map(|b| b.to_dim()) + .unwrap_or(symbol.clone().to_dim()); + for input_outlet in input_outlets { let input_shape = &model.input_fact(input_outlet.node)?.shape;