diff --git a/token/jwt.go b/token/jwt.go index 27db536..4a58d6e 100644 --- a/token/jwt.go +++ b/token/jwt.go @@ -126,6 +126,8 @@ func toStringArray(a interface{}) ([]string, bool) { return b, true } +// parseJWT tries to parse a string as a JWT. +// It returns (nil, nil) if the string does not look like a JWT. func parseJWT(token string, keys []map[string]interface{}) (*JWT, error) { t, err := jwt.Parse( token, diff --git a/token/stateful.go b/token/stateful.go index 1d0fb4b..a8e0c25 100644 --- a/token/stateful.go +++ b/token/stateful.go @@ -57,20 +57,23 @@ func SetStatefulFilename(filename string) { tokens.modTime = time.Time{} } -func getStateful(token string) (*Stateful, error) { +// Get fetches a stateful token. +// It returns os.ErrNotExist if the token doesn't exist. +func Get(token string) (*Stateful, error) { tokens.mu.Lock() defer tokens.mu.Unlock() err := tokens.load() if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil, nil - } return nil, err } if tokens.tokens == nil { - return nil, nil + return nil, os.ErrNotExist } - return tokens.tokens[token], nil + t := tokens.tokens[token] + if t == nil { + return nil, os.ErrNotExist + } + return t, nil } func (token *Stateful) Check(host, group string, username *string) (string, []string, error) { diff --git a/token/token.go b/token/token.go index 6b34787..4aaa1d9 100644 --- a/token/token.go +++ b/token/token.go @@ -2,7 +2,6 @@ package token import ( "errors" - "os" ) var ErrUsernameRequired = errors.New("username required") @@ -13,21 +12,15 @@ type Token interface { func Parse(token string, keys []map[string]interface{}) (Token, error) { // both getStateful and parseJWT may return nil, which we - // shouldn't cast into an interface. Be very careful. - s, err1 := getStateful(token) - if err1 == nil && s != nil { - return s, nil + // shouldn't cast into an interface before testing for nil. + jwt, err := parseJWT(token, keys) + if err != nil { + // parses correctly but doesn't validate + return nil, err } - - jwt, err2 := parseJWT(token, keys) - if err2 == nil && jwt != nil { + if jwt != nil { return jwt, nil } - if err1 != nil { - return nil, err1 - } else if err2 != nil { - return nil, err2 - } - return nil, os.ErrNotExist + return Get(token) }