From 8ebcb0fb79038844601730c73660315ec674cfd7 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Jun 2021 15:18:22 +0100 Subject: [PATCH] Ensure user IDs match the spec (#261) --- event.go | 23 ++++++++++++++++++++--- event_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/event.go b/event.go index d6da9b5f..77104d9b 100644 --- a/event.go +++ b/event.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "reflect" + "regexp" "strings" "time" @@ -709,11 +710,27 @@ func (e *Event) CheckFields() error { // nolint: gocyclo return nil } -func checkID(id, kind string, sigil byte) (domain string, err error) { - domain, err = domainFromID(id) - if err != nil { +// https://matrix.org/docs/spec/appendices#id12 +var validUserIDRegex = regexp.MustCompile(`^[0-9a-z\._=\-/]+$`) + +func checkID(id, kind string, sigil byte) (domain ServerName, err error) { + var user string + if user, domain, err = SplitID(sigil, id); err != nil { + err = fmt.Errorf("gomatrixserverlib: invalid ID %q, invalid format: %w", id, err) return } + // Check that the characters in the ID are valid for the type + switch sigil { + case '@': + if len(id) > 255 { + err = fmt.Errorf("gomatrixserverlib: invalid ID %q, too long", id) + return + } + if !validUserIDRegex.MatchString(user) { + err = fmt.Errorf("gomatrixserverlib: invalid ID %q, user part %q contains invalid characters", id, user) + return + } + } if id[0] != sigil { err = fmt.Errorf( "gomatrixserverlib: invalid %s ID, wanted first byte to be '%c' got '%c'", diff --git a/event_test.go b/event_test.go index f9ef9fc0..b9469d45 100644 --- a/event_test.go +++ b/event_test.go @@ -180,3 +180,28 @@ func TestHeaderedEventToNewEventFromUntrustedJSON(t *testing.T) { t.Fatal("expected an UnexpectedHeaderedEvent error but got:", err) } } + +func TestValidUserID(t *testing.T) { + userIDsShouldPass := []string{ + "@foo:bar.com", + "@foo-baz:bar.com", + "@foo_qux/baz:bar.com", + "@foo.baz:bar.com", + } + userIDsShouldFail := []string{ + "@Foo:bar.com", + "@foo%:bar.com", + "@fooé:bar.com", + "@℉oo:bar.com", + } + for _, id := range userIDsShouldPass { + if _, err := checkID(id, "user", '@'); err != nil { + t.Fatalf("%q should have passed but didn't: %s", id, err) + } + } + for _, id := range userIDsShouldFail { + if _, err := checkID(id, "user", '@'); err == nil { + t.Fatalf("%q should have failed but didn't", id) + } + } +}