diff --git a/go.mod b/go.mod index 77a2a6e..abeca52 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/at-wat/ebml-go v0.17.0 - github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/golang-jwt/jwt/v5 v5.2.0 github.com/gorilla/websocket v1.5.0 github.com/jech/cert v0.0.0-20210819231831-aca735647728 github.com/jech/samplebuilder v0.0.0-20221109182433-6cbba09fc1c9 diff --git a/go.sum b/go.sum index 17608f9..07b75c4 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= -github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= -github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= +github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= diff --git a/token/jwt.go b/token/jwt.go index bf6fc31..592ad5d 100644 --- a/token/jwt.go +++ b/token/jwt.go @@ -9,8 +9,9 @@ import ( "net/url" "path" "strings" + "time" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" ) type JWT jwt.Token @@ -108,9 +109,14 @@ func getKey(header map[string]interface{}, keys []map[string]interface{}) (inter return nil, errors.New("key not found") } -func toStringArray(a []interface{}) ([]string, bool) { - b := make([]string, len(a)) - for i, v := range a { +func toStringArray(a interface{}) ([]string, bool) { + aa, ok := a.([]interface{}) + if !ok { + return nil, false + } + + b := make([]string, len(aa)) + for i, v := range aa { w, ok := v.(string) if !ok { return nil, false @@ -121,9 +127,15 @@ func toStringArray(a []interface{}) ([]string, bool) { } func parseJWT(token string, keys []map[string]interface{}) (*JWT, error) { - t, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { - return getKey(t.Header, keys) - }) + t, err := jwt.Parse( + token, + func(t *jwt.Token) (interface{}, error) { + return getKey(t.Header, keys) + }, + jwt.WithExpirationRequired(), + jwt.WithIssuedAt(), + jwt.WithLeeway(5*time.Second), + ) if err != nil { return nil, err } @@ -131,34 +143,18 @@ func parseJWT(token string, keys []map[string]interface{}) (*JWT, error) { } func (token *JWT) Check(host, group string, username *string) (string, []string, error) { - claims := token.Claims.(jwt.MapClaims) - - s, ok := claims["sub"] - if !ok { - return "", nil, errors.New("token has no 'sub' field") - } - sub, ok := s.(string) - if !ok { - return "", nil, errors.New("invalid 'sub' field") + sub, err := token.Claims.GetSubject() + if err != nil { + return "", nil, err } // we accept tokens with a different username from the one provided, // and use the token's 'sub' field to override the username - var aud []string - if a, ok := claims["aud"]; ok && a != nil { - switch a := a.(type) { - case string: - aud = []string{a} - case []interface{}: - aud, ok = toStringArray(a) - if !ok { - return "", nil, errors.New("invalid 'aud' field") - } - default: - return "", nil, errors.New("invalid 'aud' field") - } + aud, err := token.Claims.GetAudience() + if err != nil { + return "", nil, err } - ok = false + ok := false for _, u := range aud { url, err := url.Parse(u) if err != nil { @@ -181,13 +177,14 @@ func (token *JWT) Check(host, group string, username *string) (string, []string, return "", nil, errors.New("token for wrong group") } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return "", nil, errors.New("unexpected type for token") + } + var perms []string if p, ok := claims["permissions"]; ok && p != nil { - pp, ok := p.([]interface{}) - if !ok { - return "", nil, errors.New("invalid 'permissions' field") - } - perms, ok = toStringArray(pp) + perms, ok = toStringArray(p) if !ok { return "", nil, errors.New("invalid 'permissions' field") } diff --git a/token/jwt_test.go b/token/jwt_test.go index bb385f8..b19fa13 100644 --- a/token/jwt_test.go +++ b/token/jwt_test.go @@ -125,8 +125,11 @@ func TestJWT(t *testing.T) { t.Errorf("Couldn't parse noSubToken: %v", err) } username, perms, err = tok.Check("galene.org:8443", "auth", &jack) - if err == nil { - t.Errorf("noSubToken is valid") + if err != nil { + t.Errorf("noSubToken is not valid: %v", err) + } + if username != "" || !reflect.DeepEqual(perms, []string{"present"}) { + t.Errorf("Expected \"\", [present], got %v %v", username, perms) } badToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJub25lIn0.eyJzdWIiOiJqb2huIiwiYXVkIjoiaHR0cHM6Ly9nYWxlbmUub3JnOjg0NDMvZ3JvdXAvYXV0aC8iLCJwZXJtaXNzaW9ucyI6WyJwcmVzZW50Il0sImlhdCI6MTY0NTMxMDQ2OSwiZXhwIjoyOTA2NzUwNDY5LCJpc3MiOiJodHRwOi8vbG9jYWxob3N0OjEyMzQvIn0."