started with 2fa support
This commit is contained in:
@@ -2,10 +2,12 @@ package storage
|
||||
|
||||
import (
|
||||
"MiauInv/models"
|
||||
utils "MiauInv/util"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/glebarez/go-sqlite"
|
||||
)
|
||||
@@ -27,7 +29,9 @@ func InitDB(filepath string) error {
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password TEXT NOT NULL,
|
||||
role TEXT NOT NULL
|
||||
role TEXT NOT NULL,
|
||||
two_factor_enabled INTEGER NOT NULL DEFAULT 0,
|
||||
two_factor_secret TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
@@ -41,6 +45,16 @@ func InitDB(filepath string) error {
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS two_factor_recovery_codes (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
code_hash TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
used_at INTEGER DEFAULT NULL,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
|
||||
UNIQUE(user_id, code_hash)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS items (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
@@ -84,7 +98,26 @@ func InitDB(filepath string) error {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return err
|
||||
if err := ensureUserTwoFactorColumns(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureUserTwoFactorColumns() error {
|
||||
migrations := []string{
|
||||
"ALTER TABLE users ADD COLUMN two_factor_enabled INTEGER NOT NULL DEFAULT 0",
|
||||
"ALTER TABLE users ADD COLUMN two_factor_secret TEXT NOT NULL DEFAULT ''",
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
_, err := DB.Exec(migration)
|
||||
if err != nil && !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Users
|
||||
@@ -93,18 +126,131 @@ func AddUser(user *models.User) error {
|
||||
return err
|
||||
}
|
||||
func GetUserByUsername(username string) (models.User, error) {
|
||||
row := DB.QueryRow("SELECT * FROM users WHERE username = ?", strings.ToLower(username))
|
||||
var user models.User
|
||||
err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role)
|
||||
return user, err
|
||||
row := DB.QueryRow(`
|
||||
SELECT id, username, password, role, two_factor_enabled, two_factor_secret
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
`, strings.ToLower(username))
|
||||
return scanUser(row)
|
||||
}
|
||||
func GetUserById(id string) (models.User, error) {
|
||||
row := DB.QueryRow("SELECT * FROM users WHERE id = ?", id)
|
||||
row := DB.QueryRow(`
|
||||
SELECT id, username, password, role, two_factor_enabled, two_factor_secret
|
||||
FROM users
|
||||
WHERE id = ?
|
||||
`, id)
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
func scanUser(row *sql.Row) (models.User, error) {
|
||||
var user models.User
|
||||
err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role)
|
||||
var twoFactorEnabled int
|
||||
err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role, &twoFactorEnabled, &user.TwoFactorSecret)
|
||||
user.TwoFactorEnabled = twoFactorEnabled == 1
|
||||
return user, err
|
||||
}
|
||||
|
||||
func SetUserTwoFactorSecret(userID, secret string) error {
|
||||
_, err := DB.Exec("UPDATE users SET two_factor_secret = ? WHERE id = ?", secret, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func EnableUserTwoFactorWithRecoveryCodes(userID string, recoveryCodeHashes []string) error {
|
||||
tx, err := DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
for _, codeHash := range recoveryCodeHashes {
|
||||
if _, err := tx.Exec(`
|
||||
INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, utils.GenerateUUID(), userID, codeHash, now); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 1 WHERE id = ?", userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func DisableUserTwoFactor(userID string) error {
|
||||
tx, err := DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 0, two_factor_secret = '' WHERE id = ?", userID); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func ReplaceUserRecoveryCodes(userID string, recoveryCodeHashes []string) error {
|
||||
tx, err := DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
for _, codeHash := range recoveryCodeHashes {
|
||||
if _, err := tx.Exec(`
|
||||
INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, utils.GenerateUUID(), userID, codeHash, now); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func UseUserRecoveryCode(userID, codeHash string) (bool, error) {
|
||||
res, err := DB.Exec(`
|
||||
UPDATE two_factor_recovery_codes
|
||||
SET used_at = ?
|
||||
WHERE user_id = ? AND code_hash = ? AND used_at IS NULL
|
||||
`, time.Now().Unix(), userID, codeHash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return n == 1, nil
|
||||
}
|
||||
|
||||
func CountUnusedRecoveryCodes(userID string) (int, error) {
|
||||
var count int
|
||||
err := DB.QueryRow(`
|
||||
SELECT COUNT(*)
|
||||
FROM two_factor_recovery_codes
|
||||
WHERE user_id = ? AND used_at IS NULL
|
||||
`, userID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Refresh Tokens
|
||||
func AddRefreshToken(token *models.RefreshToken) error {
|
||||
_, err := DB.Exec("INSERT INTO refresh_tokens(id, user_id, token_hash, expires_at, created_at, revoked, device_info) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
|
||||
Reference in New Issue
Block a user