1
Fork 0
mirror of https://github.com/jech/galene.git synced 2024-11-22 16:45:58 +01:00

Honour the kid field in JWT if present.

This commit is contained in:
Juliusz Chroboczek 2024-05-11 12:29:30 +02:00
parent 6c01925342
commit 969354e9e5
3 changed files with 18 additions and 9 deletions

5
README
View file

@ -319,8 +319,9 @@ specify either an authorisation server or an authorisation portal.
"authServer": "https://auth.example.org", "authServer": "https://auth.example.org",
} }
If multiple keys are provided, then they will all be tried in turn (the If multiple keys are provided, then they will all be tried in turn, unless
kid field, if provided, is ignored). the token includes the "kid" header field, in which case only the
specified key will be used.
If an authorisation server is specified, then the default client, after it If an authorisation server is specified, then the default client, after it
prompts for a password, will request a token from the authorisation server prompts for a password, will request a token from the authorisation server

View file

@ -581,7 +581,7 @@ func SetWildcardUser(group string, user *UserDescription) error {
func SetKeys(group string, keys []map[string]any) error { func SetKeys(group string, keys []map[string]any) error {
if keys != nil { if keys != nil {
_, err := token.ParseKeys(keys) _, err := token.ParseKeys(keys, "")
if err != nil { if err != nil {
return err return err
} }

View file

@ -96,14 +96,18 @@ func ParseKey(key map[string]any) (any, error) {
} }
} }
func ParseKeys(keys []map[string]any) ([]jwt.VerificationKey, error) { func ParseKeys(keys []map[string]any, kid string) ([]jwt.VerificationKey, error) {
ks := make([]jwt.VerificationKey, len(keys)) ks := make([]jwt.VerificationKey, 0, len(keys))
for i, ky := range keys { for _, ky := range keys {
// return all keys if kid is not specified
if kid != "" && ky["kid"] != kid {
continue
}
k, err := ParseKey(ky) k, err := ParseKey(ky)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ks[i] = k ks = append(ks, k)
} }
return ks, nil return ks, nil
} }
@ -130,11 +134,15 @@ func toStringArray(a interface{}) ([]string, bool) {
func parseJWT(token string, keys []map[string]any) (*JWT, error) { func parseJWT(token string, keys []map[string]any) (*JWT, error) {
t, err := jwt.Parse( t, err := jwt.Parse(
token, token,
func(t *jwt.Token) (interface{}, error) { func(t *jwt.Token) (any, error) {
ks, err := ParseKeys(keys) kid, _ := t.Header["kid"].(string)
ks, err := ParseKeys(keys, kid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(ks) == 1 {
return ks[0], nil
}
return jwt.VerificationKeySet{Keys: ks}, nil return jwt.VerificationKeySet{Keys: ks}, nil
}, },
jwt.WithExpirationRequired(), jwt.WithExpirationRequired(),