From b783f42833294eb7ad4eafa129e54ad2aabf5fe2 Mon Sep 17 00:00:00 2001 From: Francois Galilee Date: Thu, 26 Sep 2024 12:22:13 -0400 Subject: [PATCH] go: add a Go library for Magika This change introduces a Go library for Magika. The library implements feature extraction and relies on the ONNX Runtime for inference, wrapping the C API using cgo. --- go/README.md | 17 ++++ go/cli/cli.go | 50 ++++++++++ go/cli/cli_test.go | 31 ++++++ go/cli/main.go | 26 +++++ go/cli/tests_data/magika_test.zip | Bin 0 -> 899 bytes go/cli/tests_data/magika_test_pptx.txt | 3 + go/docker/Dockerfile | 69 +++++++++++++ go/go.mod | 5 + go/go.sum | 2 + go/magika/config.go | 62 ++++++++++++ go/magika/content.go | 44 +++++++++ go/magika/features.go | 130 +++++++++++++++++++++++++ go/magika/features_test.go | 49 ++++++++++ go/magika/scanner.go | 103 ++++++++++++++++++++ go/magika/scanner_test.go | 92 +++++++++++++++++ go/onnx/onnx.go | 7 ++ go/onnx/onnx_runtime.go | 42 ++++++++ go/onnx/onnx_runtime.h | 48 +++++++++ go/onnx/onnx_runtime_test.go | 44 +++++++++ go/onnx/onnx_zero.go | 9 ++ 20 files changed, 833 insertions(+) create mode 100644 go/README.md create mode 100644 go/cli/cli.go create mode 100644 go/cli/cli_test.go create mode 100644 go/cli/main.go create mode 100644 go/cli/tests_data/magika_test.zip create mode 100644 go/cli/tests_data/magika_test_pptx.txt create mode 100644 go/docker/Dockerfile create mode 100644 go/go.mod create mode 100644 go/go.sum create mode 100644 go/magika/config.go create mode 100644 go/magika/content.go create mode 100644 go/magika/features.go create mode 100644 go/magika/features_test.go create mode 100644 go/magika/scanner.go create mode 100644 go/magika/scanner_test.go create mode 100644 go/onnx/onnx.go create mode 100644 go/onnx/onnx_runtime.go create mode 100644 go/onnx/onnx_runtime.h create mode 100644 go/onnx/onnx_runtime_test.go create mode 100644 go/onnx/onnx_zero.go diff --git a/go/README.md b/go/README.md new file mode 100644 index 00000000..ba4f1add --- /dev/null +++ b/go/README.md @@ -0,0 +1,17 @@ +# Go library + +This directory contains the Go library for Magika. + +The inference relies on the [ONNX Runtime](https://onnxruntime.ai/), and it +requires [cgo](https://go.dev/blog/cgo) for interfacing with the ONNX Runtime +[C API](https://onnxruntime.ai/docs/api/c/). + +- [`docker`](./docker) contains a sample docker file that builds a +container image that ties together a Magika CLI, an ONNX Runtime, +and a [model](../assets/models/standard_v2_1). +- [`cli`](./cli) contains a basic CLI that illustrates how to +the Magika go library may be called from within an application. +- [`magika`](./magika) contains the library, that extracts +features from a sequence of bytes. +- [`onnx`](./onnx) wraps the C API of the ONNX Runtime to +provide an inference engine. \ No newline at end of file diff --git a/go/cli/cli.go b/go/cli/cli.go new file mode 100644 index 00000000..4908bbf6 --- /dev/null +++ b/go/cli/cli.go @@ -0,0 +1,50 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "os" + + "github.com/google/magika/magika" +) + +const ( + assetsDirEnv = "MAGIKA_ASSETS_DIR" + modelNameEnv = "MAGIKA_MODEL" +) + +// cli is a basic CLI that infers the content type of the files listed on +// the command line. The assets dir and the model name are given via the +// environment variable MAGIKA_ASSETS_DIR and MAGIKA_MODEL respectively. +func cli(w io.Writer, args ...string) error { + assetsDir := os.Getenv(assetsDirEnv) + if assetsDir == "" { + return fmt.Errorf("%s environment variable not set or empty", assetsDirEnv) + } + modelName := os.Getenv(modelNameEnv) + if modelName == "" { + return fmt.Errorf("%s environment variable not set or empty", modelNameEnv) + } + s, err := magika.NewScanner(assetsDir, modelName) + if err != nil { + return fmt.Errorf("create scanner: %w", err) + } + + // For each filename given as argument, read the file and scan its content. + for _, a := range args { + fmt.Fprintf(w, "%s: ", a) + b, err := os.ReadFile(a) + if err != nil { + fmt.Fprintf(w, "%v\n", err) + continue + } + ct, err := s.Scan(bytes.NewReader(b), len(b)) + if err != nil { + fmt.Fprintf(w, "scan: %v\n", err) + continue + } + fmt.Fprintf(w, "%s\n", ct.Label) + } + return nil +} diff --git a/go/cli/cli_test.go b/go/cli/cli_test.go new file mode 100644 index 00000000..8d16593a --- /dev/null +++ b/go/cli/cli_test.go @@ -0,0 +1,31 @@ +//go:build cgo && onnxruntime + +package main + +import ( + "path" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestCLI(t *testing.T) { + const basicDir = "../../tests_data/basic" + var ( + files = []string{ + path.Join(basicDir, "python/code.py"), + path.Join(basicDir, "zip/magika_test.zip"), + } + b strings.Builder + ) + if err := cli(&b, files...); err != nil { + t.Fatal(err) + } + if d := cmp.Diff(strings.TrimSpace(b.String()), strings.Join([]string{ + "../../tests_data/basic/python/code.py: python", + "../../tests_data/basic/zip/magika_test.zip: zip", + }, "\n")); d != "" { + t.Errorf("mismatch (-want +got):\n%s", d) + } +} diff --git a/go/cli/main.go b/go/cli/main.go new file mode 100644 index 00000000..800f0ad6 --- /dev/null +++ b/go/cli/main.go @@ -0,0 +1,26 @@ +/* +CLI is a simple command line interface for magika. + +It takes a list of files as argument, and infers their types in sequence. +For example: + + $ magika test.go readme.md + test.go: go + readme.md: markdown + +The primary intent is to illustrate how the magika go library can be used +and compiled, using cgo and the ONNX Runtime library. +*/ +package main + +import ( + "fmt" + "os" +) + +func main() { + if err := cli(os.Stdout, os.Args[1:]...); err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } +} diff --git a/go/cli/tests_data/magika_test.zip b/go/cli/tests_data/magika_test.zip new file mode 100644 index 0000000000000000000000000000000000000000..dc27e7da9e697ccc21fc36234aa9f1ee34a5eab5 GIT binary patch literal 899 zcmWIWW@Zs#;Nak3Sne$w$$$jJfo$Kz^vvwUkksN5m;B_?+|;}hy^NCFoTJlT=g&3} zIPhPzzxL)5CPhV^R+Z|ETU+=q?^+qWywA}KIZAJu{zgq>~bNahvyOR79n-EL#DT?tP}#>T3P1&vH73U$1_PWOfBZxo2;ZQ zr+s5W5`2yeiCYTX%@lPBlDb^JX2lMNpzc3Y#is4cO=e^D(0*!jcHuOIDBrLn+}44) zD^u^ElknR&_u$j@=bwFl-F@(V|6=}#IU6HC{VP52In~kil-GXo5R<;hnZNdmUJ6h? zqMoZ9{qEVC72A6%&b9vEbls+LTgth{rINqay~@roZtLLg>X}}5(wk$QgF<&k-laK9 zm&ylmS_VJv-H=j$pnMM3^}8)%I`Kb_@yiuGaog?I?)p-Q^{i;iW=Erm;dT3$ru-q#!7g%Xw_UVE9oOhY zwr}55b$0W7&cn%J$E$^I%qdm<{`Tk2Rb}ZX{nd_^eHmwu;=lDjfA10vJ`EzIrY2gyp5r^=L1X z#8sabrLI}+S-V?!*}b4lo!3e%julQjv=~~iMFvc{W;b_j^6J`4OJBwZc(ZfVGwOLg sW@2Di$i=`A;LXS+!hpyr$a0{Zf(mftn*eWCHjqL_Aan)N8q6Rb02u0x{r~^~ literal 0 HcmV?d00001 diff --git a/go/cli/tests_data/magika_test_pptx.txt b/go/cli/tests_data/magika_test_pptx.txt new file mode 100644 index 00000000..6bca6ec4 --- /dev/null +++ b/go/cli/tests_data/magika_test_pptx.txt @@ -0,0 +1,3 @@ +This is a test for Magika! + +Very cool if this can be detected correctly! diff --git a/go/docker/Dockerfile b/go/docker/Dockerfile new file mode 100644 index 00000000..2dd536af --- /dev/null +++ b/go/docker/Dockerfile @@ -0,0 +1,69 @@ +# Sample Dockerfile to build an image that ties together an ONNX Runtime, +# a Magika model, and a Magika CLI. +# +# It expects the root of the repository as build context: +# $ docker build -f go/docker/Dockerfile -t magika-go:latest . +# +# Then, to list the content type of the files in the current directory: +# docker run --rm --name magika-go -v $PWD:$PWD:ro -w $PWD magika-go:latest * + +# Build stage for ONNX Runtime and magika. +FROM golang:latest AS build + +# Work in a clean temp directory. +WORKDIR /tmp + +# Download, check, and install ONNX Runtime (https://onnxruntime.ai/) in +# /opt/onnxruntime. +# Releases are located at https://github.com/microsoft/onnxruntime/releases. +# We need the SDK (/include) for compiling, and the library (/lib) for inference. +ARG ONNX_NAME=onnxruntime +ARG ONNX_ARCH=linux-x64 +ARG ONNX_VERSION=1.19.2 +ARG ONNX_FULLNAME=${ONNX_NAME}-${ONNX_ARCH}-${ONNX_VERSION} +ARG ONNX_TARBALL=${ONNX_FULLNAME}.tgz +ARG ONNX_DIGEST=eb00c64e0041f719913c4080e0fed7d9963dc3aa9b54664df6036d8308dbcd33 + +RUN curl -sL -O https://github.com/microsoft/${ONNX_NAME}/releases/download/v${ONNX_VERSION}/${ONNX_TARBALL} \ + && echo "${ONNX_DIGEST} ${ONNX_TARBALL}" > checksum.txt \ + && sha256sum -c checksum.txt \ + && tar -xzf ${ONNX_TARBALL} -C /opt \ + && ln -s /opt/${ONNX_FULLNAME} /opt/onnxruntime + +# Retrieve the magika go code from the build context, test, and build the cli. +COPY go go/ +COPY tests_data tests_data/ +COPY assets/content_types_kb.min.json assets/content_types_kb.min.json +COPY assets/models/standard_v2_1 assets/models/standard_v2_1/ + +ARG CGO_ENABLED=1 +ARG CGO_CFLAGS=-I/opt/onnxruntime/include +ARG LD_LIBRARY_PATH=/opt/onnxruntime/lib + +# Run the tests. +WORKDIR go +RUN MAGIKA_ASSETS_DIR=../../assets \ + MAGIKA_MODEL=standard_v2_1 \ + go test -tags onnxruntime -ldflags="-linkmode=external -extldflags=-L/opt/onnxruntime/lib" ./... + +# Build the CLI. +WORKDIR cli +RUN go build -tags onnxruntime -ldflags="-linkmode=external -extldflags=-L/opt/onnxruntime/lib" . + + +# Final stage: copy resources from the build and set environment variables. +FROM debian:latest + +# Add the ONNX Runtime. +ENV LD_LIBRARY_PATH=/opt/onnxruntime/lib +COPY --from=build /opt/onnxruntime/lib ${LD_LIBRARY_PATH} + +# Magika model. +ENV MAGIKA_ASSETS_DIR=/opt/magika/assets +ENV MAGIKA_MODEL=standard_v2_1 +COPY assets/models/${MAGIKA_MODEL} ${MAGIKA_ASSETS_DIR}/models/${MAGIKA_MODEL}/ +COPY assets/content_types_kb.min.json ${MAGIKA_ASSETS_DIR}/content_types_kb.min.json + +# Magika CLI. +COPY --from=build /tmp/go/cli/cli /usr/local/bin/magika +ENTRYPOINT ["magika"] \ No newline at end of file diff --git a/go/go.mod b/go/go.mod new file mode 100644 index 00000000..025c1e68 --- /dev/null +++ b/go/go.mod @@ -0,0 +1,5 @@ +module github.com/google/magika + +go 1.22.3 + +require github.com/google/go-cmp v0.6.0 // indirect diff --git a/go/go.sum b/go/go.sum new file mode 100644 index 00000000..5a8d551d --- /dev/null +++ b/go/go.sum @@ -0,0 +1,2 @@ +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= diff --git a/go/magika/config.go b/go/magika/config.go new file mode 100644 index 00000000..ca94bd55 --- /dev/null +++ b/go/magika/config.go @@ -0,0 +1,62 @@ +package magika + +import ( + "encoding/json" + "fmt" + "os" + "path" +) + +const ( + configFile = "config.min.json" + contentTypesKBFile = "content_types_kb.min.json" + modelFile = "model.onnx" + modelsDir = "models" +) + +// Config holds the portion of Magika's model configuration that is relevant +// for inference. +type Config struct { + BegSize int `json:"beg_size"` + MidSize int `json:"mid_size"` + EndSize int `json:"end_size"` + UseInputsAtOffsets bool `json:"use_inputs_at_offsets"` + MediumConfidenceThreshold float32 `json:"medium_confidence_threshold"` + MinFileSizeForDl int64 `json:"min_file_size_for_dl"` + PaddingToken int `json:"padding_token"` + BlockSize int `json:"block_size"` + TargetLabelsSpace []string `json:"target_labels_space"` + Thresholds map[string]float32 `json:"thresholds"` +} + +// ReadConfig is a helper that reads and unmarshal a Config, given an assets +// dir and a model name. +func ReadConfig(assetsDir, name string) (Config, error) { + var cfg Config + p := configPath(assetsDir, name) + b, err := os.ReadFile(p) + if err != nil { + return Config{}, fmt.Errorf("read %q: %w", p, err) + } + if err := json.Unmarshal(b, &cfg); err != nil { + return Config{}, fmt.Errorf("unmarshal: %w", err) + } + return cfg, nil +} + +// contentTypesKBPath returns the content types KB path for the given +// asset folder. +func contentTypesKBPath(assetDir string) string { + return path.Join(assetDir, contentTypesKBFile) +} + +// configPath returns the model config for the given asset folder and model +// name. +func configPath(assetDir, name string) string { + return path.Join(assetDir, modelsDir, name, configFile) +} + +// modelPath returns the Onnx model for the given asset folder and model name. +func modelPath(assetDir, name string) string { + return path.Join(assetDir, modelsDir, name, modelFile) +} diff --git a/go/magika/content.go b/go/magika/content.go new file mode 100644 index 00000000..7ad4a36e --- /dev/null +++ b/go/magika/content.go @@ -0,0 +1,44 @@ +package magika + +import ( + "encoding/json" + "fmt" + "os" +) + +const ( + contentTypeLabelEmpty = "empty" + contentTypeLabelTxt = "txt" + contentTypeLabelUnknown = "unknown" +) + +// ContentType holds the definition of a content type. +type ContentType struct { + Label string // As keyed in the content types KB. + MimeType string `json:"mime_type"` + Group string `json:"group"` + Description string `json:"description"` + Extensions []string `json:"extensions"` + IsText bool `json:"is_text"` +} + +// readContentTypesKB is a helper that reads and unmarshal a content types KB, +// given the assets dir. +// It returns a dictionary that maps a label as defined in the model config +// target label space to a content type. +func readContentTypesKB(assetsDir string) (map[string]ContentType, error) { + var ckb map[string]ContentType + p := contentTypesKBPath(assetsDir) + b, err := os.ReadFile(p) + if err != nil { + return nil, fmt.Errorf("read %q: %w", p, err) + } + if err := json.Unmarshal(b, &ckb); err != nil { + return nil, fmt.Errorf("unmarshal: %w", err) + } + for label, ct := range ckb { + ct.Label = label + ckb[label] = ct + } + return ckb, nil +} diff --git a/go/magika/features.go b/go/magika/features.go new file mode 100644 index 00000000..3338d1b3 --- /dev/null +++ b/go/magika/features.go @@ -0,0 +1,130 @@ +package magika + +import ( + "bytes" + "fmt" + "io" +) + +// Features holds the features of a give slice of bytes. +type Features struct { + firstBlock []byte + Beg []int32 `json:"beg"` + Mid []int32 `json:"mid"` + End []int32 `json:"end"` + Offset8000 []int32 `json:"offset_0x8000_0x8007"` + Offset8800 []int32 `json:"offset_0x8800_0x8807"` + Offset9000 []int32 `json:"offset_0x9000_0x9007"` + Offset9800 []int32 `json:"offset_0x9800_0x9807"` +} + +// ExtractFeatures extract the features from the given reader. +// The number of bytes that can be read from the reader is given by size. +func ExtractFeatures(cfg Config, r io.ReaderAt, size int) (Features, error) { + var ( + er = errReader{r: r, sz: size} + beg = er.readAt(0, cfg.BlockSize) + mid = er.readAt((size-cfg.MidSize)/2, cfg.MidSize) + end = er.readAt(size-cfg.BlockSize, cfg.BlockSize) + ) + f := buildFeatures(cfg, beg, mid, end) + + peek := func(off int) []int32 { + b := er.readAt(off, 8) + if len(b) < 8 { + b = nil + } + return padInt32(cfg, b, 0, 8) + } + f.Offset8000 = peek(0x8000) + f.Offset8800 = peek(0x8800) + f.Offset9000 = peek(0x9000) + f.Offset9800 = peek(0x9800) + + if er.err != nil { + return Features{}, er.err + } + return f, nil +} + +// Flatten returns a flattened array of the given features. +func (f Features) Flatten() []int32 { + res := make([]int32, 0, len(f.Beg)+len(f.Mid)+len(f.End)) + res = append(res, f.Beg...) + res = append(res, f.Mid...) + res = append(res, f.End...) + return res +} + +// errReader wraps an io.ReaderAt and accumulates errors that may arise during +// reading. It also silently protects against out of range read. +// This allows for a simpler parsing code flow with a unique error check at +// the end of parsing. +type errReader struct { + r io.ReaderAt + sz int + err error +} + +func (e *errReader) readAt(off, n int) []byte { + if e.err != nil || off >= e.sz { + return nil + } + if off < 0 { + n += off + off = 0 + } + n = min(n, e.sz-off) + b := make([]byte, n) + p, err := e.r.ReadAt(b, int64(max(off, 0))) + if err != nil && err != io.EOF { + e.err = fmt.Errorf("read %d bytes at %d: %w", n, max(off, 0), err) + return nil + } + return b[:p] +} + +// buildFeatures builds features from the beg, mid, and end bytes. +func buildFeatures(cfg Config, beg, mid, end []byte) Features { + firstBlock := beg + + spaces := string([]rune{'\t', '\n', '\v', '\f', '\r', ' '}) + // Trim beg and end, and truncate to BegSize and EndSize. + beg = bytes.TrimLeft(beg, spaces) + end = bytes.TrimRight(end, spaces) + beg = safeSlice(beg, 0, cfg.BegSize) + end = safeSlice(end, len(end)-cfg.EndSize, len(end)) + + return Features{ + firstBlock: firstBlock, + Beg: padInt32(cfg, beg, 0, cfg.BegSize), + Mid: padInt32(cfg, mid, (cfg.MidSize-len(mid))/2, cfg.MidSize), + End: padInt32(cfg, end, cfg.EndSize-len(end), cfg.EndSize), + } +} + +// padInt32 pads and convert the given bytes into int32. +// The len of the returned is the given size. +// if prefix is non-zero, that many padding is add as prefix. +// then the given bytes are converted into int32 +// finally, padding occurs until the returned slice is of the given size. +func padInt32(cfg Config, b []byte, prefix, size int) []int32 { + r := make([]int32, 0, size) + for len(r) < prefix { + r = append(r, int32(cfg.PaddingToken)) + } + for _, bb := range b { + r = append(r, int32(bb)) + } + for len(r) < size { + r = append(r, int32(cfg.PaddingToken)) + } + return r +} + +// safeSlice returns a slice from the given array, silently clipping +// out-of-bound indices. This happens when the given input data contains +// fewer bytes than the sampling size. +func safeSlice(b []byte, from, to int) []byte { + return b[max(from, 0):min(to, len(b))] +} diff --git a/go/magika/features_test.go b/go/magika/features_test.go new file mode 100644 index 00000000..79fa0de1 --- /dev/null +++ b/go/magika/features_test.go @@ -0,0 +1,49 @@ +package magika + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "fmt" + "io" + "os" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestExtractFeatures(t *testing.T) { + f, err := os.Open("../../tests_data/features_extraction/reference.json.gz") + if err != nil { + t.Fatal(err) + } + r, err := gzip.NewReader(f) + if err != nil { + t.Fatalf("could not uncompress test data: %s", err) + } + b, err := io.ReadAll(r) + if err != nil { + t.Fatalf("could not read uncompress test data: %s", err) + } + + var cases []struct { + TestInfo Config `json:"test_info"` + Content []byte `json:"content"` + FeaturesV2 Features `json:"features_v2"` + } + if err := json.Unmarshal(b, &cases); err != nil { + t.Fatal(err) + } + for i, c := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + ft, err := ExtractFeatures(c.TestInfo, bytes.NewReader(c.Content), len(c.Content)) + if err != nil { + t.Fatal(err) + } + if d := cmp.Diff(ft, c.FeaturesV2, cmpopts.IgnoreUnexported(Features{})); d != "" { + t.Errorf("mismatch (-got +want):\n%s", d) + } + }) + } +} diff --git a/go/magika/scanner.go b/go/magika/scanner.go new file mode 100644 index 00000000..e8ee18e6 --- /dev/null +++ b/go/magika/scanner.go @@ -0,0 +1,103 @@ +package magika + +import ( + "errors" + "fmt" + "io" + "unicode/utf8" + + "github.com/google/magika/onnx" +) + +// Scanner represents a Magika scanner that returns the content type +// of the scanned content running the Magika model using an ONNX Runtime. +// This is a similar scanner interface to licensecheck, that scans +// content to identify licenses. +type Scanner struct { + onnx onnx.Onnx + cfg Config + ckb map[string]ContentType +} + +// NewScanner returns a scanner based on the model of the given name defined +// in the given the assets dir. +func NewScanner(assetsDir, name string) (*Scanner, error) { + cfg, err := ReadConfig(assetsDir, name) + if err != nil { + return nil, fmt.Errorf("read config: %w", err) + } + p := modelPath(assetsDir, name) + ob, err := onnx.NewOnnx(p, len(cfg.TargetLabelsSpace)) + if err != nil { + return nil, fmt.Errorf("new onnx: %w", err) + } + if ob == nil { + return nil, errors.New("new onnx: nil onnx object") + } + ckb, err := readContentTypesKB(assetsDir) + if err != nil { + return nil, fmt.Errorf("read content types KB: %w", err) + } + return &Scanner{ + onnx: ob, + cfg: cfg, + ckb: ckb, + }, nil +} + +// Scan scans the given reader containing the given size of bytes, and +// returns the inferred content type. +// It is safe for concurrent use. +func (s *Scanner) Scan(r io.ReaderAt, size int) (ContentType, error) { + if size == 0 { + return s.ckb[contentTypeLabelEmpty], nil + } + ft, err := ExtractFeatures(s.cfg, r, size) + if err != nil { + return ContentType{}, fmt.Errorf("extract features: %w", err) + } + // Do not use the model for small files. + if ft.Beg[s.cfg.MinFileSizeForDl-1] == int32(s.cfg.PaddingToken) { + if utf8.Valid(ft.firstBlock) { + return s.ckb[contentTypeLabelTxt], nil + } else { + return s.ckb[contentTypeLabelUnknown], nil + } + } + scores, err := s.onnx.Run(ft.Flatten()) + if err != nil { + return ContentType{}, fmt.Errorf("run onnx: %w", err) + } + if len(scores) == 0 { + return ContentType{}, errors.New("run onnx: empty result") + } + best := 0 + for i, v := range scores { + if v > scores[best] { + best = i + } + } + return s.contentType(best, scores[best]) +} + +func (s *Scanner) contentType(best int, score float32) (ContentType, error) { + l := s.cfg.TargetLabelsSpace[best] + ct, ok := s.ckb[l] + if !ok { + return ContentType{}, fmt.Errorf("no content type found for %q", l) + } + th := s.cfg.MediumConfidenceThreshold + if t, ok := s.cfg.Thresholds[l]; ok { + th = t + } + // Return the inferred content type if the threshold is met, otherwise + // falls back to a relevant default. + switch { + case score >= th: + return ct, nil + case ct.IsText: + return s.ckb[contentTypeLabelTxt], nil + default: + return s.ckb[contentTypeLabelUnknown], nil + } +} diff --git a/go/magika/scanner_test.go b/go/magika/scanner_test.go new file mode 100644 index 00000000..b29d3935 --- /dev/null +++ b/go/magika/scanner_test.go @@ -0,0 +1,92 @@ +//go:build cgo && onnxruntime + +package magika + +import ( + "bytes" + "os" + "path" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestScannerBasic(t *testing.T) { + const basicDir = "../../tests_data/basic" + es, err := os.ReadDir(basicDir) + if err != nil { + t.Fatalf("read tests data: %v", err) + } + s := newTestScanner(t) + for _, e := range es { + t.Run(e.Name(), func(t *testing.T) { + dir := path.Join(basicDir, e.Name()) + es, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("read tests data: %v", err) + } + for _, ee := range es { + p := path.Join(dir, ee.Name()) + fi, err := os.Stat(p) + if err != nil { + t.Fatalf("stat %s: %v", p, err) + } + f, err := os.Open(p) + if err != nil { + t.Fatalf("open %s: %v", p, err) + } + ct, err := s.Scan(f, int(fi.Size())) + if err != nil { + t.Fatalf("scan %s: %v", p, err) + } + if d := cmp.Diff(ct.Label, e.Name()); d != "" { + t.Errorf("unexpected content type for %s (-got +want):\n%s", ee.Name(), d) + } + } + }) + } +} + +func TestScannerSmall(t *testing.T) { + s := newTestScanner(t) + for _, c := range []struct { + name string + data []byte + want string + }{{ + name: "empty", + data: []byte{}, + want: contentTypeLabelEmpty, + }, { + name: "small txt", + data: []byte("small"), + want: contentTypeLabelTxt, + }, { + name: "small bin", + data: []byte{0x80, 0x80, 0x80, 0x80}, + want: contentTypeLabelUnknown, + }} { + t.Run(c.name, func(t *testing.T) { + ct, err := s.Scan(bytes.NewReader(c.data), len(c.data)) + if err != nil { + t.Fatalf("scan: %v", err) + } + if d := cmp.Diff(ct, s.ckb[c.want]); d != "" { + t.Errorf("unexpected content type (-got +want):\n%s", d) + } + }) + } +} + +func newTestScanner(t *testing.T) *Scanner { + t.Helper() + const ( + assetsDir = "../../assets" + modelName = "standard_v2_1" + ) + s, err := NewScanner(assetsDir, modelName) + if err != nil { + t.Fatalf("new scanner: %v", err) + } + return s +} diff --git a/go/onnx/onnx.go b/go/onnx/onnx.go new file mode 100644 index 00000000..e2632663 --- /dev/null +++ b/go/onnx/onnx.go @@ -0,0 +1,7 @@ +package onnx + +// Onnx represents something that can run inferences on features. +type Onnx interface { + // Run returns the result of the inference on the given features. + Run(features []int32) ([]float32, error) +} diff --git a/go/onnx/onnx_runtime.go b/go/onnx/onnx_runtime.go new file mode 100644 index 00000000..8aa785f4 --- /dev/null +++ b/go/onnx/onnx_runtime.go @@ -0,0 +1,42 @@ +//go:build cgo && onnxruntime + +package onnx + +// #cgo LDFLAGS: -lonnxruntime +// #include "onnx_runtime.h" +import "C" + +import ( + "fmt" +) + +// NewOnnx returns an onnx that can perform inferences using an ONNX Runtime +// (https://onnxruntime.ai/) and the given model. +// It wraps the C calls to the ONNX Runtime API https://onnxruntime.ai/docs/api/c. +func NewOnnx(modelPath string, sizeTarget int) (Onnx, error) { + ort := &onnxRuntime{ + api: C.GetApiBase(), + sizeTarget: sizeTarget, + } + if err := C.CreateSession(ort.api, C.CString(modelPath), &ort.session, &ort.memory); err != nil { + return nil, fmt.Errorf("create session: %v", C.GoString(C.GetErrorMessage(err))) + } + return ort, nil +} + +// onnxRuntime implements the Onnx interface relying on a cgo call +// to a C ONNX Runtime library. +type onnxRuntime struct { + api *C.OrtApi + session *C.OrtSession + memory *C.OrtMemoryInfo + sizeTarget int +} + +func (ort *onnxRuntime) Run(features []int32) ([]float32, error) { + target := make([]float32, ort.sizeTarget) + if err := C.Run(ort.api, ort.session, ort.memory, (*C.int32_t)(&features[0]), C.int64_t(len(features)), (*C.float)(&target[0]), C.int64_t(len(target))); err != nil { + return nil, fmt.Errorf("run: %v", C.GoString(C.GetErrorMessage(err))) + } + return target, nil +} diff --git a/go/onnx/onnx_runtime.h b/go/onnx/onnx_runtime.h new file mode 100644 index 00000000..6b589c64 --- /dev/null +++ b/go/onnx/onnx_runtime.h @@ -0,0 +1,48 @@ +#include +#include + +#define RETURN_ON_ERROR(expr) { \ + OrtStatus* onnx_status = (expr); \ + if (onnx_status != NULL) { \ + return onnx_status; \ + } \ +} + +const OrtApi *GetApiBase() { + return OrtGetApiBase()->GetApi(ORT_API_VERSION); +} + +OrtStatus *CreateSession(const OrtApi *ort, const char *model, OrtSession **session, OrtMemoryInfo **memory_info) { + OrtEnv *env; + RETURN_ON_ERROR(ort->CreateEnv(ORT_LOGGING_LEVEL_ERROR, "onnx", &env)); + RETURN_ON_ERROR(ort->DisableTelemetryEvents(env)); + OrtSessionOptions *options; + RETURN_ON_ERROR(ort->CreateSessionOptions(&options)); + RETURN_ON_ERROR(ort->EnableCpuMemArena(options)); + RETURN_ON_ERROR(ort->CreateSession(env, model, options, session)); + RETURN_ON_ERROR(ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, memory_info)); + return NULL; +} + +OrtStatus *Run(const OrtApi *ort, OrtSession *session, OrtMemoryInfo *memory_info, int32_t features[], int64_t sizeFeatures, float target[], int64_t sizeTarget) { + const char *input_names[] = {"bytes"}; + const char *output_names[] = {"target_label"}; + const int64_t input_shape[] = {1, sizeFeatures}; + OrtValue *input_tensor = NULL; + RETURN_ON_ERROR(ort->CreateTensorWithDataAsOrtValue(memory_info, features, sizeFeatures * sizeof(int32_t), input_shape, 2, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, &input_tensor)); + OrtValue *output_tensor = NULL; + RETURN_ON_ERROR(ort->Run(session, NULL, input_names, (const OrtValue *const *) &input_tensor, 1, output_names, 1, &output_tensor)); + float *out = NULL; + RETURN_ON_ERROR(ort->GetTensorMutableData(output_tensor, (void **) &out)); + memcpy(target, out, sizeTarget * sizeof(float)); + ort->ReleaseValue(input_tensor); + ort->ReleaseValue(output_tensor); + return NULL; +} + +const char *GetErrorMessage(const OrtStatus* onnx_status) { + if (onnx_status == NULL) { + return ""; + } + return OrtGetApiBase()->GetApi(ORT_API_VERSION)->GetErrorMessage(onnx_status); +} diff --git a/go/onnx/onnx_runtime_test.go b/go/onnx/onnx_runtime_test.go new file mode 100644 index 00000000..13b304c1 --- /dev/null +++ b/go/onnx/onnx_runtime_test.go @@ -0,0 +1,44 @@ +//go:build cgo && onnxruntime + +package onnx_test + +import ( + "math/rand/v2" + "testing" + + "github.com/google/magika/magika" + "github.com/google/magika/onnx" +) + +func TestONNXRuntime(t *testing.T) { + const ( + assetsDir = "../../assets" + modelName = "standard_v2_1" + modelPath = "../../assets/models/" + modelName + "/model.onnx" + ) + + cfg, err := magika.ReadConfig(assetsDir, modelName) + if err != nil { + t.Fatal(err) + } + + rt, err := onnx.NewOnnx(modelPath, len(cfg.TargetLabelsSpace)) + if err != nil { + t.Fatalf("Create onnx: %v", err) + } + + // Initialize a random features tensor. + features := make([]int32, cfg.BegSize+cfg.MidSize+cfg.EndSize) + for i := range features { + features[i] = rand.Int32() + } + + // Get the scores and check its size. + scores, err := rt.Run(features) + if err != nil { + t.Fatalf("Run onnx: %v", err) + } + if n, m := len(scores), len(cfg.TargetLabelsSpace); n != m { + t.Fatalf("Unexpected scores len: got %d, want %d", n, m) + } +} diff --git a/go/onnx/onnx_zero.go b/go/onnx/onnx_zero.go new file mode 100644 index 00000000..a0dd8447 --- /dev/null +++ b/go/onnx/onnx_zero.go @@ -0,0 +1,9 @@ +//go:build !(cgo && onnxruntime) + +package onnx + +// NewOnnx returns a nil Onnx runtime. +// This allows for building and unit testing in a non-cgo context. +func NewOnnx(string, int) (Onnx, error) { + return nil, nil +}