1
Fork 0
photoview/api/graphql/models/user.go

197 lines
4.8 KiB
Go

package models
import (
"crypto/rand"
"fmt"
"log"
"os"
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
gorm.Model
Username string `gorm:"unique,size:128"`
Password *string `gorm:"size:256`
RootPath string `gorm:"size:512`
Admin bool `gorm:"default:false"`
}
// func (u *User) ID() int {
// return u.UserID
// }
type AccessToken struct {
gorm.Model
UserID uint
User User `gorm:"constraint:OnDelete:CASCADE;"`
Value string `gorm:"size:24`
Expire time.Time
}
var ErrorInvalidUserCredentials = errors.New("invalid credentials")
// func NewUserFromRow(row *sql.Row) (*User, error) {
// user := User{}
// if err := row.Scan(&user.UserID, &user.Username, &user.Password, &user.RootPath, &user.Admin); err != nil {
// return nil, errors.Wrap(err, "failed to scan user from database")
// }
// return &user, nil
// }
// func NewUsersFromRows(rows *sql.Rows) ([]*User, error) {
// users := make([]*User, 0)
// for rows.Next() {
// var user User
// if err := rows.Scan(&user.UserID, &user.Username, &user.Password, &user.RootPath, &user.Admin); err != nil {
// return nil, errors.Wrap(err, "failed to scan users from database")
// }
// users = append(users, &user)
// }
// rows.Close()
// return users, nil
// }
func AuthorizeUser(db *gorm.DB, username string, password string) (*User, error) {
// row := database.QueryRow("SELECT * FROM user WHERE username = ?", username)
var user User
result := db.Where("username = ?", username).First(&user)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, ErrorInvalidUserCredentials
}
return nil, errors.Wrap(result.Error, "failed to get user by username when authorizing")
}
if user.Password == nil {
return nil, errors.New("user does not have a password")
}
if err := bcrypt.CompareHashAndPassword([]byte(*user.Password), []byte(password)); err != nil {
if err == bcrypt.ErrMismatchedHashAndPassword {
return nil, ErrorInvalidUserCredentials
} else {
return nil, errors.Wrap(err, "compare user password hash")
}
}
return &user, nil
}
var ErrorInvalidRootPath = errors.New("invalid root path")
func ValidRootPath(rootPath string) bool {
_, err := os.Stat(rootPath)
if err != nil {
log.Printf("Warn: invalid root path: '%s'\n%s\n", rootPath, err)
return false
}
return true
}
func RegisterUser(db *gorm.DB, username string, password *string, rootPath string, admin bool) (*User, error) {
if !ValidRootPath(rootPath) {
return nil, ErrorInvalidRootPath
}
user := User{
Username: username,
RootPath: rootPath,
Admin: admin,
}
if password != nil {
hashedPassBytes, err := bcrypt.GenerateFromPassword([]byte(*password), 12)
if err != nil {
return nil, errors.Wrap(err, "failed to hash password")
}
hashedPass := string(hashedPassBytes)
user.Password = &hashedPass
}
result := db.Create(&user)
if result.Error != nil {
return nil, errors.Wrap(result.Error, "insert new user with password into database")
}
return &user, nil
}
func (user *User) GenerateAccessToken(db *gorm.DB) (*AccessToken, error) {
bytes := make([]byte, 24)
if _, err := rand.Read(bytes); err != nil {
return nil, errors.New(fmt.Sprintf("Could not generate token: %s\n", err.Error()))
}
const CHARACTERS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
for i, b := range bytes {
bytes[i] = CHARACTERS[b%byte(len(CHARACTERS))]
}
token_value := string(bytes)
expire := time.Now().Add(14 * 24 * time.Hour)
// expireString := expire.UTC().Format("2006-01-02 15:04:05")
// if _, err := database.Exec("INSERT INTO access_token (value, expire, user_id) VALUES (?, ?, ?)", token_value, expireString, user.UserID); err != nil {
// return nil, err
// }
token := AccessToken{
UserID: user.ID,
Value: token_value,
Expire: expire,
}
result := db.Create(&token)
if result.Error != nil {
return nil, errors.Wrap(result.Error, "saving access token to database")
}
return &token, nil
}
func VerifyTokenAndGetUser(db *gorm.DB, token string) (*User, error) {
now := time.Now().UTC().Format("2006-01-02 15:04:05")
// row := database.QueryRow("SELECT (user_id) FROM access_token WHERE expire > ? AND value = ?", now, token)
var accessToken AccessToken
result := db.Where("expire > ? AND value = ?", now, token).First(&accessToken)
if result.Error != nil {
return nil, result.Error
}
// var userId string
// if err := row.Scan(&userId); err != nil {
// log.Println(err.Error())
// return nil, err
// }
// row = db.QueryRow("SELECT * FROM user WHERE user_id = ?", userId)
// user, err := NewUserFromRow(row)
// if err != nil {
// return nil, err
// }
var user User
result = db.First(&user, accessToken.ID)
if result.Error != nil {
return nil, result.Error
}
return &user, nil
}