Fix faces not getting scanned
- This fixes #344 - Add integration tests for face recognition - Properly check that the user own the queried album
This commit is contained in:
parent
d03923992c
commit
1029b61a4c
|
@ -36,6 +36,8 @@ models:
|
||||||
resolver: true
|
resolver: true
|
||||||
faces:
|
faces:
|
||||||
resolver: true
|
resolver: true
|
||||||
|
mediaType:
|
||||||
|
resolver: true
|
||||||
MediaURL:
|
MediaURL:
|
||||||
model: github.com/photoview/photoview/api/graphql/models.MediaURL
|
model: github.com/photoview/photoview/api/graphql/models.MediaURL
|
||||||
MediaEXIF:
|
MediaEXIF:
|
||||||
|
@ -60,3 +62,5 @@ models:
|
||||||
model: github.com/photoview/photoview/api/graphql/models.FaceRectangle
|
model: github.com/photoview/photoview/api/graphql/models.FaceRectangle
|
||||||
SiteInfo:
|
SiteInfo:
|
||||||
model: github.com/photoview/photoview/api/graphql/models.SiteInfo
|
model: github.com/photoview/photoview/api/graphql/models.SiteInfo
|
||||||
|
MediaType:
|
||||||
|
model: github.com/photoview/photoview/api/graphql/models.MediaType
|
||||||
|
|
|
@ -11993,13 +11993,19 @@ func (ec *executionContext) marshalNMediaDownload2ᚖgithubᚗcomᚋphotoviewᚋ
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ec *executionContext) unmarshalNMediaType2githubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaType(ctx context.Context, v interface{}) (models.MediaType, error) {
|
func (ec *executionContext) unmarshalNMediaType2githubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaType(ctx context.Context, v interface{}) (models.MediaType, error) {
|
||||||
var res models.MediaType
|
tmp, err := graphql.UnmarshalString(v)
|
||||||
err := res.UnmarshalGQL(v)
|
res := models.MediaType(tmp)
|
||||||
return res, graphql.ErrorOnPath(ctx, err)
|
return res, graphql.ErrorOnPath(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ec *executionContext) marshalNMediaType2githubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaType(ctx context.Context, sel ast.SelectionSet, v models.MediaType) graphql.Marshaler {
|
func (ec *executionContext) marshalNMediaType2githubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaType(ctx context.Context, sel ast.SelectionSet, v models.MediaType) graphql.Marshaler {
|
||||||
return v
|
res := graphql.MarshalString(string(v))
|
||||||
|
if res == graphql.Null {
|
||||||
|
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
|
||||||
|
ec.Errorf(ctx, "must not be null")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ec *executionContext) marshalNMediaURL2ᚖgithubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaURL(ctx context.Context, sel ast.SelectionSet, v *models.MediaURL) graphql.Marshaler {
|
func (ec *executionContext) marshalNMediaURL2ᚖgithubᚗcomᚋphotoviewᚋphotoviewᚋapiᚋgraphqlᚋmodelsᚐMediaURL(ctx context.Context, sel ast.SelectionSet, v *models.MediaURL) graphql.Marshaler {
|
||||||
|
|
|
@ -32,7 +32,10 @@ func (a *Album) BeforeSave(tx *gorm.DB) (err error) {
|
||||||
// GetChildren performs a recursive query to get all the children of the album.
|
// GetChildren performs a recursive query to get all the children of the album.
|
||||||
// An optional filter can be provided that can be used to modify the query on the children.
|
// An optional filter can be provided that can be used to modify the query on the children.
|
||||||
func (a *Album) GetChildren(db *gorm.DB, filter func(*gorm.DB) *gorm.DB) (children []*Album, err error) {
|
func (a *Album) GetChildren(db *gorm.DB, filter func(*gorm.DB) *gorm.DB) (children []*Album, err error) {
|
||||||
// SELECT * FROM sub_albums
|
return GetChildrenFromAlbums(db, filter, []int{a.ID})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetChildrenFromAlbums(db *gorm.DB, filter func(*gorm.DB) *gorm.DB, albumIDs []int) (children []*Album, err error) {
|
||||||
query := db.Model(&Album{}).Table("sub_albums")
|
query := db.Model(&Album{}).Table("sub_albums")
|
||||||
|
|
||||||
if filter != nil {
|
if filter != nil {
|
||||||
|
@ -41,13 +44,13 @@ func (a *Album) GetChildren(db *gorm.DB, filter func(*gorm.DB) *gorm.DB) (childr
|
||||||
|
|
||||||
err = db.Raw(`
|
err = db.Raw(`
|
||||||
WITH recursive sub_albums AS (
|
WITH recursive sub_albums AS (
|
||||||
SELECT * FROM albums AS root WHERE id = ?
|
SELECT * FROM albums AS root WHERE id IN (?)
|
||||||
UNION ALL
|
UNION ALL
|
||||||
SELECT child.* FROM albums AS child JOIN sub_albums ON child.parent_album_id = sub_albums.id
|
SELECT child.* FROM albums AS child JOIN sub_albums ON child.parent_album_id = sub_albums.id
|
||||||
)
|
)
|
||||||
|
|
||||||
?
|
?
|
||||||
`, a.ID, query).Find(&children).Error
|
`, albumIDs, query).Find(&children).Error
|
||||||
|
|
||||||
return children, err
|
return children, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,47 +121,6 @@ func (e LanguageTranslation) MarshalGQL(w io.Writer) {
|
||||||
fmt.Fprint(w, strconv.Quote(e.String()))
|
fmt.Fprint(w, strconv.Quote(e.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
type MediaType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
MediaTypePhoto MediaType = "Photo"
|
|
||||||
MediaTypeVideo MediaType = "Video"
|
|
||||||
)
|
|
||||||
|
|
||||||
var AllMediaType = []MediaType{
|
|
||||||
MediaTypePhoto,
|
|
||||||
MediaTypeVideo,
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e MediaType) IsValid() bool {
|
|
||||||
switch e {
|
|
||||||
case MediaTypePhoto, MediaTypeVideo:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e MediaType) String() string {
|
|
||||||
return string(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *MediaType) UnmarshalGQL(v interface{}) error {
|
|
||||||
str, ok := v.(string)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("enums must be strings")
|
|
||||||
}
|
|
||||||
|
|
||||||
*e = MediaType(str)
|
|
||||||
if !e.IsValid() {
|
|
||||||
return fmt.Errorf("%s is not a valid MediaType", str)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e MediaType) MarshalGQL(w io.Writer) {
|
|
||||||
fmt.Fprint(w, strconv.Quote(e.String()))
|
|
||||||
}
|
|
||||||
|
|
||||||
type NotificationType string
|
type NotificationType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -42,30 +42,19 @@ func (m *Media) BeforeSave(tx *gorm.DB) error {
|
||||||
// Update path hash
|
// Update path hash
|
||||||
m.PathHash = MD5Hash(m.Path)
|
m.PathHash = MD5Hash(m.Path)
|
||||||
|
|
||||||
// Save media type as lowercase for better compatibility
|
|
||||||
m.Type = MediaType(strings.ToLower(string(m.Type)))
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Media) AfterFind(tx *gorm.DB) error {
|
type MediaType string
|
||||||
|
|
||||||
// Convert lowercased media type back
|
const (
|
||||||
lowercasedType := strings.ToLower(string(m.Type))
|
MediaTypePhoto MediaType = "photo"
|
||||||
foundType := false
|
MediaTypeVideo MediaType = "video"
|
||||||
for _, t := range AllMediaType {
|
)
|
||||||
if strings.ToLower(string(t)) == lowercasedType {
|
|
||||||
m.Type = t
|
|
||||||
foundType = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if foundType == false {
|
var AllMediaType = []MediaType{
|
||||||
return errors.New(fmt.Sprintf("Failed to parse media from DB: Invalid media type: %s", m.Type))
|
MediaTypePhoto,
|
||||||
}
|
MediaTypeVideo,
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type MediaPurpose string
|
type MediaPurpose string
|
||||||
|
|
|
@ -173,11 +173,29 @@ func (user *User) FillAlbums(db *gorm.DB) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) OwnsAlbum(db *gorm.DB, album *Album) (bool, error) {
|
func (user *User) OwnsAlbum(db *gorm.DB, album *Album) (bool, error) {
|
||||||
// TODO: Implement this
|
|
||||||
panic("not implemented")
|
if err := user.FillAlbums(db); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
albumIDs := make([]int, 0)
|
||||||
|
for _, a := range user.Albums {
|
||||||
|
albumIDs = append(albumIDs, a.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := func(query *gorm.DB) *gorm.DB {
|
||||||
|
return query.Where("id = ?", album.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
ownedAlbum, err := GetChildrenFromAlbums(db, filter, albumIDs)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(ownedAlbum) > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) OwnsMedia(db *gorm.DB, media *Media) (bool, error) {
|
// func (user *User) OwnsMedia(db *gorm.DB, media *Media) (bool, error) {
|
||||||
// TODO: implement this
|
// // TODO: implement this
|
||||||
panic("not implemented")
|
// panic("not implemented")
|
||||||
}
|
// }
|
||||||
|
|
|
@ -69,15 +69,6 @@ func (r faceGroupResolver) ImageFaces(ctx context.Context, obj *models.FaceGroup
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Database.Transaction(func(tx *gorm.DB) error {
|
|
||||||
for i := range imageFaces {
|
|
||||||
if err := imageFaces[i].Media.AfterFind(tx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return imageFaces, nil
|
return imageFaces, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package resolvers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
api "github.com/photoview/photoview/api/graphql"
|
api "github.com/photoview/photoview/api/graphql"
|
||||||
"github.com/photoview/photoview/api/graphql/auth"
|
"github.com/photoview/photoview/api/graphql/auth"
|
||||||
|
@ -108,6 +109,11 @@ func (r *Resolver) Media() api.MediaResolver {
|
||||||
return &mediaResolver{r}
|
return &mediaResolver{r}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *mediaResolver) MediaType(ctx context.Context, media *models.Media) (*models.MediaType, error) {
|
||||||
|
formattedType := models.MediaType(strings.Title(string(media.Type)))
|
||||||
|
return &formattedType, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *mediaResolver) Shares(ctx context.Context, media *models.Media) ([]*models.ShareToken, error) {
|
func (r *mediaResolver) Shares(ctx context.Context, media *models.Media) ([]*models.ShareToken, error) {
|
||||||
var shareTokens []*models.ShareToken
|
var shareTokens []*models.ShareToken
|
||||||
if err := r.Database.Where("media_id = ?", media.ID).Find(&shareTokens).Error; err != nil {
|
if err := r.Database.Where("media_id = ?", media.ID).Find(&shareTokens).Error; err != nil {
|
||||||
|
|
|
@ -2,11 +2,11 @@ package face_detection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/Kagami/go-face"
|
"github.com/Kagami/go-face"
|
||||||
"github.com/photoview/photoview/api/graphql/models"
|
"github.com/photoview/photoview/api/graphql/models"
|
||||||
|
"github.com/photoview/photoview/api/utils"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,7 @@ func InitializeFaceDetector(db *gorm.DB) error {
|
||||||
|
|
||||||
log.Println("Initializing face detector")
|
log.Println("Initializing face detector")
|
||||||
|
|
||||||
rec, err := face.NewRecognizer(filepath.Join("data", "models"))
|
rec, err := face.NewRecognizer(utils.FaceRecognitionModelsPath())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "initialize facedetect recognizer")
|
return errors.Wrap(err, "initialize facedetect recognizer")
|
||||||
}
|
}
|
||||||
|
@ -107,6 +107,8 @@ func (fd *FaceDetector) DetectFaces(db *gorm.DB, media *models.Media) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("Face thumb path: %v %v\n", thumbnailPath, fd)
|
||||||
|
|
||||||
fd.mutex.Lock()
|
fd.mutex.Lock()
|
||||||
faces, err := fd.rec.RecognizeFile(thumbnailPath)
|
faces, err := fd.rec.RecognizeFile(thumbnailPath)
|
||||||
fd.mutex.Unlock()
|
fd.mutex.Unlock()
|
||||||
|
|
|
@ -3,9 +3,11 @@ package scanner_test
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/photoview/photoview/api/graphql/models"
|
"github.com/photoview/photoview/api/graphql/models"
|
||||||
"github.com/photoview/photoview/api/scanner"
|
"github.com/photoview/photoview/api/scanner"
|
||||||
|
"github.com/photoview/photoview/api/scanner/face_detection"
|
||||||
"github.com/photoview/photoview/api/test_utils"
|
"github.com/photoview/photoview/api/test_utils"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
@ -42,6 +44,10 @@ func TestFullScan(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !assert.NoError(t, face_detection.InitializeFaceDetector(db)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !assert.NoError(t, scanner.AddUserToQueue(user)) {
|
if !assert.NoError(t, scanner.AddUserToQueue(user)) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -54,6 +60,32 @@ func TestFullScan(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, 10, len(all_media))
|
assert.Equal(t, 9, len(all_media))
|
||||||
|
|
||||||
|
var all_media_url []*models.MediaURL
|
||||||
|
if !assert.NoError(t, db.Find(&all_media_url).Error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 18, len(all_media_url))
|
||||||
|
|
||||||
|
// Verify that faces was recognized
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
var all_face_groups []*models.FaceGroup
|
||||||
|
if !assert.NoError(t, db.Find(&all_face_groups).Error) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(all_face_groups) == 3
|
||||||
|
}, time.Second*5, time.Millisecond*500)
|
||||||
|
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
var all_image_faces []*models.ImageFace
|
||||||
|
if !assert.NoError(t, db.Find(&all_image_faces).Error) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(all_image_faces) == 6
|
||||||
|
}, time.Second*5, time.Millisecond*500)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 292 KiB |
|
@ -32,11 +32,12 @@ func UnitTestRun(m *testing.M) int {
|
||||||
func IntegrationTestRun(m *testing.M) int {
|
func IntegrationTestRun(m *testing.M) int {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
_, file, _, ok := runtime.Caller(0)
|
||||||
|
if !ok {
|
||||||
|
log.Fatal("could not get runtime file path")
|
||||||
|
}
|
||||||
|
|
||||||
if *integration_flags.Database {
|
if *integration_flags.Database {
|
||||||
_, file, _, ok := runtime.Caller(0)
|
|
||||||
if !ok {
|
|
||||||
log.Fatal("could not get runtime file path")
|
|
||||||
}
|
|
||||||
|
|
||||||
envPath := path.Join(path.Dir(file), "..", "testing.env")
|
envPath := path.Join(path.Dir(file), "..", "testing.env")
|
||||||
|
|
||||||
|
@ -45,6 +46,9 @@ func IntegrationTestRun(m *testing.M) int {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
faceModelsPath := path.Join(path.Dir(file), "..", "data", "models")
|
||||||
|
utils.ConfigureTestFaceRecognitionModelsPath(faceModelsPath)
|
||||||
|
|
||||||
result := m.Run()
|
result := m.Run()
|
||||||
|
|
||||||
test_dbm.Close()
|
test_dbm.Close()
|
||||||
|
|
|
@ -7,10 +7,11 @@ type EnvironmentVariable string
|
||||||
|
|
||||||
// General options
|
// General options
|
||||||
const (
|
const (
|
||||||
EnvDevelopmentMode EnvironmentVariable = "PHOTOVIEW_DEVELOPMENT_MODE"
|
EnvDevelopmentMode EnvironmentVariable = "PHOTOVIEW_DEVELOPMENT_MODE"
|
||||||
EnvServeUI EnvironmentVariable = "PHOTOVIEW_SERVE_UI"
|
EnvServeUI EnvironmentVariable = "PHOTOVIEW_SERVE_UI"
|
||||||
EnvUIPath EnvironmentVariable = "PHOTOVIEW_UI_PATH"
|
EnvUIPath EnvironmentVariable = "PHOTOVIEW_UI_PATH"
|
||||||
EnvMediaCachePath EnvironmentVariable = "PHOTOVIEW_MEDIA_CACHE"
|
EnvMediaCachePath EnvironmentVariable = "PHOTOVIEW_MEDIA_CACHE"
|
||||||
|
EnvFaceRecognitionModelsPath EnvironmentVariable = "PHOTOVIEW_FACE_RECOGNITION_MODELS_PATH"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Network related
|
// Network related
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"path"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GenerateToken() string {
|
func GenerateToken() string {
|
||||||
|
@ -61,3 +62,21 @@ func MediaCachePath() string {
|
||||||
|
|
||||||
return photoCache
|
return photoCache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var test_face_recognition_models_path string = ""
|
||||||
|
|
||||||
|
func ConfigureTestFaceRecognitionModelsPath(path string) {
|
||||||
|
test_face_recognition_models_path = path
|
||||||
|
}
|
||||||
|
|
||||||
|
func FaceRecognitionModelsPath() string {
|
||||||
|
if test_face_recognition_models_path != "" {
|
||||||
|
return test_face_recognition_models_path
|
||||||
|
}
|
||||||
|
|
||||||
|
if EnvFaceRecognitionModelsPath.GetValue() == "" {
|
||||||
|
return path.Join("data", "models")
|
||||||
|
}
|
||||||
|
|
||||||
|
return EnvFaceRecognitionModelsPath.GetValue()
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue