1
Fork 0

Implement accessors for stateful tokens.

This commit is contained in:
Juliusz Chroboczek 2024-04-14 20:59:46 +02:00
parent fe15057252
commit 2f5c21d161
5 changed files with 153 additions and 99 deletions

View File

@ -1753,7 +1753,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
now := time.Now() now := time.Now()
tok.IssuedAt = &now tok.IssuedAt = &now
new, err := token.Add(tok) new, err := token.Update(tok, "")
if err != nil { if err != nil {
return terror("error", err.Error()) return terror("error", err.Error())
} }
@ -1783,19 +1783,26 @@ func handleClientMessage(c *webClient, m clientMessage) error {
} }
if tok.Group != "" || tok.Username != nil || if tok.Group != "" || tok.Username != nil ||
tok.Permissions != nil || tok.Permissions != nil ||
tok.NotBefore != nil ||
tok.IssuedBy != nil || tok.IssuedBy != nil ||
tok.IssuedAt != nil { tok.IssuedAt != nil {
return terror( return terror(
"error", "this field cannot be edited", "error", "this field cannot be edited",
) )
} }
if tok.Expires == nil {
return terror("error", "trying to edit nothing") old, etag, err := token.Get(tok.Token)
if err != nil {
return terror("error", err.Error())
} }
new, err := token.Extend( t := old.Clone()
c.group.Name(), tok.Token, *tok.Expires, if tok.Expires != nil {
) t.Expires = tok.Expires
}
if tok.NotBefore != nil {
t.NotBefore = tok.NotBefore
}
new, err := token.Update(t, etag)
if err != nil { if err != nil {
return terror("error", err.Error()) return terror("error", err.Error())
} }
@ -1819,7 +1826,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
!member("token", c.permissions) { !member("token", c.permissions) {
return terror("not-authorised", "not authorised") return terror("not-authorised", "not authorised")
} }
tokens, err := token.List(c.group.Name()) tokens, _, err := token.List(c.group.Name())
if err != nil { if err != nil {
return terror("error", err.Error()) return terror("error", err.Error())
} }

View File

@ -3,6 +3,7 @@ package token
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
@ -11,6 +12,8 @@ import (
"time" "time"
) )
var ErrTagMismatch = errors.New("tag mismatch")
// A stateful token // A stateful token
type Stateful struct { type Stateful struct {
Token string `json:"token"` Token string `json:"token"`
@ -57,26 +60,26 @@ func SetStatefulFilename(filename string) {
tokens.modTime = time.Time{} tokens.modTime = time.Time{}
} }
func (state *state) Get(token string) (*Stateful, error) { func (state *state) Get(token string) (*Stateful, string, error) {
state.mu.Lock() state.mu.Lock()
defer state.mu.Unlock() defer state.mu.Unlock()
err := state.load() etag, err := state.load()
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
if state.tokens == nil { if state.tokens == nil {
return nil, os.ErrNotExist return nil, "", os.ErrNotExist
} }
t := state.tokens[token] t := state.tokens[token]
if t == nil { if t == nil {
return nil, os.ErrNotExist return nil, "", os.ErrNotExist
} }
return t, nil return t, etag, nil
} }
// Get fetches a stateful token. // Get fetches a stateful token.
// It returns os.ErrNotExist if the token doesn't exist. // It returns os.ErrNotExist if the token doesn't exist.
func Get(token string) (*Stateful, error) { func Get(token string) (*Stateful, string, error) {
return tokens.Get(token) return tokens.Get(token)
} }
@ -103,12 +106,13 @@ func (token *Stateful) Check(host, group string, username *string) (string, []st
return user, token.Permissions, nil return user, token.Permissions, nil
} }
// load updates the state from the corresponding file.
// called locked // called locked
func (state *state) load() error { func (state *state) load() (string, error) {
if state.filename == "" { if state.filename == "" {
state.modTime = time.Time{} state.modTime = time.Time{}
state.tokens = nil state.tokens = nil
return nil return state.etag(), nil
} }
fi, err := os.Stat(state.filename) fi, err := os.Stat(state.filename)
@ -117,14 +121,13 @@ func (state *state) load() error {
state.fileSize = 0 state.fileSize = 0
state.tokens = nil state.tokens = nil
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
return nil return "", nil
} }
return err return "", err
} }
if state.modTime.Equal(fi.ModTime()) && if state.modTime.Equal(fi.ModTime()) && state.fileSize == fi.Size() {
state.fileSize == fi.Size() { return state.etag(), nil
return nil
} }
f, err := os.Open(state.filename) f, err := os.Open(state.filename)
@ -133,9 +136,9 @@ func (state *state) load() error {
state.fileSize = 0 state.fileSize = 0
state.tokens = nil state.tokens = nil
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
return nil return state.etag(), nil
} }
return err return "", err
} }
defer f.Close() defer f.Close()
@ -150,7 +153,7 @@ func (state *state) load() error {
} else if err != nil { } else if err != nil {
state.modTime = time.Time{} state.modTime = time.Time{}
state.fileSize = 0 state.fileSize = 0
return err return "", err
} }
ts[t.Token] = &t ts[t.Token] = &t
} }
@ -161,35 +164,108 @@ func (state *state) load() error {
state.fileSize = 0 state.fileSize = 0
state.tokens = nil state.tokens = nil
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
return nil return state.etag(), nil
} }
return err return "", err
} }
state.modTime = fi.ModTime() state.modTime = fi.ModTime()
state.fileSize = fi.Size() state.fileSize = fi.Size()
return nil return state.etag(), nil
} }
func (state *state) Add(token *Stateful) (*Stateful, error) { func (state *state) etag() string {
if state.modTime.Equal(time.Time{}) {
return ""
}
return fmt.Sprintf("\"%v-%v\"",
state.fileSize, state.modTime.UnixNano(),
)
}
// Update adds or updates a token.
// If etag is the empty string, it is added if it didn't exist. If etag
// is not empty, it is added if it matches the state's etag.
func (state *state) Update(token *Stateful, etag string) (*Stateful, error) {
tokens.mu.Lock() tokens.mu.Lock()
defer tokens.mu.Unlock() defer tokens.mu.Unlock()
if state.filename == "" { if state.filename == "" {
if etag != "" {
return nil, ErrTagMismatch
}
return nil, os.ErrNotExist return nil, os.ErrNotExist
} }
err := state.load() _, err := state.load()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if state.tokens != nil { if state.tokens == nil {
if _, ok := state.tokens[token.Token]; ok { state.tokens = make(map[string]*Stateful)
return nil, os.ErrExist
}
} }
err = os.MkdirAll(filepath.Dir(state.filename), 0700) old, ok := state.tokens[token.Token]
if ok {
if etag != state.etag() {
return nil, ErrTagMismatch
}
state.tokens[token.Token] = token
err = state.rewrite()
if err != nil {
state.tokens[token.Token] = old
return nil, err
}
return token, nil
}
if etag != "" {
return nil, ErrTagMismatch
}
return state.add(token)
}
func Delete(token string, etag string) error {
return tokens.Delete(token, etag)
}
func (state *state) Delete(token string, etag string) error {
tokens.mu.Lock()
defer tokens.mu.Unlock()
if state.filename == "" {
return os.ErrNotExist
}
_, err := state.load()
if err != nil {
return err
}
if state.tokens == nil {
return os.ErrNotExist
}
old, ok := state.tokens[token]
if !ok {
return os.ErrNotExist
}
if etag != state.etag() {
return ErrTagMismatch
}
delete(state.tokens, token)
err = state.rewrite()
if err != nil {
state.tokens[token] = old
return err
}
return nil
}
// add unconditionally adds a token, which is assumed to not exist.
// called locked
func (state *state) add(token *Stateful) (*Stateful, error) {
err := os.MkdirAll(filepath.Dir(state.filename), 0700)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -220,49 +296,8 @@ func (state *state) Add(token *Stateful) (*Stateful, error) {
return token, nil return token, nil
} }
func Add(token *Stateful) (*Stateful, error) { func Update(token *Stateful, etag string) (*Stateful, error) {
return tokens.Add(token) return tokens.Update(token, etag)
}
func Extend(group, token string, expires time.Time) (*Stateful, error) {
return tokens.Extend(group, token, expires)
}
func (state *state) Extend(group, token string, expires time.Time) (*Stateful, error) {
tokens.mu.Lock()
defer tokens.mu.Unlock()
return state.extend(group, token, expires)
}
// called locked
func (state *state) extend(group, token string, expires time.Time) (*Stateful, error) {
err := state.load()
if err != nil {
return nil, err
}
if state.tokens == nil {
return nil, os.ErrNotExist
}
old := state.tokens[token]
if old == nil {
return nil, os.ErrNotExist
}
if old.Group != group {
return nil, os.ErrPermission
}
new := old.Clone()
new.Expires = &expires
state.tokens[token] = new
err = state.rewrite()
if err != nil {
state.tokens[token] = old
return nil, err
}
return new, err
} }
// called locked // called locked
@ -280,7 +315,7 @@ func (state *state) rewrite() error {
if err != nil { if err != nil {
return err return err
} }
a, err := state.list("") a, _, err := state.list("")
if err != nil { if err != nil {
os.Remove(tmpfile.Name()) os.Remove(tmpfile.Name())
return err return err
@ -321,15 +356,15 @@ func (state *state) rewrite() error {
} }
// called locked // called locked
func (state *state) list(group string) ([]*Stateful, error) { func (state *state) list(group string) ([]*Stateful, string, error) {
err := state.load() _, err := state.load()
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
a := make([]*Stateful, 0) a := make([]*Stateful, 0)
if state.tokens == nil { if state.tokens == nil {
return a, nil return a, state.etag(), nil
} }
for _, t := range state.tokens { for _, t := range state.tokens {
if group != "" { if group != "" {
@ -348,16 +383,16 @@ func (state *state) list(group string) ([]*Stateful, error) {
} }
return (*a[i].Expires).Before(*a[j].Expires) return (*a[i].Expires).Before(*a[j].Expires)
}) })
return a, nil return a, state.etag(), nil
} }
func (state *state) List(group string) ([]*Stateful, error) { func (state *state) List(group string) ([]*Stateful, string, error) {
state.mu.Lock() state.mu.Lock()
defer state.mu.Unlock() defer state.mu.Unlock()
return state.list(group) return state.list(group)
} }
func List(group string) ([]*Stateful, error) { func List(group string) ([]*Stateful, string, error) {
return tokens.List(group) return tokens.List(group)
} }
@ -365,7 +400,7 @@ func (state *state) Expire() error {
state.mu.Lock() state.mu.Lock()
defer state.mu.Unlock() defer state.mu.Unlock()
err := state.load() _, err := state.load()
if err != nil { if err != nil {
return err return err
} }

View File

@ -252,7 +252,7 @@ func TestTokenStorage(t *testing.T) {
}, },
} }
for i, token := range tokens { for i, token := range tokens {
new, err := s.Add(token) new, err := s.Update(token, "")
if err != nil { if err != nil {
t.Errorf("Add: %v", err) t.Errorf("Add: %v", err)
} }
@ -264,19 +264,30 @@ func TestTokenStorage(t *testing.T) {
} }
s.modTime = time.Time{} s.modTime = time.Time{}
err := s.load() _, err := s.load()
if err != nil { if err != nil {
t.Errorf("Load: %v", err) t.Errorf("Load: %v", err)
} }
expectTokens(t, s.tokens, tokens) expectTokens(t, s.tokens, tokens)
_, err = s.Extend("test2", tokens[1].Token, now.Add(time.Hour)) t1, etag, err := s.Get("tok2")
if err == nil {
t.Errorf("Edit succeeded with wrong group")
}
new, err := s.Extend("test", tokens[1].Token, now.Add(time.Hour))
if err != nil { if err != nil {
t.Errorf("Edit: %v", err) t.Fatalf("Get: %v", err)
}
t2 := t1.Clone()
soon := now.Add(time.Hour)
t2.Expires = &soon
_, err = s.Update(t2, "")
if !errors.Is(err, ErrTagMismatch) {
t.Errorf("Update: got %v, expected ErrTagMismatch", err)
}
_, err = s.Update(t2, "\"bad\"")
if !errors.Is(err, ErrTagMismatch) {
t.Errorf("Update: got %v, expected ErrTagMismatch", err)
}
new, err := s.Update(t2, etag)
if err != nil {
t.Fatalf("Update: %v", err)
} }
tokens[1].Expires = &future tokens[1].Expires = &future
if !equal(new, tokens[1]) { if !equal(new, tokens[1]) {
@ -350,7 +361,7 @@ func TestExpire(t *testing.T) {
} }
for _, token := range tokens { for _, token := range tokens {
_, err := s.Add(token) _, err := s.Update(token, "")
if err != nil { if err != nil {
t.Errorf("Add: %v", err) t.Errorf("Add: %v", err)
} }

View File

@ -22,5 +22,6 @@ func Parse(token string, keys []map[string]interface{}) (Token, error) {
return jwt, nil return jwt, nil
} }
return Get(token) s, _, err := Get(token)
return s, err
} }

View File

@ -23,13 +23,13 @@ func TestToken(t *testing.T) {
future := time.Now().Add(time.Hour) future := time.Now().Add(time.Hour)
user := "user" user := "user"
_, err = Add(&Stateful{ _, err = Update(&Stateful{
Token: "token", Token: "token",
Group: "group", Group: "group",
Username: &user, Username: &user,
Permissions: []string{"present"}, Permissions: []string{"present"},
Expires: &future, Expires: &future,
}) }, "")
if err != nil { if err != nil {
t.Fatalf("Add: %v", err) t.Fatalf("Add: %v", err)
} }