Files
remote-rig/internal/db/db.go
T
overseer c2670a9f33
CI/CD / lint-and-typecheck (pull_request) Successful in 8s
CI/CD / test (pull_request) Successful in 9s
CI/CD / build (pull_request) Failing after 10s
CI/CD / deploy (pull_request) Has been skipped
Merge branch 'dev' into agent/dex/CUB-239-hub-dedup-replay
2026-05-28 06:59:51 -04:00

181 lines
4.0 KiB
Go

// Package db provides SQLite database initialization and schema management.
package db
import (
"database/sql"
_ "embed"
"fmt"
"log"
"os"
"path/filepath"
"strings"
_ "modernc.org/sqlite"
)
//go:embed migrations/001_create_tables.sql
var migration001 string
//go:embed migrations/002_dedup_unique_index.sql
var migration002 string
// DB wraps the sql.DB with connection-level settings.
type DB struct {
*sql.DB
}
// Open opens the SQLite database at the given path, enables WAL mode,
// and runs all migrations using a schema_version table for tracking.
func Open(path string) (*DB, error) {
// Ensure the directory exists
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, err
}
db, err := sql.Open("sqlite", path)
if err != nil {
return nil, err
}
// Enable WAL for concurrent read/write performance
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
db.Close()
return nil, err
}
// Enable foreign keys
if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil {
db.Close()
return nil, err
}
// Ensure schema_version table exists for migration tracking
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_version (version INTEGER PRIMARY KEY)`); err != nil {
db.Close()
return nil, err
}
// Read current schema version (0 if table is empty)
var currentVersion int
if err := db.QueryRow(`SELECT COALESCE(MAX(version), 0) FROM schema_version`).Scan(&currentVersion); err != nil {
db.Close()
return nil, err
}
// Migration definitions: ordered list of (version, sql)
type migration struct {
version int
sql string
}
migrations := []migration{
{1, migration001},
{2, migration002},
}
for _, m := range migrations {
if currentVersion >= m.version {
continue
}
log.Printf("Running migration %d for %s...", m.version, path)
if err := migrate(db, m.sql); err != nil {
db.Close()
return nil, fmt.Errorf("migration %d: %w", m.version, err)
}
if _, err := db.Exec(`INSERT INTO schema_version (version) VALUES (?)`, m.version); err != nil {
db.Close()
return nil, fmt.Errorf("record migration %d: %w", m.version, err)
}
log.Printf("Migration %d complete", m.version)
}
if currentVersion < len(migrations) {
log.Println("Migrations complete")
}
return &DB{db}, nil
}
// migrate executes a SQL migration string by splitting on semicolons.
func migrate(db *sql.DB, sql string) error {
statements := splitSQL(sql)
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
if _, err := db.Exec(stmt); err != nil {
return err
}
}
return nil
}
// splitSQL splits a SQL string on semicolons, respecting quoted strings
// and stripping SQL line comments (--).
func splitSQL(sql string) []string {
// First, strip all line comments (--) to prevent them from swallowing
// subsequent SQL statements when newlines are collapsed.
sql = stripSQLLineComments(sql)
var stmts []string
var current string
inQuote := false
quoteChar := rune(0)
for _, r := range sql {
if inQuote {
current += string(r)
if r == quoteChar {
inQuote = false
}
continue
}
switch r {
case '"', '\'', '`':
inQuote = true
quoteChar = r
current += string(r)
case ';':
stmts = append(stmts, current)
current = ""
case '\r', '\n', '\t':
current += " "
default:
current += string(r)
}
}
if strings.TrimSpace(current) != "" {
stmts = append(stmts, current)
}
return stmts
}
// stripSQLLineComments removes all -- single-line comments from SQL text.
func stripSQLLineComments(sql string) string {
var result strings.Builder
i := 0
runes := []rune(sql)
for i < len(runes) {
r := runes[i]
// Check for -- comment start
if r == '-' && i+1 < len(runes) && runes[i+1] == '-' {
// Skip to end of line
i += 2
for i < len(runes) && runes[i] != '\n' && runes[i] != '\r' {
i++
}
// Replace comment with a newline (preserves statement boundaries)
result.WriteRune('\n')
continue
}
result.WriteRune(r)
i++
}
return result.String()
}