Pass database as argument to individual face_detector functions
This allows the face detector to use transactions, such that faces can be detected on media that has not been fully commited yet. This solves #214
This commit is contained in:
parent
ba16fc1caa
commit
3ae92086cd
|
@ -361,7 +361,7 @@ func (r *mutationResolver) UserRemoveRootAlbum(ctx context.Context, userID int,
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := face_detection.GlobalFaceDetector.ReloadFacesFromDatabase(); err != nil {
|
if err := face_detection.GlobalFaceDetector.ReloadFacesFromDatabase(tx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,7 @@ func CleanupMedia(db *gorm.DB, albumId int, albumMedia []*models.Media) []error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload faces after deleting media
|
// Reload faces after deleting media
|
||||||
if err := face_detection.GlobalFaceDetector.ReloadFacesFromDatabase(); err != nil {
|
if err := face_detection.GlobalFaceDetector.ReloadFacesFromDatabase(db); err != nil {
|
||||||
deleteErrors = append(deleteErrors, errors.Wrap(err, "reload faces from database"))
|
deleteErrors = append(deleteErrors, errors.Wrap(err, "reload faces from database"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,7 +123,7 @@ func deleteOldUserAlbums(db *gorm.DB, scannedAlbums []*models.Album, user *model
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload faces after deleting albums
|
// Reload faces after deleting albums
|
||||||
if err := face_detection.GlobalFaceDetector.ReloadFacesFromDatabase(); err != nil {
|
if err := face_detection.GlobalFaceDetector.ReloadFacesFromDatabase(db); err != nil {
|
||||||
deleteErrors = append(deleteErrors, err)
|
deleteErrors = append(deleteErrors, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
|
|
||||||
type FaceDetector struct {
|
type FaceDetector struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
db *gorm.DB
|
|
||||||
rec *face.Recognizer
|
rec *face.Recognizer
|
||||||
faceDescriptors []face.Descriptor
|
faceDescriptors []face.Descriptor
|
||||||
faceGroupIDs []int32
|
faceGroupIDs []int32
|
||||||
|
@ -37,7 +36,6 @@ func InitializeFaceDetector(db *gorm.DB) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalFaceDetector = FaceDetector{
|
GlobalFaceDetector = FaceDetector{
|
||||||
db: db,
|
|
||||||
rec: rec,
|
rec: rec,
|
||||||
faceDescriptors: faceDescriptors,
|
faceDescriptors: faceDescriptors,
|
||||||
faceGroupIDs: faceGroupIDs,
|
faceGroupIDs: faceGroupIDs,
|
||||||
|
@ -69,8 +67,8 @@ func getSamplesFromDatabase(db *gorm.DB) (samples []face.Descriptor, faceGroupID
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReloadFacesFromDatabase replaces the in-memory face descriptors with the ones in the database
|
// ReloadFacesFromDatabase replaces the in-memory face descriptors with the ones in the database
|
||||||
func (fd *FaceDetector) ReloadFacesFromDatabase() error {
|
func (fd *FaceDetector) ReloadFacesFromDatabase(db *gorm.DB) error {
|
||||||
faceDescriptors, faceGroupIDs, imageFaceIDs, err := getSamplesFromDatabase(fd.db)
|
faceDescriptors, faceGroupIDs, imageFaceIDs, err := getSamplesFromDatabase(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -86,8 +84,8 @@ func (fd *FaceDetector) ReloadFacesFromDatabase() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DetectFaces finds the faces in the given image and saves them to the database
|
// DetectFaces finds the faces in the given image and saves them to the database
|
||||||
func (fd *FaceDetector) DetectFaces(media *models.Media) error {
|
func (fd *FaceDetector) DetectFaces(tx *gorm.DB, media *models.Media) error {
|
||||||
if err := fd.db.Model(media).Preload("MediaURL").First(&media).Error; err != nil {
|
if err := tx.Model(media).Preload("MediaURL").First(&media).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,7 +116,7 @@ func (fd *FaceDetector) DetectFaces(media *models.Media) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, face := range faces {
|
for _, face := range faces {
|
||||||
fd.classifyFace(&face, media, thumbnailPath)
|
fd.classifyFace(tx, &face, media, thumbnailPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -128,7 +126,7 @@ func (fd *FaceDetector) classifyDescriptor(descriptor face.Descriptor) int32 {
|
||||||
return int32(fd.rec.ClassifyThreshold(descriptor, 0.2))
|
return int32(fd.rec.ClassifyThreshold(descriptor, 0.2))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fd *FaceDetector) classifyFace(face *face.Face, media *models.Media, imagePath string) error {
|
func (fd *FaceDetector) classifyFace(tx *gorm.DB, face *face.Face, media *models.Media, imagePath string) error {
|
||||||
fd.mutex.Lock()
|
fd.mutex.Lock()
|
||||||
defer fd.mutex.Unlock()
|
defer fd.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -155,18 +153,18 @@ func (fd *FaceDetector) classifyFace(face *face.Face, media *models.Media, image
|
||||||
ImageFaces: []models.ImageFace{imageFace},
|
ImageFaces: []models.ImageFace{imageFace},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fd.db.Create(&faceGroup).Error; err != nil {
|
if err := tx.Create(&faceGroup).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
log.Println("Found match")
|
log.Println("Found match")
|
||||||
|
|
||||||
if err := fd.db.First(&faceGroup, int(match)).Error; err != nil {
|
if err := tx.First(&faceGroup, int(match)).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fd.db.Model(&faceGroup).Association("ImageFaces").Append(&imageFace); err != nil {
|
if err := tx.Model(&faceGroup).Association("ImageFaces").Append(&imageFace); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
"github.com/photoview/photoview/api/scanner/face_detection"
|
"github.com/photoview/photoview/api/scanner/face_detection"
|
||||||
"github.com/photoview/photoview/api/utils"
|
"github.com/photoview/photoview/api/utils"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/sabhiram/go-gitignore"
|
ignore "github.com/sabhiram/go-gitignore"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ func scanAlbum(album *models.Album, cache *AlbumScannerCache, db *gorm.DB) {
|
||||||
|
|
||||||
if media.Type == models.MediaTypePhoto {
|
if media.Type == models.MediaTypePhoto {
|
||||||
go func() {
|
go func() {
|
||||||
if err := face_detection.GlobalFaceDetector.DetectFaces(media); err != nil {
|
if err := face_detection.GlobalFaceDetector.DetectFaces(tx, media); err != nil {
|
||||||
ScannerError("Error detecting faces in image (%s): %s", media.Path, err)
|
ScannerError("Error detecting faces in image (%s): %s", media.Path, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -139,14 +139,14 @@ func findMediaForAlbum(album *models.Album, cache *AlbumScannerCache, db *gorm.D
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get ignore data
|
// Get ignore data
|
||||||
albumIgnore := ignore.CompileIgnoreLines(*cache.GetAlbumIgnore(album.Path)...)
|
albumIgnore := ignore.CompileIgnoreLines(*cache.GetAlbumIgnore(album.Path)...)
|
||||||
|
|
||||||
for _, item := range dirContent {
|
for _, item := range dirContent {
|
||||||
photoPath := path.Join(album.Path, item.Name())
|
photoPath := path.Join(album.Path, item.Name())
|
||||||
|
|
||||||
if !item.IsDir() && isPathMedia(photoPath, cache) {
|
if !item.IsDir() && isPathMedia(photoPath, cache) {
|
||||||
// Match file against ignore data
|
// Match file against ignore data
|
||||||
if (albumIgnore.MatchesPath(item.Name())) {
|
if albumIgnore.MatchesPath(item.Name()) {
|
||||||
log.Printf("File %s ignored\n", item.Name())
|
log.Printf("File %s ignored\n", item.Name())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue