diff --git a/go-backend/internal/gateway/client.go b/go-backend/internal/gateway/client.go index 90b2f4d..4b2d520 100644 --- a/go-backend/internal/gateway/client.go +++ b/go-backend/internal/gateway/client.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "net/http" + "sync" "time" "code.cubecraftcreations.com/CubeCraft-Creations/Control-Center/go-backend/internal/handler" @@ -26,8 +27,9 @@ type Client struct { httpClient *http.Client agents repository.AgentRepo broker *handler.Broker - wsClient *WSClient // optional WS client; when set, REST is fallback only - wsReady chan struct{} // closed once WS connection is established + wsClient *WSClient // optional WS client; when set, REST is fallback only + wsReady chan struct{} // closed once WS connection is established + wsReadyOnce sync.Once // protects wsReady close from double-close race } // Config holds gateway client configuration, typically loaded from environment. @@ -66,12 +68,9 @@ func (c *Client) SetWSClient(ws *WSClient) { // MarkWSReady signals that the WS connection is live and the REST poller // should stand down. Called by WSClient after a successful handshake. func (c *Client) MarkWSReady() { - select { - case <-c.wsReady: - // already closed - default: + c.wsReadyOnce.Do(func() { close(c.wsReady) - } + }) } // Start begins the gateway client loop. When a WS client is wired, it diff --git a/go-backend/internal/gateway/events.go b/go-backend/internal/gateway/events.go index d6544b6..a0f660e 100644 --- a/go-backend/internal/gateway/events.go +++ b/go-backend/internal/gateway/events.go @@ -18,8 +18,7 @@ import ( // ── Event payload types ────────────────────────────────────────────────── // sessionChangedPayload represents a single session delta from a -// sessions.changed event. Fields are optional; use json.RawMessage for -// anything we don't strictly need. +// sessions.changed event. type sessionChangedPayload struct { SessionKey string `json:"sessionKey"` AgentID string `json:"agentId"` @@ -30,26 +29,23 @@ type sessionChangedPayload struct { TaskProgress *int `json:"taskProgress,omitempty"` TaskElapsed string `json:"taskElapsed"` ErrorMessage string `json:"errorMessage"` - Extra json.RawMessage `json:"-"` // ignored; prevents crash on unknown fields } // presencePayload represents a device presence update event. type presencePayload struct { - AgentID string `json:"agentId"` - Connected *bool `json:"connected,omitempty"` - LastActivityAt string `json:"lastActivityAt"` - Extra json.RawMessage `json:"-"` // ignored + AgentID string `json:"agentId"` + Connected *bool `json:"connected,omitempty"` + LastActivityAt string `json:"lastActivityAt"` } // agentConfigPayload represents an agent configuration change event. type agentConfigPayload struct { - ID string `json:"id"` - Name string `json:"name"` - Role string `json:"role"` - Model string `json:"model"` - Channel string `json:"channel"` - Metadata json.RawMessage `json:"metadata"` - Extra json.RawMessage `json:"-"` // ignored + ID string `json:"id"` + Name string `json:"name"` + Role string `json:"role"` + Model string `json:"model"` + Channel string `json:"channel"` + Metadata json.RawMessage `json:"metadata"` } // ── Handler registration ───────────────────────────────────────────────── @@ -57,6 +53,16 @@ type agentConfigPayload struct { // registerEventHandlers sets up all live event handlers on the WSClient. // Call this once after a successful handshake + initial sync. func (c *WSClient) registerEventHandlers() { + if c.agents == nil || c.broker == nil { + c.logger.Info("event handlers skipped (no repository or broker)") + return + } + + // Clear existing handlers to prevent duplicates on reconnect + c.mu.Lock() + c.handlers = make(map[string][]eventHandler) + c.mu.Unlock() + c.OnEvent("sessions.changed", c.handleSessionsChanged) c.OnEvent("presence", c.handlePresence) c.OnEvent("agent.config", c.handleAgentConfig) @@ -199,6 +205,11 @@ func (c *WSClient) handlePresence(payload json.RawMessage) { update.Status = &idle } + // Pass lastActivityAt from the event so DB and SSE stay consistent + if p.LastActivityAt != "" { + update.LastActivityAt = &p.LastActivityAt + } + // Update DB first updated, err := c.agents.Update(ctx, p.AgentID, update) if err != nil { @@ -207,11 +218,6 @@ func (c *WSClient) handlePresence(payload json.RawMessage) { return } - // Use reported timestamp if available - if p.LastActivityAt != "" { - updated.LastActivity = p.LastActivityAt - } - // Then broadcast c.broker.Broadcast("agent.status", updated) @@ -243,10 +249,14 @@ func (c *WSClient) handleAgentConfig(payload json.RawMessage) { defer cancel() // Build partial update with available fields. - // Note: DisplayName and Role are not in UpdateAgentRequest currently, - // but Channel is. We update what we can and note the gap. update := models.UpdateAgentRequest{} + if cfg.Name != "" { + update.DisplayName = &cfg.Name + } + if cfg.Role != "" { + update.Role = &cfg.Role + } if cfg.Channel != "" { update.Channel = &cfg.Channel } @@ -259,14 +269,6 @@ func (c *WSClient) handleAgentConfig(payload json.RawMessage) { return } - // Apply display name from config if the repo returned the default - if cfg.Name != "" { - updated.DisplayName = cfg.Name - } - if cfg.Role != "" { - updated.Role = cfg.Role - } - // Then broadcast fleet snapshot allAgents, err := c.agents.List(ctx, "") if err != nil { diff --git a/go-backend/internal/gateway/sync.go b/go-backend/internal/gateway/sync.go index 84d1301..3352ed3 100644 --- a/go-backend/internal/gateway/sync.go +++ b/go-backend/internal/gateway/sync.go @@ -42,6 +42,11 @@ type sessionListItem struct { // persists them, merges session state into agent cards, and broadcasts // the merged fleet as a fleet.update event. func (c *WSClient) initialSync(ctx context.Context) error { + if c.agents == nil { + c.logger.Info("initial sync skipped (no repository)") + return nil + } + c.logger.Info("initial sync starting") // 1. Fetch agents @@ -77,12 +82,12 @@ func (c *WSClient) initialSync(ctx context.Context) error { newName := card.DisplayName newRole := card.Role _, updateErr := c.agents.Update(ctx, card.ID, models.UpdateAgentRequest{ - CurrentTask: &newName, // reuse field for display name update + DisplayName: &newName, + Role: &newRole, }) if updateErr != nil { c.logger.Warn("sync: agent update failed", "id", card.ID, "error", updateErr) } - _ = newRole // role not in UpdateAgentRequest yet, skip silently } } diff --git a/go-backend/internal/gateway/wsclient.go b/go-backend/internal/gateway/wsclient.go index 462bf08..322d551 100644 --- a/go-backend/internal/gateway/wsclient.go +++ b/go-backend/internal/gateway/wsclient.go @@ -55,6 +55,7 @@ type WSClient struct { handlers map[string][]eventHandler connId string // set after successful hello-ok restClient *Client // optional REST client to notify on WS ready + wsReadyOnce sync.Once // ensures MarkWSReady close is one-shot } // NewWSClient returns a WSClient wired to the given repository and broker. @@ -142,8 +143,9 @@ type helloOKResponse struct { // read loop. On disconnect it reconnects with exponential backoff. On // ctx cancellation it performs a clean shutdown. func (c *WSClient) Start(ctx context.Context) { - backoff := 1 * time.Second + initialBackoff := 1 * time.Second maxBackoff := 30 * time.Second + backoff := initialBackoff for { err := c.connectAndRun(ctx) @@ -155,6 +157,9 @@ func (c *WSClient) Start(ctx context.Context) { c.logger.Warn("ws client disconnected, reconnecting", "error", err, "backoff", backoff) + } else { + // Reset backoff on successful connect+run completion + backoff = initialBackoff } select { @@ -189,7 +194,16 @@ func (c *WSClient) connectAndRun(ctx context.Context) error { c.conn = conn c.connMu.Unlock() - // Reset backoff on successful connect + // When context is cancelled, close the conn to unblock ReadJSON in readLoop. + go func() { + <-ctx.Done() + c.connMu.Lock() + if c.conn != nil { + c.conn.Close() + } + c.connMu.Unlock() + }() + defer func() { conn.Close() }() @@ -221,6 +235,9 @@ func (c *WSClient) connectAndRun(ctx context.Context) error { c.logger.Info("ws client notified REST fallback to stand down") } + // Reset wsReadyOnce so MarkWSReady can fire again after a reconnect + c.wsReadyOnce = sync.Once{} + // Step 2b: Initial sync — fetch agents + sessions from gateway if err := c.initialSync(ctx); err != nil { c.logger.Warn("initial sync failed, will continue with read loop", "error", err) @@ -309,25 +326,15 @@ func (c *WSClient) sendConnect(conn *websocket.Conn) (*helloOKResponse, error) { } // readLoop continuously reads frames from the connection and routes them. -// It returns on read error or context cancellation. +// It returns on read error or when the connection is closed by the ctx-done +// goroutine started in connectAndRun. func (c *WSClient) readLoop(ctx context.Context, conn *websocket.Conn) error { for { - select { - case <-ctx.Done(): - // Clean shutdown: send close frame - c.connMu.Lock() - c.conn.WriteControl( - websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseNormalClosure, "shutdown"), - time.Now().Add(5*time.Second), - ) - c.connMu.Unlock() - return ctx.Err() - default: - } - var frame wsFrame if err := conn.ReadJSON(&frame); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } // Check if it's a close error if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { c.logger.Info("ws connection closed by server") @@ -398,9 +405,8 @@ func (c *WSClient) handleEvent(frame wsFrame) { // ── Send ───────────────────────────────────────────────────────────────── // Send sends a JSON request to the gateway and returns the response payload. -// It is safe for concurrent use. The caller should check for errors in the -// returned payload. A nil payload with nil error means the gateway sent an -// error response (check via the response frame's error field, which is logged). +// It is safe for concurrent use. Returns an error if the client is not +// connected. func (c *WSClient) Send(method string, params any) (json.RawMessage, error) { reqID := uuid.New().String() @@ -430,6 +436,10 @@ func (c *WSClient) Send(method string, params any) (json.RawMessage, error) { } c.connMu.Lock() + if c.conn == nil { + c.connMu.Unlock() + return nil, fmt.Errorf("gateway: not connected") + } err = c.conn.WriteJSON(frame) c.connMu.Unlock() diff --git a/go-backend/internal/gateway/wsclient_test.go b/go-backend/internal/gateway/wsclient_test.go index c028acf..92a1d66 100644 --- a/go-backend/internal/gateway/wsclient_test.go +++ b/go-backend/internal/gateway/wsclient_test.go @@ -1,11 +1,390 @@ package gateway import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" "testing" + "time" "code.cubecraftcreations.com/CubeCraft-Creations/Control-Center/go-backend/internal/models" + + "github.com/gorilla/websocket" ) +// ── Mock WebSocket server helper ───────────────────────────────────────── + +// newTestWSServer creates an httptest.Server that upgrades to WebSocket and +// delegates each connection to handler. The server URL can be converted to +// a ws:// URL by replacing "http" with "ws". +func newTestWSServer(t *testing.T, handler func(conn *websocket.Conn)) *httptest.Server { + t.Helper() + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + handler(conn) + })) + return srv +} + +// wsURL converts an httptest.Server http URL to a ws URL. +func wsURL(srv *httptest.Server) string { + return "ws" + strings.TrimPrefix(srv.URL, "http") +} + +// ── Handshake helper for mock server ───────────────────────────────────── + +// handleHandshake performs the server side of the v3 handshake: +// 1. Send connect.challenge +// 2. Read connect request +// 3. Send hello-ok response +// +// Returns the connect request frame for inspection. +func handleHandshake(t *testing.T, conn *websocket.Conn) map[string]any { + t.Helper() + + // 1. Send connect.challenge + challenge := map[string]any{ + "type": "event", + "event": "connect.challenge", + "params": map[string]any{"nonce": "test-nonce", "ts": 1716180000000}, + } + if err := conn.WriteJSON(challenge); err != nil { + t.Fatalf("server: write challenge: %v", err) + } + + // 2. Read connect request + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Fatalf("server: read connect request: %v", err) + } + + if req["method"] != "connect" { + t.Fatalf("server: expected method=connect, got %v", req["method"]) + } + + // 3. Send hello-ok response + // Note: helloOKResponse expects ConnID at the top level of the result, + // matching the WSClient's JSON struct tags. + result := map[string]any{ + "type": "hello-ok", + "protocol": 3, + "connId": "test-conn-123", + "features": map[string]any{"methods": []string{}, "events": []string{}}, + "auth": map[string]any{"role": "operator", "scopes": []string{"operator.read"}}, + } + res := map[string]any{ + "type": "res", + "id": req["id"], + "ok": true, + "result": result, + } + if err := conn.WriteJSON(res); err != nil { + t.Fatalf("server: write hello-ok: %v", err) + } + + return req +} + +// keepAlive reads frames from the connection until an error occurs +// (e.g., the client disconnects). Used as the default "do nothing" +// server loop after handshake. +func keepAlive(conn *websocket.Conn) { + for { + var m map[string]any + if err := conn.ReadJSON(&m); err != nil { + break + } + } +} + +// ── 1. Test: Full handshake ────────────────────────────────────────────── + +func TestWSClient_Handshake(t *testing.T) { + srv := newTestWSServer(t, func(conn *websocket.Conn) { + handleHandshake(t, conn) + keepAlive(conn) + }) + defer srv.Close() + + client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + done := make(chan struct{}) + go func() { + client.Start(ctx) + close(done) + }() + + // Wait briefly for handshake to complete + time.Sleep(200 * time.Millisecond) + + // Verify connId was set + client.connMu.Lock() + connID := client.connId + client.connMu.Unlock() + + if connID != "test-conn-123" { + t.Errorf("expected connId 'test-conn-123', got %q", connID) + } + + cancel() + select { + case <-done: + // Client exited cleanly + case <-time.After(3 * time.Second): + t.Fatal("WSClient did not shut down after context cancellation") + } +} + +// ── 2. Test: Send() with response matching ─────────────────────────────── + +func TestWSClient_Send(t *testing.T) { + srv := newTestWSServer(t, func(conn *websocket.Conn) { + handleHandshake(t, conn) + + // Read RPC requests and respond to each + for { + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + break + } + reqID, _ := req["id"].(string) + method, _ := req["method"].(string) + + var result any + switch method { + case "agents.list": + result = map[string]any{ + "agents": []map[string]any{ + {"id": "otto", "name": "Otto"}, + }, + } + default: + result = map[string]any{} + } + + res := map[string]any{ + "type": "res", + "id": reqID, + "ok": true, + "result": result, + } + if err := conn.WriteJSON(res); err != nil { + break + } + } + }) + defer srv.Close() + + client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go client.Start(ctx) + + // Give the client time to complete handshake + time.Sleep(300 * time.Millisecond) + + resp, err := client.Send("agents.list", nil) + if err != nil { + t.Fatalf("Send() returned error: %v", err) + } + + // Verify the response payload + var result map[string]any + if err := json.Unmarshal(resp, &result); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + agents, ok := result["agents"].([]any) + if !ok || len(agents) != 1 { + t.Errorf("expected 1 agent in response, got %v", result) + } + + cancel() +} + +// ── 3. Test: Event handler routing ─────────────────────────────────────── + +func TestWSClient_EventRouting(t *testing.T) { + eventReceived := make(chan json.RawMessage, 1) + + srv := newTestWSServer(t, func(conn *websocket.Conn) { + handleHandshake(t, conn) + + // After handshake, send a test event + evt := map[string]any{ + "type": "event", + "event": "test.event", + "params": map[string]any{"greeting": "hello from server"}, + } + if err := conn.WriteJSON(evt); err != nil { + t.Logf("server: write event: %v", err) + return + } + + keepAlive(conn) + }) + defer srv.Close() + + client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default()) + + // Register event handler BEFORE starting the client + client.OnEvent("test.event", func(payload json.RawMessage) { + eventReceived <- payload + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go client.Start(ctx) + + // Wait for the event handler to fire + select { + case payload := <-eventReceived: + var data map[string]any + if err := json.Unmarshal(payload, &data); err != nil { + t.Fatalf("unmarshal event payload: %v", err) + } + if greeting, _ := data["greeting"].(string); greeting != "hello from server" { + t.Errorf("expected greeting 'hello from server', got %q", greeting) + } + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for event handler to fire") + } + + cancel() +} + +// ── 4. Test: Concurrent Send ───────────────────────────────────────────── + +func TestWSClient_ConcurrentSend(t *testing.T) { + var reqCount atomic.Int32 + + srv := newTestWSServer(t, func(conn *websocket.Conn) { + handleHandshake(t, conn) + + // Read RPC requests and respond to each + for { + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + break + } + reqID, _ := req["id"].(string) + n := reqCount.Add(1) + + res := map[string]any{ + "type": "res", + "id": reqID, + "ok": true, + "result": map[string]any{"index": n, "method": req["method"]}, + } + if err := conn.WriteJSON(res); err != nil { + break + } + } + }) + defer srv.Close() + + client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default()) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + go client.Start(ctx) + + // Give the client time to complete handshake + time.Sleep(300 * time.Millisecond) + + // Fire 3 concurrent Send() calls + type sendResult struct { + method string + payload json.RawMessage + err error + } + results := make(chan sendResult, 3) + + methods := []string{"agents.list", "sessions.list", "agents.config"} + for _, method := range methods { + go func(m string) { + resp, err := client.Send(m, nil) + results <- sendResult{method: m, payload: resp, err: err} + }(method) + } + + // Collect all results + for i := 0; i < 3; i++ { + select { + case r := <-results: + if r.err != nil { + t.Errorf("Send(%q) returned error: %v", r.method, r.err) + continue + } + var result map[string]any + if err := json.Unmarshal(r.payload, &result); err != nil { + t.Errorf("Send(%q) unmarshal error: %v", r.method, err) + continue + } + gotMethod, _ := result["method"].(string) + if gotMethod != r.method { + t.Errorf("Send(%q) got response for %q (mismatched)", r.method, gotMethod) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for concurrent Send results") + } + } + + cancel() +} + +// ── 5. Test: Clean shutdown ────────────────────────────────────────────── + +func TestWSClient_CleanShutdown(t *testing.T) { + srv := newTestWSServer(t, func(conn *websocket.Conn) { + handleHandshake(t, conn) + keepAlive(conn) + }) + defer srv.Close() + + client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + + done := make(chan struct{}) + go func() { + client.Start(ctx) + close(done) + }() + + // Let the client connect and complete handshake + time.Sleep(200 * time.Millisecond) + + // Cancel context — should trigger clean shutdown + cancel() + + select { + case <-done: + // Client exited cleanly — pass + case <-time.After(3 * time.Second): + t.Fatal("WSClient did not shut down cleanly within timeout") + } +} + +// ── Pure utility tests (from CUB-205) ───────────────────────────────────── + func TestMapSessionStatus(t *testing.T) { tests := []struct { input string diff --git a/go-backend/internal/models/models.go b/go-backend/internal/models/models.go index 8480b41..2844cff 100644 --- a/go-backend/internal/models/models.go +++ b/go-backend/internal/models/models.go @@ -63,12 +63,15 @@ type CreateAgentRequest struct { // UpdateAgentRequest is the payload for PUT /api/agents/{id}. type UpdateAgentRequest struct { - Status *AgentStatus `json:"status,omitempty" validate:"omitempty,agentStatus"` - CurrentTask *string `json:"currentTask,omitempty"` - TaskProgress *int `json:"taskProgress,omitempty" validate:"omitempty,min=0,max=100"` - TaskElapsed *string `json:"taskElapsed,omitempty"` - Channel *string `json:"channel,omitempty" validate:"omitempty,min=1,max=32"` - ErrorMessage *string `json:"errorMessage,omitempty"` + Status *AgentStatus `json:"status,omitempty" validate:"omitempty,agentStatus"` + DisplayName *string `json:"displayName,omitempty"` + Role *string `json:"role,omitempty"` + LastActivityAt *string `json:"lastActivityAt,omitempty"` + CurrentTask *string `json:"currentTask,omitempty"` + TaskProgress *int `json:"taskProgress,omitempty" validate:"omitempty,min=0,max=100"` + TaskElapsed *string `json:"taskElapsed,omitempty"` + Channel *string `json:"channel,omitempty" validate:"omitempty,min=1,max=32"` + ErrorMessage *string `json:"errorMessage,omitempty"` } // AgentStatusHistoryEntry represents a point-in-time status change for an agent.