diff --git a/service/database/db-users.go b/service/database/db-users.go index 3d0bd80..003e635 100644 --- a/service/database/db-users.go +++ b/service/database/db-users.go @@ -2,6 +2,7 @@ package database import ( "database/sql" + "errors" "github.com/gofrs/uuid" "github.com/notherealmarco/WASAPhoto/service/database/db_errors" @@ -47,6 +48,16 @@ func (db *appdbimpl) GetUserID(name string) (string, error) { // Create a new user func (db *appdbimpl) CreateUser(name string) (string, error) { + + // check if username is taken (case insensitive) + exists, err := db.nameExists(name) + + if err != nil { + return "", err + } else if exists { + return "", errors.New("username already exists") + } + uid, err := uuid.NewV4() if err != nil { return "", err @@ -55,14 +66,30 @@ func (db *appdbimpl) CreateUser(name string) (string, error) { return uid.String(), err } +// Check if username exists +func (db *appdbimpl) nameExists(name string) (bool, error) { + var cnt int + err := db.c.QueryRow(`SELECT COUNT(*) FROM "users" WHERE "name" LIKE ?`, name).Scan(&cnt) + if err != nil { + return false, err + } + return cnt > 0, nil +} + // Update username func (db *appdbimpl) UpdateUsername(uid string, name string) (QueryResult, error) { - _, err := db.c.Exec(`UPDATE "users" SET "name" = ? WHERE "uid" = ?`, name, uid) - if db_errors.UniqueViolation(err) { + // check if username is taken (case insensitive) + exists, err := db.nameExists(name) + + if err != nil { + return ERR_INTERNAL, err + } else if exists { return ERR_EXISTS, nil } + _, err = db.c.Exec(`UPDATE "users" SET "name" = ? WHERE "uid" = ?`, name, uid) + if err != nil { return ERR_INTERNAL, err }