diff --git a/token/jwt.go b/token/jwt.go index a009dd0..bf6fc31 100644 --- a/token/jwt.go +++ b/token/jwt.go @@ -120,7 +120,7 @@ func toStringArray(a []interface{}) ([]string, bool) { return b, true } -func parseJWT(token string, keys []map[string]interface{}) (Token, error) { +func parseJWT(token string, keys []map[string]interface{}) (*JWT, error) { t, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { return getKey(t.Header, keys) }) diff --git a/token/stateful.go b/token/stateful.go index 95e1eae..2a3f170 100644 --- a/token/stateful.go +++ b/token/stateful.go @@ -55,11 +55,14 @@ func SetStatefulFilename(filename string) { tokens.modTime = time.Time{} } -func getStateful(token string) (Token, error) { +func getStateful(token string) (*Stateful, error) { tokens.mu.Lock() defer tokens.mu.Unlock() err := tokens.load() if err != nil { + if os.IsNotExist(err) { + return nil, nil + } return nil, err } if tokens.tokens == nil { diff --git a/token/token.go b/token/token.go index 65e7871..9fbd60a 100644 --- a/token/token.go +++ b/token/token.go @@ -2,6 +2,7 @@ package token import ( "errors" + "os" ) var ErrUsernameRequired = errors.New("username required") @@ -11,9 +12,23 @@ type Token interface { } func Parse(token string, keys []map[string]interface{}) (Token, error) { - t, err := getStateful(token) - if err == nil && t != nil { - return t, nil + // both getStateful and parseJWT may return nil, which we + // shouldn't cast into an interface. Be very careful. + s, err1 := getStateful(token) + if err1 == nil && s != nil { + return s, nil + } + + jwt, err2 := parseJWT(token, keys) + if err2 == nil && jwt != nil { + return jwt, nil + } + + if err1 != nil { + return nil, err1 + } else if err2 != nil { + return nil, err2 + } else { + return nil, os.ErrNotExist } - return parseJWT(token, keys) } diff --git a/token/token_test.go b/token/token_test.go new file mode 100644 index 0000000..1bcad8c --- /dev/null +++ b/token/token_test.go @@ -0,0 +1,31 @@ +package token + +import ( + "testing" + "path/filepath" + "os" +) + +func TestBad(t *testing.T) { + d := t.TempDir() + tokens = state{ + filename: filepath.Join(d, "test.jsonl"), + } + f, err := os.OpenFile(tokens.filename, + os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600, + ) + if err != nil { + t.Fatalf("Create: %v", err) + } + defer f.Close() + + token, err := Parse("foo", nil) + if err == nil { + t.Errorf("Expected error, got %v", token) + } + + token, err = Parse("foo", []map[string]interface{}{}) + if err == nil { + t.Errorf("Expected error, got %v", token) + } +}