1
Fork 0
photoview/api/graphql/models/user.go

180 lines
4.4 KiB
Go
Raw Normal View History

2020-01-31 17:36:48 +01:00
package models
import (
2020-01-31 18:51:24 +01:00
"crypto/rand"
2020-01-31 17:36:48 +01:00
"database/sql"
2020-01-31 18:51:24 +01:00
"fmt"
2020-01-31 23:30:34 +01:00
"log"
2020-02-16 12:22:00 +01:00
"os"
2020-01-31 18:51:24 +01:00
"time"
2020-01-31 17:36:48 +01:00
2020-07-10 18:35:37 +02:00
"github.com/pkg/errors"
2020-01-31 17:36:48 +01:00
"golang.org/x/crypto/bcrypt"
)
type User struct {
2020-01-31 18:51:24 +01:00
UserID int
Username string
2020-02-16 12:22:00 +01:00
Password *string
2020-01-31 18:51:24 +01:00
RootPath string
Admin bool
}
2020-02-09 21:25:33 +01:00
func (u *User) ID() int {
return u.UserID
2020-01-31 23:30:34 +01:00
}
2020-01-31 18:51:24 +01:00
type AccessToken struct {
Value string
Expire time.Time
2020-01-31 17:36:48 +01:00
}
2020-01-31 23:30:34 +01:00
var ErrorInvalidUserCredentials = errors.New("invalid credentials")
2020-01-31 17:36:48 +01:00
func NewUserFromRow(row *sql.Row) (*User, error) {
user := User{}
2020-01-31 18:51:24 +01:00
if err := row.Scan(&user.UserID, &user.Username, &user.Password, &user.RootPath, &user.Admin); err != nil {
2020-07-10 18:35:37 +02:00
return nil, errors.Wrap(err, "failed to scan user from database")
2020-01-31 18:51:24 +01:00
}
2020-01-31 17:36:48 +01:00
return &user, nil
}
2020-01-31 23:30:34 +01:00
func NewUsersFromRows(rows *sql.Rows) ([]*User, error) {
users := make([]*User, 0)
for rows.Next() {
var user User
if err := rows.Scan(&user.UserID, &user.Username, &user.Password, &user.RootPath, &user.Admin); err != nil {
2020-07-10 18:35:37 +02:00
return nil, errors.Wrap(err, "failed to scan users from database")
2020-01-31 23:30:34 +01:00
}
users = append(users, &user)
}
2020-02-28 20:57:46 +01:00
rows.Close()
2020-01-31 23:30:34 +01:00
return users, nil
}
2020-01-31 17:36:48 +01:00
func AuthorizeUser(database *sql.DB, username string, password string) (*User, error) {
2020-02-01 14:52:27 +01:00
row := database.QueryRow("SELECT * FROM user WHERE username = ?", username)
2020-01-31 17:36:48 +01:00
user, err := NewUserFromRow(row)
if err != nil {
2020-02-02 00:29:42 +01:00
if err == sql.ErrNoRows {
return nil, ErrorInvalidUserCredentials
} else {
return nil, err
}
2020-01-31 17:36:48 +01:00
}
2020-02-16 12:22:00 +01:00
if user.Password == nil {
return nil, errors.New("user does not have a password")
}
if err := bcrypt.CompareHashAndPassword([]byte(*user.Password), []byte(password)); err != nil {
2020-01-31 17:36:48 +01:00
if err == bcrypt.ErrMismatchedHashAndPassword {
2020-01-31 23:30:34 +01:00
return nil, ErrorInvalidUserCredentials
2020-01-31 17:36:48 +01:00
} else {
2020-07-10 18:35:37 +02:00
return nil, errors.Wrap(err, "compare user password hash")
2020-01-31 17:36:48 +01:00
}
}
return user, nil
}
2020-02-16 12:22:00 +01:00
var ErrorInvalidRootPath = errors.New("invalid root path")
func ValidRootPath(rootPath string) bool {
_, err := os.Stat(rootPath)
2020-01-31 17:36:48 +01:00
if err != nil {
2020-02-16 12:22:00 +01:00
log.Printf("Warn: invalid root path: '%s'\n%s\n", rootPath, err)
return false
2020-01-31 17:36:48 +01:00
}
2020-02-16 12:22:00 +01:00
return true
}
func RegisterUser(database *sql.Tx, username string, password *string, rootPath string, admin bool) (*User, error) {
if !ValidRootPath(rootPath) {
return nil, ErrorInvalidRootPath
}
if password != nil {
hashedPassBytes, err := bcrypt.GenerateFromPassword([]byte(*password), 12)
if err != nil {
2020-07-10 18:35:37 +02:00
return nil, errors.Wrap(err, "failed to hash password")
2020-02-16 12:22:00 +01:00
}
hashedPass := string(hashedPassBytes)
if _, err := database.Exec("INSERT INTO user (username, password, root_path, admin) VALUES (?, ?, ?, ?)", username, hashedPass, rootPath, admin); err != nil {
2020-07-10 18:35:37 +02:00
return nil, errors.Wrap(err, "insert new user with password into database")
2020-02-16 12:22:00 +01:00
}
} else {
if _, err := database.Exec("INSERT INTO user (username, root_path, admin) VALUES (?, ?, ?)", username, rootPath, admin); err != nil {
2020-07-10 18:35:37 +02:00
return nil, errors.Wrap(err, "insert user without password into database")
2020-02-16 12:22:00 +01:00
}
2020-01-31 17:36:48 +01:00
}
2020-02-01 14:52:27 +01:00
row := database.QueryRow("SELECT * FROM user WHERE username = ?", username)
2020-01-31 17:36:48 +01:00
if row == nil {
2020-01-31 23:30:34 +01:00
return nil, ErrorInvalidUserCredentials
2020-01-31 17:36:48 +01:00
}
user, err := NewUserFromRow(row)
if err != nil {
return nil, err
}
return user, nil
}
2020-01-31 18:51:24 +01:00
func (user *User) GenerateAccessToken(database *sql.Tx) (*AccessToken, error) {
2020-01-31 18:51:24 +01:00
bytes := make([]byte, 24)
if _, err := rand.Read(bytes); err != nil {
return nil, errors.New(fmt.Sprintf("Could not generate token: %s\n", err.Error()))
}
2020-02-02 00:29:42 +01:00
const CHARACTERS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
2020-01-31 18:51:24 +01:00
for i, b := range bytes {
bytes[i] = CHARACTERS[b%byte(len(CHARACTERS))]
}
token_value := string(bytes)
expire := time.Now().Add(14 * 24 * time.Hour)
expireString := expire.UTC().Format("2006-01-02 15:04:05")
2020-02-01 14:52:27 +01:00
if _, err := database.Exec("INSERT INTO access_token (value, expire, user_id) VALUES (?, ?, ?)", token_value, expireString, user.UserID); err != nil {
2020-01-31 18:51:24 +01:00
return nil, err
}
token := AccessToken{
Value: token_value,
Expire: expire,
}
return &token, nil
}
2020-01-31 23:30:34 +01:00
func VerifyTokenAndGetUser(database *sql.DB, token string) (*User, error) {
now := time.Now().UTC().Format("2006-01-02 15:04:05")
2020-02-01 14:52:27 +01:00
row := database.QueryRow("SELECT (user_id) FROM access_token WHERE expire > ? AND value = ?", now, token)
2020-01-31 23:30:34 +01:00
var userId string
if err := row.Scan(&userId); err != nil {
log.Println(err.Error())
return nil, err
}
2020-02-01 14:52:27 +01:00
row = database.QueryRow("SELECT * FROM user WHERE user_id = ?", userId)
2020-01-31 23:30:34 +01:00
user, err := NewUserFromRow(row)
if err != nil {
return nil, err
}
return user, nil
}