1
Fork 0

Refactoring + disable cleanup tests for sqlite

This commit is contained in:
viktorstrate 2021-11-06 15:41:25 +01:00
parent 18b83dbbea
commit 12085698c8
No known key found for this signature in database
GPG Key ID: 3F855605109C1E8A
13 changed files with 72 additions and 44 deletions

View File

@ -72,8 +72,8 @@ func GetSqliteAddress(path string) (*url.URL, error) {
func ConfigureDatabase(config *gorm.Config) (*gorm.DB, error) { func ConfigureDatabase(config *gorm.Config) (*gorm.DB, error) {
var databaseDialect gorm.Dialector var databaseDialect gorm.Dialector
switch drivers.DatabaseDriver() { switch drivers.DatabaseDriverFromEnv() {
case drivers.DatabaseDriverMysql: case drivers.MYSQL:
mysqlAddress, err := GetMysqlAddress(utils.EnvMysqlURL.GetValue()) mysqlAddress, err := GetMysqlAddress(utils.EnvMysqlURL.GetValue())
if err != nil { if err != nil {
return nil, err return nil, err
@ -81,7 +81,7 @@ func ConfigureDatabase(config *gorm.Config) (*gorm.DB, error) {
log.Printf("Connecting to MYSQL database: %s", mysqlAddress) log.Printf("Connecting to MYSQL database: %s", mysqlAddress)
databaseDialect = gorm_mysql.Open(mysqlAddress) databaseDialect = gorm_mysql.Open(mysqlAddress)
case drivers.DatabaseDriverSqlite: case drivers.SQLITE:
sqliteAddress, err := GetSqliteAddress(utils.EnvSqlitePath.GetValue()) sqliteAddress, err := GetSqliteAddress(utils.EnvSqlitePath.GetValue())
if err != nil { if err != nil {
return nil, err return nil, err
@ -89,7 +89,7 @@ func ConfigureDatabase(config *gorm.Config) (*gorm.DB, error) {
log.Printf("Opening SQLITE database: %s", sqliteAddress) log.Printf("Opening SQLITE database: %s", sqliteAddress)
databaseDialect = sqlite.Open(sqliteAddress.String()) databaseDialect = sqlite.Open(sqliteAddress.String())
case drivers.DatabaseDriverPostgres: case drivers.POSTGRES:
postgresAddress, err := GetPostgresAddress(utils.EnvPostgresURL.GetValue()) postgresAddress, err := GetPostgresAddress(utils.EnvPostgresURL.GetValue())
if err != nil { if err != nil {
return nil, err return nil, err
@ -104,7 +104,7 @@ func ConfigureDatabase(config *gorm.Config) (*gorm.DB, error) {
} }
// Manually enable foreign keys for sqlite, as this isn't done by default // Manually enable foreign keys for sqlite, as this isn't done by default
if drivers.DatabaseDriver() == drivers.DatabaseDriverSqlite { if drivers.SQLITE.MatchDatabase(db) {
db.Exec("PRAGMA foreign_keys = ON") db.Exec("PRAGMA foreign_keys = ON")
} }
@ -199,9 +199,9 @@ func MigrateDatabase(db *gorm.DB) error {
func ClearDatabase(db *gorm.DB) error { func ClearDatabase(db *gorm.DB) error {
err := db.Transaction(func(tx *gorm.DB) error { err := db.Transaction(func(tx *gorm.DB) error {
db_driver := drivers.DatabaseDriver() db_driver := drivers.DatabaseDriverFromEnv()
if db_driver == drivers.DatabaseDriverMysql { if db_driver == drivers.MYSQL {
if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0;").Error; err != nil { if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0;").Error; err != nil {
return err return err
} }
@ -213,15 +213,15 @@ func ClearDatabase(db *gorm.DB) error {
table := dry_run.Find(model).Statement.Table table := dry_run.Find(model).Statement.Table
switch db_driver { switch db_driver {
case drivers.DatabaseDriverPostgres: case drivers.POSTGRES:
if err := tx.Exec(fmt.Sprintf("TRUNCATE TABLE %s CASCADE", table)).Error; err != nil { if err := tx.Exec(fmt.Sprintf("TRUNCATE TABLE %s CASCADE", table)).Error; err != nil {
return err return err
} }
case drivers.DatabaseDriverMysql: case drivers.MYSQL:
if err := tx.Exec(fmt.Sprintf("TRUNCATE TABLE %s", table)).Error; err != nil { if err := tx.Exec(fmt.Sprintf("TRUNCATE TABLE %s", table)).Error; err != nil {
return err return err
} }
case drivers.DatabaseDriverSqlite: case drivers.SQLITE:
if err := tx.Exec(fmt.Sprintf("DELETE FROM %s", table)).Error; err != nil { if err := tx.Exec(fmt.Sprintf("DELETE FROM %s", table)).Error; err != nil {
return err return err
} }
@ -229,7 +229,7 @@ func ClearDatabase(db *gorm.DB) error {
} }
if db_driver == drivers.DatabaseDriverMysql { if db_driver == drivers.MYSQL {
if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 1;").Error; err != nil { if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 1;").Error; err != nil {
return err return err
} }

View File

@ -4,32 +4,52 @@ import (
"strings" "strings"
"github.com/photoview/photoview/api/utils" "github.com/photoview/photoview/api/utils"
"gorm.io/gorm"
) )
// DatabaseDriverType represents the name of a database driver // DatabaseDriverType represents the name of a database driver
type DatabaseDriverType string type DatabaseDriverType string
const ( const (
DatabaseDriverMysql DatabaseDriverType = "mysql" MYSQL DatabaseDriverType = "mysql"
DatabaseDriverSqlite DatabaseDriverType = "sqlite" SQLITE DatabaseDriverType = "sqlite"
DatabaseDriverPostgres DatabaseDriverType = "postgres" POSTGRES DatabaseDriverType = "postgres"
) )
func DatabaseDriver() DatabaseDriverType { func DatabaseDriverFromEnv() DatabaseDriverType {
var driver DatabaseDriverType var driver DatabaseDriverType
driverString := strings.ToLower(utils.EnvDatabaseDriver.GetValue()) driverString := strings.ToLower(utils.EnvDatabaseDriver.GetValue())
switch driverString { switch driverString {
case "mysql": case "mysql":
driver = DatabaseDriverMysql driver = MYSQL
case "sqlite": case "sqlite":
driver = DatabaseDriverSqlite driver = SQLITE
case "postgres": case "postgres":
driver = DatabaseDriverPostgres driver = POSTGRES
default: default:
driver = DatabaseDriverMysql driver = MYSQL
} }
return driver return driver
} }
func (driver DatabaseDriverType) MatchDatabase(db *gorm.DB) bool {
return db.Dialector.Name() == string(driver)
}
func GetDatabaseDriverType(db *gorm.DB) (driver DatabaseDriverType) {
switch db.Dialector.Name() {
case "mysql":
driver = MYSQL
case "sqlite":
driver = SQLITE
case "postgres":
driver = POSTGRES
default:
driver = MYSQL
}
return
}

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"log" "log"
"github.com/photoview/photoview/api/database/drivers"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -22,12 +23,10 @@ func DateExtract(db *gorm.DB, component DateComponent, attribute string) string
var result string var result string
switch db.Dialector.Name() { switch drivers.GetDatabaseDriverType(db) {
case "mysql", "postgres": case drivers.MYSQL, drivers.POSTGRES:
result = fmt.Sprintf("EXTRACT(%s FROM %s)", component, attribute) result = fmt.Sprintf("EXTRACT(%s FROM %s)", component, attribute)
break case drivers.SQLITE:
case "sqlite":
var sqliteFormatted string var sqliteFormatted string
switch component { switch component {
case DateCompYear: case DateCompYear:
@ -39,9 +38,8 @@ func DateExtract(db *gorm.DB, component DateComponent, attribute string) string
} }
result = fmt.Sprintf("CAST(strftime('%s', %s) AS INTEGER)", sqliteFormatted, attribute) result = fmt.Sprintf("CAST(strftime('%s', %s) AS INTEGER)", sqliteFormatted, attribute)
break
default: default:
log.Panicf("unsupported database backend: %s", db.Dialector.Name()) log.Panicf("unsupported database backend: %s", drivers.GetDatabaseDriverType(db))
} }
return result return result

View File

@ -3,6 +3,7 @@ package actions
import ( import (
"strings" "strings"
"github.com/photoview/photoview/api/database/drivers"
"github.com/photoview/photoview/api/graphql/models" "github.com/photoview/photoview/api/graphql/models"
"github.com/pkg/errors" "github.com/pkg/errors"
"gorm.io/gorm" "gorm.io/gorm"
@ -26,7 +27,7 @@ func Search(db *gorm.DB, query string, userID int, _limitMedia *int, _limitAlbum
var media []*models.Media var media []*models.Media
userSubquery := db.Table("user_albums").Where("user_id = ?", userID) userSubquery := db.Table("user_albums").Where("user_id = ?", userID)
if db.Dialector.Name() == "postgres" { if drivers.POSTGRES.MatchDatabase(db) {
userSubquery = userSubquery.Where("album_id = \"Album\".id") userSubquery = userSubquery.Where("album_id = \"Album\".id")
} else { } else {
userSubquery = userSubquery.Where("album_id = Album.id") userSubquery = userSubquery.Where("album_id = Album.id")

View File

@ -3,6 +3,7 @@ package actions
import ( import (
"time" "time"
"github.com/photoview/photoview/api/database/drivers"
"github.com/photoview/photoview/api/graphql/auth" "github.com/photoview/photoview/api/graphql/auth"
"github.com/photoview/photoview/api/graphql/models" "github.com/photoview/photoview/api/graphql/models"
"github.com/photoview/photoview/api/utils" "github.com/photoview/photoview/api/utils"
@ -15,7 +16,7 @@ func AddMediaShare(db *gorm.DB, user *models.User, mediaID int, expire *time.Tim
var media models.Media var media models.Media
var query string var query string
if db.Dialector.Name() == "postgres" { if drivers.POSTGRES.MatchDatabase(db) {
query = "EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = \"Album\".id AND user_albums.user_id = ?)" query = "EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = \"Album\".id AND user_albums.user_id = ?)"
} else { } else {
query = "EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = Album.id AND user_albums.user_id = ?)" query = "EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = Album.id AND user_albums.user_id = ?)"
@ -146,7 +147,7 @@ func hashSharePassword(password *string) (*string, error) {
func getUserToken(db *gorm.DB, userID int, tokenValue string) (*models.ShareToken, error) { func getUserToken(db *gorm.DB, userID int, tokenValue string) (*models.ShareToken, error) {
var query string var query string
if db.Dialector.Name() == "postgres" { if drivers.POSTGRES.MatchDatabase(db) {
query = "\"Owner\".id = ? OR \"Owner\".admin = TRUE" query = "\"Owner\".id = ? OR \"Owner\".admin = TRUE"
} else { } else {
query = "Owner.id = ? OR Owner.admin = TRUE" query = "Owner.id = ? OR Owner.admin = TRUE"

View File

@ -3,6 +3,7 @@ package actions
import ( import (
"time" "time"
"github.com/photoview/photoview/api/database/drivers"
"github.com/photoview/photoview/api/graphql/models" "github.com/photoview/photoview/api/graphql/models"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -13,15 +14,15 @@ func MyTimeline(db *gorm.DB, user *models.User, paginate *models.Pagination, onl
Joins("JOIN albums ON media.album_id = albums.id"). Joins("JOIN albums ON media.album_id = albums.id").
Where("albums.id IN (?)", db.Table("user_albums").Select("user_albums.album_id").Where("user_id = ?", user.ID)) Where("albums.id IN (?)", db.Table("user_albums").Select("user_albums.album_id").Where("user_id = ?", user.ID))
switch db.Dialector.Name() { switch drivers.GetDatabaseDriverType(db) {
case "postgres": case drivers.POSTGRES:
query = query. query = query.
Order("DATE_TRUNC('year', date_shot) DESC"). Order("DATE_TRUNC('year', date_shot) DESC").
Order("DATE_TRUNC('month', date_shot) DESC"). Order("DATE_TRUNC('month', date_shot) DESC").
Order("DATE_TRUNC('day', date_shot) DESC"). Order("DATE_TRUNC('day', date_shot) DESC").
Order("albums.title ASC"). Order("albums.title ASC").
Order("media.date_shot DESC") Order("media.date_shot DESC")
case "sqlite": case drivers.SQLITE:
query = query. query = query.
Order("strftime('%j', media.date_shot) DESC"). // convert to day of year 001-366 Order("strftime('%j', media.date_shot) DESC"). // convert to day of year 001-366
Order("albums.title ASC"). Order("albums.title ASC").

View File

@ -10,6 +10,7 @@ import (
"strings" "strings"
"github.com/Kagami/go-face" "github.com/Kagami/go-face"
"github.com/photoview/photoview/api/database/drivers"
"github.com/photoview/photoview/api/scanner/media_encoding/media_utils" "github.com/photoview/photoview/api/scanner/media_encoding/media_utils"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
@ -48,10 +49,10 @@ type FaceDescriptor face.Descriptor
// GormDataType datatype used in database // GormDataType datatype used in database
func (FaceDescriptor) GormDBDataType(db *gorm.DB, field *schema.Field) string { func (FaceDescriptor) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() { switch drivers.GetDatabaseDriverType(db) {
case "mysql", "sqlite": case drivers.MYSQL, drivers.SQLITE:
return "BLOB" return "BLOB"
case "postgres": case drivers.POSTGRES:
return "BYTEA" return "BYTEA"
} }
return "" return ""

View File

@ -16,9 +16,9 @@ func (SiteInfo) TableName() string {
return "site_info" return "site_info"
} }
func DefaultSiteInfo() SiteInfo { func DefaultSiteInfo(db *gorm.DB) SiteInfo {
defaultConcurrentWorkers := 3 defaultConcurrentWorkers := 3
if db_drivers.DatabaseDriver() == db_drivers.DatabaseDriverSqlite { if db_drivers.SQLITE.MatchDatabase(db) {
defaultConcurrentWorkers = 1 defaultConcurrentWorkers = 1
} }
@ -39,7 +39,7 @@ func GetSiteInfo(db *gorm.DB) (*SiteInfo, error) {
} }
if len(siteInfo) == 0 { if len(siteInfo) == 0 {
newSiteInfo := DefaultSiteInfo() newSiteInfo := DefaultSiteInfo(db)
if err := db.Create(&newSiteInfo).Error; err != nil { if err := db.Create(&newSiteInfo).Error; err != nil {
return nil, errors.Wrap(err, "initialize site_info") return nil, errors.Wrap(err, "initialize site_info")

View File

@ -17,7 +17,7 @@ func TestSiteInfo(t *testing.T) {
return return
} }
assert.Equal(t, models.DefaultSiteInfo(), *site_info) assert.Equal(t, models.DefaultSiteInfo(db), *site_info)
site_info.InitialSetup = false site_info.InitialSetup = false
site_info.PeriodicScanInterval = 360 site_info.PeriodicScanInterval = 360

View File

@ -67,7 +67,7 @@ func (r *mutationResolver) SetScannerConcurrentWorkers(ctx context.Context, work
return 0, errors.New("concurrent workers must at least be 1") return 0, errors.New("concurrent workers must at least be 1")
} }
if workers > 1 && drivers.DatabaseDriver() == drivers.DatabaseDriverSqlite { if workers > 1 && drivers.DatabaseDriverFromEnv() == drivers.SQLITE {
return 0, errors.New("multiple workers not supported for SQLite databases") return 0, errors.New("multiple workers not supported for SQLite databases")
} }

View File

@ -10,6 +10,7 @@ import (
"strings" "strings"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/photoview/photoview/api/database/drivers"
"github.com/photoview/photoview/api/graphql/models" "github.com/photoview/photoview/api/graphql/models"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -37,7 +38,7 @@ func RegisterDownloadRoutes(db *gorm.DB, router *mux.Router) {
} }
var mediaWhereQuery string var mediaWhereQuery string
if db.Dialector.Name() == "postgres" { if drivers.POSTGRES.MatchDatabase(db) {
mediaWhereQuery = "\"Media\".album_id = ?" mediaWhereQuery = "\"Media\".album_id = ?"
} else { } else {
mediaWhereQuery = "Media.album_id = ?" mediaWhereQuery = "Media.album_id = ?"

View File

@ -26,7 +26,7 @@ func CleanupMedia(db *gorm.DB, albumId int, albumMedia []*models.Media) []error
// Select media from database that was not found on hard disk // Select media from database that was not found on hard disk
if len(albumMedia) > 0 { if len(albumMedia) > 0 {
query.Where("NOT id IN (?)", albumMediaIds) query = query.Where("NOT id IN (?)", albumMediaIds)
} }
if err := query.Find(&mediaList).Error; err != nil { if err := query.Find(&mediaList).Error; err != nil {
@ -48,7 +48,7 @@ func CleanupMedia(db *gorm.DB, albumId int, albumMedia []*models.Media) []error
} }
if len(mediaIDs) > 0 { if len(mediaIDs) > 0 {
if err := db.Delete(models.Media{}, mediaIDs).Error; err != nil { if err := db.Where("id IN (?)", mediaIDs).Delete(models.Media{}).Error; err != nil {
deleteErrors = append(deleteErrors, errors.Wrap(err, "delete old media from database")) deleteErrors = append(deleteErrors, errors.Wrap(err, "delete old media from database"))
} }

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/otiai10/copy" "github.com/otiai10/copy"
"github.com/photoview/photoview/api/database/drivers"
"github.com/photoview/photoview/api/graphql/models" "github.com/photoview/photoview/api/graphql/models"
"github.com/photoview/photoview/api/scanner/face_detection" "github.com/photoview/photoview/api/scanner/face_detection"
"github.com/photoview/photoview/api/test_utils" "github.com/photoview/photoview/api/test_utils"
@ -16,6 +17,11 @@ func TestCleanupMedia(t *testing.T) {
test_utils.FilesystemTest(t) test_utils.FilesystemTest(t)
db := test_utils.DatabaseTest(t) db := test_utils.DatabaseTest(t)
// Sqlite doesn't seem to support foreign key cascading
if drivers.SQLITE.MatchDatabase(db) {
t.SkipNow()
}
if !assert.NoError(t, face_detection.InitializeFaceDetector(db)) { if !assert.NoError(t, face_detection.InitializeFaceDetector(db)) {
return return
} }
@ -69,7 +75,6 @@ func TestCleanupMedia(t *testing.T) {
} }
t.Run("Modify albums", func(t *testing.T) { t.Run("Modify albums", func(t *testing.T) {
test_utils.RunScannerOnUser(t, db, user1) test_utils.RunScannerOnUser(t, db, user1)
assert.Equal(t, 9, countAllMedia()) assert.Equal(t, 9, countAllMedia())
assert.Equal(t, 18, countAllMediaURLs()) assert.Equal(t, 18, countAllMediaURLs())