-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] - Add Weights and Biases detector (#3551)
* add Weights and biases detector * add to default detector list
- Loading branch information
Showing
6 changed files
with
393 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
package weightsandbiases | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/base64" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"strconv" | ||
|
||
regexp "github.com/wasilibs/go-re2" | ||
|
||
"github.com/trufflesecurity/trufflehog/v3/pkg/common" | ||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors" | ||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb" | ||
) | ||
|
||
type Scanner struct{ client *http.Client } | ||
|
||
// Ensure the Scanner satisfies the interface at compile time. | ||
var _ detectors.Detector = (*Scanner)(nil) | ||
|
||
var ( | ||
defaultClient = common.SaneHttpClient() | ||
// Make sure that your group is surrounded in boundary characters such as below to reduce false positives. | ||
keyPat = regexp.MustCompile(detectors.PrefixRegex([]string{"wandb"}) + `\b([0-9a-f]{40})\b`) | ||
) | ||
|
||
// Keywords are used for efficiently pre-filtering chunks. | ||
// Use identifiers in the secret preferably, or the provider name. | ||
func (s Scanner) Keywords() []string { return []string{"wandb"} } | ||
|
||
// FromData will find and optionally verify Weightsandbiases secrets in a given set of bytes. | ||
func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (results []detectors.Result, err error) { | ||
dataStr := string(data) | ||
|
||
uniqueMatches := make(map[string]struct{}) | ||
for _, match := range keyPat.FindAllStringSubmatch(dataStr, -1) { | ||
uniqueMatches[match[1]] = struct{}{} | ||
} | ||
|
||
for match := range uniqueMatches { | ||
s1 := detectors.Result{ | ||
DetectorType: detectorspb.DetectorType_WeightsAndBiases, | ||
Raw: []byte(match), | ||
} | ||
|
||
if verify { | ||
client := s.client | ||
if client == nil { | ||
client = defaultClient | ||
} | ||
|
||
isVerified, extraData, verificationErr := verifyMatch(ctx, client, match) | ||
s1.Verified = isVerified | ||
s1.ExtraData = extraData | ||
s1.SetVerificationError(verificationErr, match) | ||
} | ||
|
||
results = append(results, s1) | ||
} | ||
|
||
return | ||
} | ||
|
||
type viewerResponse struct { | ||
Data struct { | ||
Viewer struct { | ||
ID string `json:"id"` | ||
Username string `json:"username"` | ||
Email string `json:"email"` | ||
Admin bool `json:"admin"` | ||
} `json:"viewer"` | ||
} `json:"data"` | ||
} | ||
|
||
func verifyMatch(ctx context.Context, client *http.Client, token string) (bool, map[string]string, error) { | ||
query := `{"query": "query Viewer { viewer { id username email admin } }"}` | ||
|
||
const baseURL = "https://api.wandb.ai/graphql" | ||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewBufferString(query)) | ||
if err != nil { | ||
return false, nil, nil | ||
} | ||
|
||
authHeader := base64.StdEncoding.EncodeToString([]byte("api:" + token)) | ||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Authorization", "Basic "+authHeader) | ||
|
||
res, err := client.Do(req) | ||
if err != nil { | ||
return false, nil, err | ||
} | ||
defer func() { | ||
_, _ = io.Copy(io.Discard, res.Body) | ||
_ = res.Body.Close() | ||
}() | ||
|
||
switch res.StatusCode { | ||
case http.StatusOK: | ||
var viewerResp viewerResponse | ||
if err := json.NewDecoder(res.Body).Decode(&viewerResp); err != nil { | ||
return false, nil, err | ||
} | ||
|
||
// Only consider it verified if we got back a username. | ||
if viewerResp.Data.Viewer.Username == "" { | ||
return false, nil, nil | ||
} | ||
|
||
extraData := map[string]string{ | ||
"username": viewerResp.Data.Viewer.Username, | ||
"email": viewerResp.Data.Viewer.Email, | ||
"admin": strconv.FormatBool(viewerResp.Data.Viewer.Admin), | ||
} | ||
return true, extraData, nil | ||
case http.StatusUnauthorized: | ||
return false, nil, nil | ||
default: | ||
return false, nil, fmt.Errorf("unexpected HTTP response status %d", res.StatusCode) | ||
} | ||
} | ||
|
||
func (s Scanner) Description() string { | ||
return "Weights & Biases is a Machine Learning Operations (MLOps) platform that helps track experiments, version datasets, evaluate model performance, and collaborate with team members" | ||
} | ||
|
||
func (s Scanner) Type() detectorspb.DetectorType { | ||
return detectorspb.DetectorType_WeightsAndBiases | ||
} |
167 changes: 167 additions & 0 deletions
167
pkg/detectors/weightsandbiases/weightsandbiases_integration_test.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
//go:build detectors | ||
// +build detectors | ||
|
||
package weightsandbiases | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
"github.com/google/go-cmp/cmp/cmpopts" | ||
|
||
"github.com/trufflesecurity/trufflehog/v3/pkg/common" | ||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors" | ||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb" | ||
) | ||
|
||
func TestWeightsandbiases_FromChunk(t *testing.T) { | ||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) | ||
defer cancel() | ||
testSecrets, err := common.GetSecret(ctx, "trufflehog-testing", "detectors5") | ||
if err != nil { | ||
t.Fatalf("could not get test secrets from GCP: %s", err) | ||
} | ||
secret := testSecrets.MustGetField("WEIGHTSANDBIASES") | ||
inactiveSecret := testSecrets.MustGetField("WEIGHTSANDBIASES_INACTIVE") | ||
|
||
type args struct { | ||
ctx context.Context | ||
data []byte | ||
verify bool | ||
} | ||
tests := []struct { | ||
name string | ||
s Scanner | ||
args args | ||
want []detectors.Result | ||
wantErr bool | ||
wantVerificationErr bool | ||
}{ | ||
{ | ||
name: "found, verified", | ||
s: Scanner{}, | ||
args: args{ | ||
ctx: context.Background(), | ||
data: []byte(fmt.Sprintf("You can find a weightsandbiases secret wandb %s within", secret)), | ||
verify: true, | ||
}, | ||
want: []detectors.Result{ | ||
{ | ||
DetectorType: detectorspb.DetectorType_WeightsAndBiases, | ||
Verified: true, | ||
ExtraData: map[string]string{ | ||
"admin": "false", | ||
"email": "[email protected]", | ||
"username": "source-integrations", | ||
}, | ||
}, | ||
}, | ||
wantErr: false, | ||
wantVerificationErr: false, | ||
}, | ||
{ | ||
name: "found, unverified", | ||
s: Scanner{}, | ||
args: args{ | ||
ctx: context.Background(), | ||
data: []byte(fmt.Sprintf("You can find a weightsandbiases secret wandb %s within but not valid", inactiveSecret)), // the secret would satisfy the regex but not pass validation | ||
verify: true, | ||
}, | ||
want: []detectors.Result{ | ||
{ | ||
DetectorType: detectorspb.DetectorType_WeightsAndBiases, | ||
Verified: false, | ||
}, | ||
}, | ||
wantErr: false, | ||
wantVerificationErr: false, | ||
}, | ||
{ | ||
name: "not found", | ||
s: Scanner{}, | ||
args: args{ | ||
ctx: context.Background(), | ||
data: []byte("You cannot find the secret within"), | ||
verify: true, | ||
}, | ||
want: nil, | ||
wantErr: false, | ||
wantVerificationErr: false, | ||
}, | ||
{ | ||
name: "found, would be verified if not for timeout", | ||
s: Scanner{client: common.SaneHttpClientTimeOut(1 * time.Microsecond)}, | ||
args: args{ | ||
ctx: context.Background(), | ||
data: []byte(fmt.Sprintf("You can find a weightsandbiases secret wandb %s within", secret)), | ||
verify: true, | ||
}, | ||
want: []detectors.Result{ | ||
{ | ||
DetectorType: detectorspb.DetectorType_WeightsAndBiases, | ||
Verified: false, | ||
}, | ||
}, | ||
wantErr: false, | ||
wantVerificationErr: true, | ||
}, | ||
{ | ||
name: "found, verified but unexpected api surface", | ||
s: Scanner{client: common.ConstantResponseHttpClient(404, "")}, | ||
args: args{ | ||
ctx: context.Background(), | ||
data: []byte(fmt.Sprintf("You can find a weightsandbiases secret wandb %s within", secret)), | ||
verify: true, | ||
}, | ||
want: []detectors.Result{ | ||
{ | ||
DetectorType: detectorspb.DetectorType_WeightsAndBiases, | ||
Verified: false, | ||
}, | ||
}, | ||
wantErr: false, | ||
wantVerificationErr: true, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := tt.s.FromData(tt.args.ctx, tt.args.verify, tt.args.data) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("Weightsandbiases.FromData() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
for i := range got { | ||
if len(got[i].Raw) == 0 { | ||
t.Fatalf("no raw secret present: \n %+v", got[i]) | ||
} | ||
if (got[i].VerificationError() != nil) != tt.wantVerificationErr { | ||
t.Fatalf("wantVerificationError = %v, verification error = %v", tt.wantVerificationErr, got[i].VerificationError()) | ||
} | ||
} | ||
ignoreOpts := cmpopts.IgnoreFields(detectors.Result{}, "Raw", "verificationError") | ||
if diff := cmp.Diff(got, tt.want, ignoreOpts); diff != "" { | ||
t.Errorf("Weightsandbiases.FromData() %s diff: (-got +want)\n%s", tt.name, diff) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func BenchmarkFromData(benchmark *testing.B) { | ||
ctx := context.Background() | ||
s := Scanner{} | ||
for name, data := range detectors.MustGetBenchmarkData() { | ||
benchmark.Run(name, func(b *testing.B) { | ||
b.ResetTimer() | ||
for n := 0; n < b.N; n++ { | ||
_, err := s.FromData(ctx, false, data) | ||
if err != nil { | ||
b.Fatal(err) | ||
} | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package weightsandbiases | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
|
||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors" | ||
"github.com/trufflesecurity/trufflehog/v3/pkg/engine/ahocorasick" | ||
) | ||
|
||
func TestWeightsandbiases_Pattern(t *testing.T) { | ||
d := Scanner{} | ||
ahoCorasickCore := ahocorasick.NewAhoCorasickCore([]detectors.Detector{d}) | ||
tests := []struct { | ||
name string | ||
input string | ||
want []string | ||
}{ | ||
{ | ||
name: "typical pattern", | ||
input: "WANDB_API_KEY = 'eedf1c984f6b995ec40ecc6658356044847ffb31'", | ||
want: []string{"eedf1c984f6b995ec40ecc6658356044847ffb31"}, | ||
}, | ||
{ | ||
name: "finds all matches", | ||
input: `WANDB_API_KEY = 'eedf1c984f6b995ec40ecc6658356044847ffb31' | ||
WANDB_API_KEY= 'eedf1c984f6b995ec40ecc6658356044847ffb32'`, | ||
want: []string{"eedf1c984f6b995ec40ecc6658356044847ffb31", "eedf1c984f6b995ec40ecc6658356044847ffb32"}, | ||
}, | ||
{ | ||
name: "invald pattern", | ||
input: "WANDB_API_KEY = 'e84f6b995ec40ecc6658356044847ffb31'", | ||
want: []string{}, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
t.Run(test.name, func(t *testing.T) { | ||
matchedDetectors := ahoCorasickCore.FindDetectorMatches([]byte(test.input)) | ||
if len(matchedDetectors) == 0 { | ||
t.Errorf("keywords '%v' not matched by: %s", d.Keywords(), test.input) | ||
return | ||
} | ||
|
||
results, err := d.FromData(context.Background(), false, []byte(test.input)) | ||
if err != nil { | ||
t.Errorf("error = %v", err) | ||
return | ||
} | ||
|
||
if len(results) != len(test.want) { | ||
if len(results) == 0 { | ||
t.Errorf("did not receive result") | ||
} else { | ||
t.Errorf("expected %d results, only received %d", len(test.want), len(results)) | ||
} | ||
return | ||
} | ||
|
||
actual := make(map[string]struct{}, len(results)) | ||
for _, r := range results { | ||
if len(r.RawV2) > 0 { | ||
actual[string(r.RawV2)] = struct{}{} | ||
} else { | ||
actual[string(r.Raw)] = struct{}{} | ||
} | ||
} | ||
expected := make(map[string]struct{}, len(test.want)) | ||
for _, v := range test.want { | ||
expected[v] = struct{}{} | ||
} | ||
|
||
if diff := cmp.Diff(expected, actual); diff != "" { | ||
t.Errorf("%s diff: (-want +got)\n%s", test.name, diff) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.