300 lines
7.6 KiB
Go
300 lines
7.6 KiB
Go
package storage
|
|
|
|
import (
|
|
"MiauInv/models"
|
|
utils "MiauInv/util"
|
|
"database/sql"
|
|
"errors"
|
|
"log"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/glebarez/go-sqlite"
|
|
)
|
|
|
|
var ErrNotFound = sql.ErrNoRows
|
|
var DB *sql.DB
|
|
|
|
func InitDB(filepath string) error {
|
|
var err error
|
|
DB, err = sql.Open("sqlite", filepath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
schema := `
|
|
PRAGMA foreign_keys = ON;
|
|
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id TEXT PRIMARY KEY,
|
|
username TEXT NOT NULL UNIQUE,
|
|
password TEXT NOT NULL,
|
|
role TEXT NOT NULL,
|
|
two_factor_enabled INTEGER NOT NULL DEFAULT 0,
|
|
two_factor_secret TEXT NOT NULL DEFAULT ''
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
|
id TEXT PRIMARY KEY,
|
|
user_id TEXT NOT NULL,
|
|
token_hash TEXT NOT NULL,
|
|
expires_at INTEGER NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
revoked INTEGER NOT NULL DEFAULT 0,
|
|
device_info TEXT,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS two_factor_recovery_codes (
|
|
id TEXT PRIMARY KEY,
|
|
user_id TEXT NOT NULL,
|
|
code_hash TEXT NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
used_at INTEGER DEFAULT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
|
|
UNIQUE(user_id, code_hash)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS items (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL,
|
|
category TEXT,
|
|
description TEXT,
|
|
total_quantity INTEGER NOT NULL DEFAULT 0
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS locations (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL UNIQUE
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS projects (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL UNIQUE,
|
|
description TEXT
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS stock (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
item_id INTEGER NOT NULL,
|
|
location_id INTEGER NOT NULL,
|
|
quantity INTEGER NOT NULL,
|
|
FOREIGN KEY(item_id) REFERENCES items(id),
|
|
FOREIGN KEY(location_id) REFERENCES locations(id)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS project_items (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
item_id INTEGER NOT NULL,
|
|
project_id INTEGER NOT NULL,
|
|
quantity INTEGER NOT NULL,
|
|
FOREIGN KEY(item_id) REFERENCES items(id),
|
|
FOREIGN KEY(project_id) REFERENCES projects(id)
|
|
);
|
|
`
|
|
|
|
_, err = DB.Exec(schema)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
if err := ensureUserTwoFactorColumns(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ensureUserTwoFactorColumns() error {
|
|
migrations := []string{
|
|
"ALTER TABLE users ADD COLUMN two_factor_enabled INTEGER NOT NULL DEFAULT 0",
|
|
"ALTER TABLE users ADD COLUMN two_factor_secret TEXT NOT NULL DEFAULT ''",
|
|
}
|
|
|
|
for _, migration := range migrations {
|
|
_, err := DB.Exec(migration)
|
|
if err != nil && !strings.Contains(strings.ToLower(err.Error()), "duplicate column") {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Users
|
|
func AddUser(user *models.User) error {
|
|
_, err := DB.Exec("INSERT INTO users(id, username, password, role) VALUES (?, ?, ?, ?)", user.ID, strings.ToLower(user.Username), user.Password, user.Role)
|
|
return err
|
|
}
|
|
func GetUserByUsername(username string) (models.User, error) {
|
|
row := DB.QueryRow(`
|
|
SELECT id, username, password, role, two_factor_enabled, two_factor_secret
|
|
FROM users
|
|
WHERE username = ?
|
|
`, strings.ToLower(username))
|
|
return scanUser(row)
|
|
}
|
|
func GetUserById(id string) (models.User, error) {
|
|
row := DB.QueryRow(`
|
|
SELECT id, username, password, role, two_factor_enabled, two_factor_secret
|
|
FROM users
|
|
WHERE id = ?
|
|
`, id)
|
|
return scanUser(row)
|
|
}
|
|
|
|
func scanUser(row *sql.Row) (models.User, error) {
|
|
var user models.User
|
|
var twoFactorEnabled int
|
|
err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role, &twoFactorEnabled, &user.TwoFactorSecret)
|
|
user.TwoFactorEnabled = twoFactorEnabled == 1
|
|
return user, err
|
|
}
|
|
|
|
func SetUserTwoFactorSecret(userID, secret string) error {
|
|
_, err := DB.Exec("UPDATE users SET two_factor_secret = ? WHERE id = ?", secret, userID)
|
|
return err
|
|
}
|
|
|
|
func EnableUserTwoFactorWithRecoveryCodes(userID string, recoveryCodeHashes []string) error {
|
|
tx, err := DB.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
|
|
return err
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
for _, codeHash := range recoveryCodeHashes {
|
|
if _, err := tx.Exec(`
|
|
INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at)
|
|
VALUES (?, ?, ?, ?)
|
|
`, utils.GenerateUUID(), userID, codeHash, now); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 1 WHERE id = ?", userID); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func DisableUserTwoFactor(userID string) error {
|
|
tx, err := DB.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 0, two_factor_secret = '' WHERE id = ?", userID); err != nil {
|
|
return err
|
|
}
|
|
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func ReplaceUserRecoveryCodes(userID string, recoveryCodeHashes []string) error {
|
|
tx, err := DB.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil {
|
|
return err
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
for _, codeHash := range recoveryCodeHashes {
|
|
if _, err := tx.Exec(`
|
|
INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at)
|
|
VALUES (?, ?, ?, ?)
|
|
`, utils.GenerateUUID(), userID, codeHash, now); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func UseUserRecoveryCode(userID, codeHash string) (bool, error) {
|
|
res, err := DB.Exec(`
|
|
UPDATE two_factor_recovery_codes
|
|
SET used_at = ?
|
|
WHERE user_id = ? AND code_hash = ? AND used_at IS NULL
|
|
`, time.Now().Unix(), userID, codeHash)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
n, err := res.RowsAffected()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return n == 1, nil
|
|
}
|
|
|
|
func CountUnusedRecoveryCodes(userID string) (int, error) {
|
|
var count int
|
|
err := DB.QueryRow(`
|
|
SELECT COUNT(*)
|
|
FROM two_factor_recovery_codes
|
|
WHERE user_id = ? AND used_at IS NULL
|
|
`, userID).Scan(&count)
|
|
return count, err
|
|
}
|
|
|
|
// Refresh Tokens
|
|
func AddRefreshToken(token *models.RefreshToken) error {
|
|
_, err := DB.Exec("INSERT INTO refresh_tokens(id, user_id, token_hash, expires_at, created_at, revoked, device_info) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
|
token.ID, token.UserID, token.Token, token.ExpiresAt, token.CreatedAt, token.Revoked, token.DeviceInfo)
|
|
return err
|
|
}
|
|
func GetRefreshToken(token string) (models.RefreshToken, error) {
|
|
row := DB.QueryRow("SELECT * FROM refresh_tokens WHERE token_hash = ?", token)
|
|
var refresh_token models.RefreshToken
|
|
err := row.Scan(&refresh_token.ID, &refresh_token.UserID, &refresh_token.Token, &refresh_token.ExpiresAt, &refresh_token.CreatedAt, &refresh_token.Revoked, &refresh_token.DeviceInfo)
|
|
return refresh_token, err
|
|
}
|
|
func RevokeRefreshToken(tokenID string) error {
|
|
if DB == nil {
|
|
return errors.New("db not initialized")
|
|
}
|
|
|
|
res, err := DB.Exec(`
|
|
UPDATE refresh_tokens
|
|
SET revoked = 1
|
|
WHERE id = ?
|
|
`, tokenID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
func RevokeAllRefreshTokensForUser(userID string) error {
|
|
if DB == nil {
|
|
return errors.New("db not initialized")
|
|
}
|
|
|
|
_, err := DB.Exec(`
|
|
UPDATE refresh_tokens
|
|
SET revoked = 1
|
|
WHERE user_id = ?
|
|
`, userID)
|
|
return err
|
|
}
|