package storage import ( "MiauInv/models" utils "MiauInv/util" "database/sql" "errors" "log" "strings" "time" _ "github.com/glebarez/go-sqlite" ) var ErrNotFound = sql.ErrNoRows var DB *sql.DB func InitDB(filepath string) error { var err error DB, err = sql.Open("sqlite", filepath) if err != nil { return err } schema := ` PRAGMA foreign_keys = ON; CREATE TABLE IF NOT EXISTS users ( id TEXT PRIMARY KEY, username TEXT NOT NULL UNIQUE, password TEXT NOT NULL, role TEXT NOT NULL, two_factor_enabled INTEGER NOT NULL DEFAULT 0, two_factor_secret TEXT NOT NULL DEFAULT '' ); CREATE TABLE IF NOT EXISTS refresh_tokens ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, token_hash TEXT NOT NULL, expires_at INTEGER NOT NULL, created_at INTEGER NOT NULL, revoked INTEGER NOT NULL DEFAULT 0, device_info TEXT, FOREIGN KEY(user_id) REFERENCES users(id) ); CREATE TABLE IF NOT EXISTS two_factor_recovery_codes ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, code_hash TEXT NOT NULL, created_at INTEGER NOT NULL, used_at INTEGER DEFAULT NULL, FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE, UNIQUE(user_id, code_hash) ); CREATE TABLE IF NOT EXISTS items ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, category TEXT, description TEXT, total_quantity INTEGER NOT NULL DEFAULT 0 ); CREATE TABLE IF NOT EXISTS locations ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE ); CREATE TABLE IF NOT EXISTS projects ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE, description TEXT ); CREATE TABLE IF NOT EXISTS stock ( id INTEGER PRIMARY KEY AUTOINCREMENT, item_id INTEGER NOT NULL, location_id INTEGER NOT NULL, quantity INTEGER NOT NULL, FOREIGN KEY(item_id) REFERENCES items(id), FOREIGN KEY(location_id) REFERENCES locations(id) ); CREATE TABLE IF NOT EXISTS project_items ( id INTEGER PRIMARY KEY AUTOINCREMENT, item_id INTEGER NOT NULL, project_id INTEGER NOT NULL, quantity INTEGER NOT NULL, FOREIGN KEY(item_id) REFERENCES items(id), FOREIGN KEY(project_id) REFERENCES projects(id) ); ` _, err = DB.Exec(schema) if err != nil { log.Fatal(err) } if err := ensureUserTwoFactorColumns(); err != nil { return err } return nil } func ensureUserTwoFactorColumns() error { migrations := []string{ "ALTER TABLE users ADD COLUMN two_factor_enabled INTEGER NOT NULL DEFAULT 0", "ALTER TABLE users ADD COLUMN two_factor_secret TEXT NOT NULL DEFAULT ''", } for _, migration := range migrations { _, err := DB.Exec(migration) if err != nil && !strings.Contains(strings.ToLower(err.Error()), "duplicate column") { return err } } return nil } // Users func AddUser(user *models.User) error { _, err := DB.Exec("INSERT INTO users(id, username, password, role) VALUES (?, ?, ?, ?)", user.ID, strings.ToLower(user.Username), user.Password, user.Role) return err } func GetUserByUsername(username string) (models.User, error) { row := DB.QueryRow(` SELECT id, username, password, role, two_factor_enabled, two_factor_secret FROM users WHERE username = ? `, strings.ToLower(username)) return scanUser(row) } func GetUserById(id string) (models.User, error) { row := DB.QueryRow(` SELECT id, username, password, role, two_factor_enabled, two_factor_secret FROM users WHERE id = ? `, id) return scanUser(row) } func scanUser(row *sql.Row) (models.User, error) { var user models.User var twoFactorEnabled int err := row.Scan(&user.ID, &user.Username, &user.Password, &user.Role, &twoFactorEnabled, &user.TwoFactorSecret) user.TwoFactorEnabled = twoFactorEnabled == 1 return user, err } func SetUserTwoFactorSecret(userID, secret string) error { _, err := DB.Exec("UPDATE users SET two_factor_secret = ? WHERE id = ?", secret, userID) return err } func EnableUserTwoFactorWithRecoveryCodes(userID string, recoveryCodeHashes []string) error { tx, err := DB.Begin() if err != nil { return err } defer tx.Rollback() if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil { return err } now := time.Now().Unix() for _, codeHash := range recoveryCodeHashes { if _, err := tx.Exec(` INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at) VALUES (?, ?, ?, ?) `, utils.GenerateUUID(), userID, codeHash, now); err != nil { return err } } if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 1 WHERE id = ?", userID); err != nil { return err } return tx.Commit() } func DisableUserTwoFactor(userID string) error { tx, err := DB.Begin() if err != nil { return err } defer tx.Rollback() if _, err := tx.Exec("UPDATE users SET two_factor_enabled = 0, two_factor_secret = '' WHERE id = ?", userID); err != nil { return err } if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil { return err } return tx.Commit() } func ReplaceUserRecoveryCodes(userID string, recoveryCodeHashes []string) error { tx, err := DB.Begin() if err != nil { return err } defer tx.Rollback() if _, err := tx.Exec("DELETE FROM two_factor_recovery_codes WHERE user_id = ?", userID); err != nil { return err } now := time.Now().Unix() for _, codeHash := range recoveryCodeHashes { if _, err := tx.Exec(` INSERT INTO two_factor_recovery_codes(id, user_id, code_hash, created_at) VALUES (?, ?, ?, ?) `, utils.GenerateUUID(), userID, codeHash, now); err != nil { return err } } return tx.Commit() } func UseUserRecoveryCode(userID, codeHash string) (bool, error) { res, err := DB.Exec(` UPDATE two_factor_recovery_codes SET used_at = ? WHERE user_id = ? AND code_hash = ? AND used_at IS NULL `, time.Now().Unix(), userID, codeHash) if err != nil { return false, err } n, err := res.RowsAffected() if err != nil { return false, err } return n == 1, nil } func CountUnusedRecoveryCodes(userID string) (int, error) { var count int err := DB.QueryRow(` SELECT COUNT(*) FROM two_factor_recovery_codes WHERE user_id = ? AND used_at IS NULL `, userID).Scan(&count) return count, err } // Refresh Tokens func AddRefreshToken(token *models.RefreshToken) error { _, err := DB.Exec("INSERT INTO refresh_tokens(id, user_id, token_hash, expires_at, created_at, revoked, device_info) VALUES (?, ?, ?, ?, ?, ?, ?)", token.ID, token.UserID, token.Token, token.ExpiresAt, token.CreatedAt, token.Revoked, token.DeviceInfo) return err } func GetRefreshToken(token string) (models.RefreshToken, error) { row := DB.QueryRow("SELECT * FROM refresh_tokens WHERE token_hash = ?", token) var refresh_token models.RefreshToken err := row.Scan(&refresh_token.ID, &refresh_token.UserID, &refresh_token.Token, &refresh_token.ExpiresAt, &refresh_token.CreatedAt, &refresh_token.Revoked, &refresh_token.DeviceInfo) return refresh_token, err } func RevokeRefreshToken(tokenID string) error { if DB == nil { return errors.New("db not initialized") } res, err := DB.Exec(` UPDATE refresh_tokens SET revoked = 1 WHERE id = ? `, tokenID) if err != nil { return err } n, err := res.RowsAffected() if err != nil { return err } if n == 0 { return ErrNotFound } return nil } func RevokeAllRefreshTokensForUser(userID string) error { if DB == nil { return errors.New("db not initialized") } _, err := DB.Exec(` UPDATE refresh_tokens SET revoked = 1 WHERE user_id = ? `, userID) return err }