All checks were successful
test-and-lint / test-and-lint (pull_request) Successful in 2m50s
263 lines
7.2 KiB
Go
263 lines
7.2 KiB
Go
package storage
|
|
|
|
import (
|
|
"MiauInv/models"
|
|
utils "MiauInv/util"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/go-webauthn/webauthn/webauthn"
|
|
)
|
|
|
|
const (
|
|
PasskeyCeremonyRegister = "register"
|
|
PasskeyCeremonyLogin = "login"
|
|
)
|
|
|
|
func AddPasskeyCredential(userID, name string, credential *webauthn.Credential) (models.PasskeyCredential, error) {
|
|
if DB == nil {
|
|
return models.PasskeyCredential{}, errors.New("db not initialized")
|
|
}
|
|
|
|
credentialJSON, err := json.Marshal(credential)
|
|
if err != nil {
|
|
return models.PasskeyCredential{}, err
|
|
}
|
|
|
|
row := models.PasskeyCredential{
|
|
ID: utils.GenerateUUID(),
|
|
UserID: userID,
|
|
CredentialID: utils.EncodeBase64URL(credential.ID),
|
|
Name: name,
|
|
CredentialData: string(credentialJSON),
|
|
CreatedAt: time.Now().Unix(),
|
|
}
|
|
if row.Name == "" {
|
|
row.Name = "Passkey"
|
|
}
|
|
|
|
_, err = DB.Exec(`
|
|
INSERT INTO passkey_credentials(id, user_id, credential_id, name, credential_data, created_at, last_used_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, NULL)
|
|
`, row.ID, row.UserID, row.CredentialID, row.Name, row.CredentialData, row.CreatedAt)
|
|
if err != nil {
|
|
return models.PasskeyCredential{}, err
|
|
}
|
|
|
|
return row, nil
|
|
}
|
|
|
|
func ListPasskeyCredentials(userID string) ([]models.PasskeyCredential, error) {
|
|
if DB == nil {
|
|
return nil, errors.New("db not initialized")
|
|
}
|
|
|
|
rows, err := DB.Query(`
|
|
SELECT id, user_id, credential_id, name, credential_data, created_at, COALESCE(last_used_at, 0)
|
|
FROM passkey_credentials
|
|
WHERE user_id = ?
|
|
ORDER BY created_at DESC
|
|
`, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
credentials := make([]models.PasskeyCredential, 0)
|
|
for rows.Next() {
|
|
var credential models.PasskeyCredential
|
|
if err := rows.Scan(&credential.ID, &credential.UserID, &credential.CredentialID, &credential.Name, &credential.CredentialData, &credential.CreatedAt, &credential.LastUsedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
credentials = append(credentials, credential)
|
|
}
|
|
return credentials, rows.Err()
|
|
}
|
|
|
|
func CountPasskeyCredentials(userID string) (int, error) {
|
|
if DB == nil {
|
|
return 0, errors.New("db not initialized")
|
|
}
|
|
|
|
var count int
|
|
err := DB.QueryRow("SELECT COUNT(*) FROM passkey_credentials WHERE user_id = ?", userID).Scan(&count)
|
|
return count, err
|
|
}
|
|
|
|
func GetPasskeyCredentialByCredentialID(credentialID string) (models.PasskeyCredential, error) {
|
|
if DB == nil {
|
|
return models.PasskeyCredential{}, errors.New("db not initialized")
|
|
}
|
|
|
|
row := DB.QueryRow(`
|
|
SELECT id, user_id, credential_id, name, credential_data, created_at, COALESCE(last_used_at, 0)
|
|
FROM passkey_credentials
|
|
WHERE credential_id = ?
|
|
`, credentialID)
|
|
return scanPasskeyCredential(row)
|
|
}
|
|
|
|
func GetPasskeyCredentialByID(userID, id string) (models.PasskeyCredential, error) {
|
|
if DB == nil {
|
|
return models.PasskeyCredential{}, errors.New("db not initialized")
|
|
}
|
|
|
|
row := DB.QueryRow(`
|
|
SELECT id, user_id, credential_id, name, credential_data, created_at, COALESCE(last_used_at, 0)
|
|
FROM passkey_credentials
|
|
WHERE user_id = ? AND id = ?
|
|
`, userID, id)
|
|
return scanPasskeyCredential(row)
|
|
}
|
|
|
|
func scanPasskeyCredential(row *sql.Row) (models.PasskeyCredential, error) {
|
|
var credential models.PasskeyCredential
|
|
err := row.Scan(&credential.ID, &credential.UserID, &credential.CredentialID, &credential.Name, &credential.CredentialData, &credential.CreatedAt, &credential.LastUsedAt)
|
|
return credential, err
|
|
}
|
|
|
|
func DecodeWebAuthnCredential(row models.PasskeyCredential) (webauthn.Credential, error) {
|
|
var credential webauthn.Credential
|
|
err := json.Unmarshal([]byte(row.CredentialData), &credential)
|
|
return credential, err
|
|
}
|
|
|
|
func DecodeWebAuthnCredentials(rows []models.PasskeyCredential) ([]webauthn.Credential, error) {
|
|
credentials := make([]webauthn.Credential, 0, len(rows))
|
|
for _, row := range rows {
|
|
credential, err := DecodeWebAuthnCredential(row)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
credentials = append(credentials, credential)
|
|
}
|
|
return credentials, nil
|
|
}
|
|
|
|
func UpdatePasskeyCredentialAfterLogin(userID string, credential *webauthn.Credential) error {
|
|
if DB == nil {
|
|
return errors.New("db not initialized")
|
|
}
|
|
|
|
credentialID := utils.EncodeBase64URL(credential.ID)
|
|
credentialJSON, err := json.Marshal(credential)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
res, err := DB.Exec(`
|
|
UPDATE passkey_credentials
|
|
SET credential_data = ?, last_used_at = ?
|
|
WHERE user_id = ? AND credential_id = ?
|
|
`, string(credentialJSON), time.Now().Unix(), userID, credentialID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func DeletePasskeyCredential(userID, id string) error {
|
|
if DB == nil {
|
|
return errors.New("db not initialized")
|
|
}
|
|
|
|
res, err := DB.Exec("DELETE FROM passkey_credentials WHERE user_id = ? AND id = ?", userID, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func DeleteAllPasskeyCredentials(userID string) error {
|
|
if DB == nil {
|
|
return errors.New("db not initialized")
|
|
}
|
|
|
|
_, err := DB.Exec("DELETE FROM passkey_credentials WHERE user_id = ?", userID)
|
|
return err
|
|
}
|
|
|
|
func SavePasskeyChallenge(token, userID, ceremony string, sessionData webauthn.SessionData, ttl time.Duration) error {
|
|
if DB == nil {
|
|
return errors.New("db not initialized")
|
|
}
|
|
|
|
encodedSession, err := json.Marshal(sessionData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = DB.Exec(`
|
|
INSERT INTO passkey_challenges(token, user_id, ceremony, session_data, expires_at)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
`, token, userID, ceremony, string(encodedSession), time.Now().Add(ttl).Unix())
|
|
return err
|
|
}
|
|
|
|
func ConsumePasskeyChallenge(token, ceremony string) (models.PasskeyChallenge, webauthn.SessionData, error) {
|
|
if DB == nil {
|
|
return models.PasskeyChallenge{}, webauthn.SessionData{}, errors.New("db not initialized")
|
|
}
|
|
|
|
tx, err := DB.Begin()
|
|
if err != nil {
|
|
return models.PasskeyChallenge{}, webauthn.SessionData{}, err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
row := tx.QueryRow(`
|
|
SELECT token, user_id, ceremony, session_data, expires_at
|
|
FROM passkey_challenges
|
|
WHERE token = ? AND ceremony = ?
|
|
`, token, ceremony)
|
|
|
|
var challenge models.PasskeyChallenge
|
|
if err := row.Scan(&challenge.Token, &challenge.UserID, &challenge.Ceremony, &challenge.SessionData, &challenge.ExpiresAt); err != nil {
|
|
return models.PasskeyChallenge{}, webauthn.SessionData{}, err
|
|
}
|
|
|
|
if _, err := tx.Exec("DELETE FROM passkey_challenges WHERE token = ?", token); err != nil {
|
|
return models.PasskeyChallenge{}, webauthn.SessionData{}, err
|
|
}
|
|
|
|
if challenge.ExpiresAt < time.Now().Unix() {
|
|
return models.PasskeyChallenge{}, webauthn.SessionData{}, errors.New("passkey challenge expired")
|
|
}
|
|
|
|
var sessionData webauthn.SessionData
|
|
if err := json.Unmarshal([]byte(challenge.SessionData), &sessionData); err != nil {
|
|
return models.PasskeyChallenge{}, webauthn.SessionData{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return models.PasskeyChallenge{}, webauthn.SessionData{}, err
|
|
}
|
|
|
|
return challenge, sessionData, nil
|
|
}
|
|
|
|
func CleanupExpiredPasskeyChallenges() error {
|
|
if DB == nil {
|
|
return errors.New("db not initialized")
|
|
}
|
|
|
|
_, err := DB.Exec("DELETE FROM passkey_challenges WHERE expires_at < ?", time.Now().Unix())
|
|
return err
|
|
}
|