CUB-203: WebSocket client scaffold for OpenClaw gateway v3 #41
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"code.cubecraftcreations.com/CubeCraft-Creations/Control-Center/go-backend/internal/handler"
|
"code.cubecraftcreations.com/CubeCraft-Creations/Control-Center/go-backend/internal/handler"
|
||||||
@@ -28,6 +29,7 @@ type Client struct {
|
|||||||
broker *handler.Broker
|
broker *handler.Broker
|
||||||
wsClient *WSClient // optional WS client; when set, REST is fallback only
|
wsClient *WSClient // optional WS client; when set, REST is fallback only
|
||||||
wsReady chan struct{} // closed once WS connection is established
|
wsReady chan struct{} // closed once WS connection is established
|
||||||
|
wsReadyOnce sync.Once // protects wsReady close from double-close race
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config holds gateway client configuration, typically loaded from environment.
|
// Config holds gateway client configuration, typically loaded from environment.
|
||||||
@@ -66,12 +68,9 @@ func (c *Client) SetWSClient(ws *WSClient) {
|
|||||||
// MarkWSReady signals that the WS connection is live and the REST poller
|
// MarkWSReady signals that the WS connection is live and the REST poller
|
||||||
// should stand down. Called by WSClient after a successful handshake.
|
// should stand down. Called by WSClient after a successful handshake.
|
||||||
func (c *Client) MarkWSReady() {
|
func (c *Client) MarkWSReady() {
|
||||||
select {
|
c.wsReadyOnce.Do(func() {
|
||||||
case <-c.wsReady:
|
|
||||||
// already closed
|
|
||||||
default:
|
|
||||||
close(c.wsReady)
|
close(c.wsReady)
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins the gateway client loop. When a WS client is wired, it
|
// Start begins the gateway client loop. When a WS client is wired, it
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ import (
|
|||||||
// ── Event payload types ──────────────────────────────────────────────────
|
// ── Event payload types ──────────────────────────────────────────────────
|
||||||
|
|
||||||
// sessionChangedPayload represents a single session delta from a
|
// sessionChangedPayload represents a single session delta from a
|
||||||
// sessions.changed event. Fields are optional; use json.RawMessage for
|
// sessions.changed event.
|
||||||
// anything we don't strictly need.
|
|
||||||
type sessionChangedPayload struct {
|
type sessionChangedPayload struct {
|
||||||
SessionKey string `json:"sessionKey"`
|
SessionKey string `json:"sessionKey"`
|
||||||
AgentID string `json:"agentId"`
|
AgentID string `json:"agentId"`
|
||||||
@@ -30,7 +29,6 @@ type sessionChangedPayload struct {
|
|||||||
TaskProgress *int `json:"taskProgress,omitempty"`
|
TaskProgress *int `json:"taskProgress,omitempty"`
|
||||||
TaskElapsed string `json:"taskElapsed"`
|
TaskElapsed string `json:"taskElapsed"`
|
||||||
ErrorMessage string `json:"errorMessage"`
|
ErrorMessage string `json:"errorMessage"`
|
||||||
Extra json.RawMessage `json:"-"` // ignored; prevents crash on unknown fields
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// presencePayload represents a device presence update event.
|
// presencePayload represents a device presence update event.
|
||||||
@@ -38,7 +36,6 @@ type presencePayload struct {
|
|||||||
AgentID string `json:"agentId"`
|
AgentID string `json:"agentId"`
|
||||||
Connected *bool `json:"connected,omitempty"`
|
Connected *bool `json:"connected,omitempty"`
|
||||||
LastActivityAt string `json:"lastActivityAt"`
|
LastActivityAt string `json:"lastActivityAt"`
|
||||||
Extra json.RawMessage `json:"-"` // ignored
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// agentConfigPayload represents an agent configuration change event.
|
// agentConfigPayload represents an agent configuration change event.
|
||||||
@@ -49,7 +46,6 @@ type agentConfigPayload struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Channel string `json:"channel"`
|
Channel string `json:"channel"`
|
||||||
Metadata json.RawMessage `json:"metadata"`
|
Metadata json.RawMessage `json:"metadata"`
|
||||||
Extra json.RawMessage `json:"-"` // ignored
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Handler registration ─────────────────────────────────────────────────
|
// ── Handler registration ─────────────────────────────────────────────────
|
||||||
@@ -57,6 +53,16 @@ type agentConfigPayload struct {
|
|||||||
// registerEventHandlers sets up all live event handlers on the WSClient.
|
// registerEventHandlers sets up all live event handlers on the WSClient.
|
||||||
// Call this once after a successful handshake + initial sync.
|
// Call this once after a successful handshake + initial sync.
|
||||||
func (c *WSClient) registerEventHandlers() {
|
func (c *WSClient) registerEventHandlers() {
|
||||||
|
if c.agents == nil || c.broker == nil {
|
||||||
|
c.logger.Info("event handlers skipped (no repository or broker)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear existing handlers to prevent duplicates on reconnect
|
||||||
|
c.mu.Lock()
|
||||||
|
c.handlers = make(map[string][]eventHandler)
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
c.OnEvent("sessions.changed", c.handleSessionsChanged)
|
c.OnEvent("sessions.changed", c.handleSessionsChanged)
|
||||||
c.OnEvent("presence", c.handlePresence)
|
c.OnEvent("presence", c.handlePresence)
|
||||||
c.OnEvent("agent.config", c.handleAgentConfig)
|
c.OnEvent("agent.config", c.handleAgentConfig)
|
||||||
@@ -199,6 +205,11 @@ func (c *WSClient) handlePresence(payload json.RawMessage) {
|
|||||||
update.Status = &idle
|
update.Status = &idle
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pass lastActivityAt from the event so DB and SSE stay consistent
|
||||||
|
if p.LastActivityAt != "" {
|
||||||
|
update.LastActivityAt = &p.LastActivityAt
|
||||||
|
}
|
||||||
|
|
||||||
// Update DB first
|
// Update DB first
|
||||||
updated, err := c.agents.Update(ctx, p.AgentID, update)
|
updated, err := c.agents.Update(ctx, p.AgentID, update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -207,11 +218,6 @@ func (c *WSClient) handlePresence(payload json.RawMessage) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use reported timestamp if available
|
|
||||||
if p.LastActivityAt != "" {
|
|
||||||
updated.LastActivity = p.LastActivityAt
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then broadcast
|
// Then broadcast
|
||||||
c.broker.Broadcast("agent.status", updated)
|
c.broker.Broadcast("agent.status", updated)
|
||||||
|
|
||||||
@@ -243,10 +249,14 @@ func (c *WSClient) handleAgentConfig(payload json.RawMessage) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Build partial update with available fields.
|
// Build partial update with available fields.
|
||||||
// Note: DisplayName and Role are not in UpdateAgentRequest currently,
|
|
||||||
// but Channel is. We update what we can and note the gap.
|
|
||||||
update := models.UpdateAgentRequest{}
|
update := models.UpdateAgentRequest{}
|
||||||
|
|
||||||
|
if cfg.Name != "" {
|
||||||
|
update.DisplayName = &cfg.Name
|
||||||
|
}
|
||||||
|
if cfg.Role != "" {
|
||||||
|
update.Role = &cfg.Role
|
||||||
|
}
|
||||||
if cfg.Channel != "" {
|
if cfg.Channel != "" {
|
||||||
update.Channel = &cfg.Channel
|
update.Channel = &cfg.Channel
|
||||||
}
|
}
|
||||||
@@ -259,14 +269,6 @@ func (c *WSClient) handleAgentConfig(payload json.RawMessage) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply display name from config if the repo returned the default
|
|
||||||
if cfg.Name != "" {
|
|
||||||
updated.DisplayName = cfg.Name
|
|
||||||
}
|
|
||||||
if cfg.Role != "" {
|
|
||||||
updated.Role = cfg.Role
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then broadcast fleet snapshot
|
// Then broadcast fleet snapshot
|
||||||
allAgents, err := c.agents.List(ctx, "")
|
allAgents, err := c.agents.List(ctx, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -42,6 +42,11 @@ type sessionListItem struct {
|
|||||||
// persists them, merges session state into agent cards, and broadcasts
|
// persists them, merges session state into agent cards, and broadcasts
|
||||||
// the merged fleet as a fleet.update event.
|
// the merged fleet as a fleet.update event.
|
||||||
func (c *WSClient) initialSync(ctx context.Context) error {
|
func (c *WSClient) initialSync(ctx context.Context) error {
|
||||||
|
if c.agents == nil {
|
||||||
|
c.logger.Info("initial sync skipped (no repository)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
c.logger.Info("initial sync starting")
|
c.logger.Info("initial sync starting")
|
||||||
|
|
||||||
// 1. Fetch agents
|
// 1. Fetch agents
|
||||||
@@ -77,12 +82,12 @@ func (c *WSClient) initialSync(ctx context.Context) error {
|
|||||||
newName := card.DisplayName
|
newName := card.DisplayName
|
||||||
newRole := card.Role
|
newRole := card.Role
|
||||||
_, updateErr := c.agents.Update(ctx, card.ID, models.UpdateAgentRequest{
|
_, updateErr := c.agents.Update(ctx, card.ID, models.UpdateAgentRequest{
|
||||||
CurrentTask: &newName, // reuse field for display name update
|
DisplayName: &newName,
|
||||||
|
Role: &newRole,
|
||||||
})
|
})
|
||||||
if updateErr != nil {
|
if updateErr != nil {
|
||||||
c.logger.Warn("sync: agent update failed", "id", card.ID, "error", updateErr)
|
c.logger.Warn("sync: agent update failed", "id", card.ID, "error", updateErr)
|
||||||
}
|
}
|
||||||
_ = newRole // role not in UpdateAgentRequest yet, skip silently
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ type WSClient struct {
|
|||||||
handlers map[string][]eventHandler
|
handlers map[string][]eventHandler
|
||||||
connId string // set after successful hello-ok
|
connId string // set after successful hello-ok
|
||||||
restClient *Client // optional REST client to notify on WS ready
|
restClient *Client // optional REST client to notify on WS ready
|
||||||
|
wsReadyOnce sync.Once // ensures MarkWSReady close is one-shot
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWSClient returns a WSClient wired to the given repository and broker.
|
// NewWSClient returns a WSClient wired to the given repository and broker.
|
||||||
@@ -142,8 +143,9 @@ type helloOKResponse struct {
|
|||||||
// read loop. On disconnect it reconnects with exponential backoff. On
|
// read loop. On disconnect it reconnects with exponential backoff. On
|
||||||
// ctx cancellation it performs a clean shutdown.
|
// ctx cancellation it performs a clean shutdown.
|
||||||
func (c *WSClient) Start(ctx context.Context) {
|
func (c *WSClient) Start(ctx context.Context) {
|
||||||
backoff := 1 * time.Second
|
initialBackoff := 1 * time.Second
|
||||||
maxBackoff := 30 * time.Second
|
maxBackoff := 30 * time.Second
|
||||||
|
backoff := initialBackoff
|
||||||
|
|
||||||
for {
|
for {
|
||||||
err := c.connectAndRun(ctx)
|
err := c.connectAndRun(ctx)
|
||||||
@@ -155,6 +157,9 @@ func (c *WSClient) Start(ctx context.Context) {
|
|||||||
c.logger.Warn("ws client disconnected, reconnecting",
|
c.logger.Warn("ws client disconnected, reconnecting",
|
||||||
"error", err,
|
"error", err,
|
||||||
"backoff", backoff)
|
"backoff", backoff)
|
||||||
|
} else {
|
||||||
|
// Reset backoff on successful connect+run completion
|
||||||
|
backoff = initialBackoff
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -189,7 +194,16 @@ func (c *WSClient) connectAndRun(ctx context.Context) error {
|
|||||||
c.conn = conn
|
c.conn = conn
|
||||||
c.connMu.Unlock()
|
c.connMu.Unlock()
|
||||||
|
|
||||||
// Reset backoff on successful connect
|
// When context is cancelled, close the conn to unblock ReadJSON in readLoop.
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
c.connMu.Lock()
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
c.connMu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}()
|
}()
|
||||||
@@ -221,6 +235,9 @@ func (c *WSClient) connectAndRun(ctx context.Context) error {
|
|||||||
c.logger.Info("ws client notified REST fallback to stand down")
|
c.logger.Info("ws client notified REST fallback to stand down")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset wsReadyOnce so MarkWSReady can fire again after a reconnect
|
||||||
|
c.wsReadyOnce = sync.Once{}
|
||||||
|
|
||||||
// Step 2b: Initial sync — fetch agents + sessions from gateway
|
// Step 2b: Initial sync — fetch agents + sessions from gateway
|
||||||
if err := c.initialSync(ctx); err != nil {
|
if err := c.initialSync(ctx); err != nil {
|
||||||
c.logger.Warn("initial sync failed, will continue with read loop", "error", err)
|
c.logger.Warn("initial sync failed, will continue with read loop", "error", err)
|
||||||
@@ -309,25 +326,15 @@ func (c *WSClient) sendConnect(conn *websocket.Conn) (*helloOKResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// readLoop continuously reads frames from the connection and routes them.
|
// readLoop continuously reads frames from the connection and routes them.
|
||||||
// It returns on read error or context cancellation.
|
// It returns on read error or when the connection is closed by the ctx-done
|
||||||
|
// goroutine started in connectAndRun.
|
||||||
func (c *WSClient) readLoop(ctx context.Context, conn *websocket.Conn) error {
|
func (c *WSClient) readLoop(ctx context.Context, conn *websocket.Conn) error {
|
||||||
for {
|
for {
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
// Clean shutdown: send close frame
|
|
||||||
c.connMu.Lock()
|
|
||||||
c.conn.WriteControl(
|
|
||||||
websocket.CloseMessage,
|
|
||||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "shutdown"),
|
|
||||||
time.Now().Add(5*time.Second),
|
|
||||||
)
|
|
||||||
c.connMu.Unlock()
|
|
||||||
return ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
var frame wsFrame
|
var frame wsFrame
|
||||||
if err := conn.ReadJSON(&frame); err != nil {
|
if err := conn.ReadJSON(&frame); err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
// Check if it's a close error
|
// Check if it's a close error
|
||||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||||
c.logger.Info("ws connection closed by server")
|
c.logger.Info("ws connection closed by server")
|
||||||
@@ -398,9 +405,8 @@ func (c *WSClient) handleEvent(frame wsFrame) {
|
|||||||
// ── Send ─────────────────────────────────────────────────────────────────
|
// ── Send ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// Send sends a JSON request to the gateway and returns the response payload.
|
// Send sends a JSON request to the gateway and returns the response payload.
|
||||||
// It is safe for concurrent use. The caller should check for errors in the
|
// It is safe for concurrent use. Returns an error if the client is not
|
||||||
// returned payload. A nil payload with nil error means the gateway sent an
|
// connected.
|
||||||
// error response (check via the response frame's error field, which is logged).
|
|
||||||
func (c *WSClient) Send(method string, params any) (json.RawMessage, error) {
|
func (c *WSClient) Send(method string, params any) (json.RawMessage, error) {
|
||||||
reqID := uuid.New().String()
|
reqID := uuid.New().String()
|
||||||
|
|
||||||
@@ -430,6 +436,10 @@ func (c *WSClient) Send(method string, params any) (json.RawMessage, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.connMu.Lock()
|
c.connMu.Lock()
|
||||||
|
if c.conn == nil {
|
||||||
|
c.connMu.Unlock()
|
||||||
|
return nil, fmt.Errorf("gateway: not connected")
|
||||||
|
}
|
||||||
err = c.conn.WriteJSON(frame)
|
err = c.conn.WriteJSON(frame)
|
||||||
c.connMu.Unlock()
|
c.connMu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,390 @@
|
|||||||
package gateway
|
package gateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"code.cubecraftcreations.com/CubeCraft-Creations/Control-Center/go-backend/internal/models"
|
"code.cubecraftcreations.com/CubeCraft-Creations/Control-Center/go-backend/internal/models"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ── Mock WebSocket server helper ─────────────────────────────────────────
|
||||||
|
|
||||||
|
// newTestWSServer creates an httptest.Server that upgrades to WebSocket and
|
||||||
|
// delegates each connection to handler. The server URL can be converted to
|
||||||
|
// a ws:// URL by replacing "http" with "ws".
|
||||||
|
func newTestWSServer(t *testing.T, handler func(conn *websocket.Conn)) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
upgrader := websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
}
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler(conn)
|
||||||
|
}))
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsURL converts an httptest.Server http URL to a ws URL.
|
||||||
|
func wsURL(srv *httptest.Server) string {
|
||||||
|
return "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Handshake helper for mock server ─────────────────────────────────────
|
||||||
|
|
||||||
|
// handleHandshake performs the server side of the v3 handshake:
|
||||||
|
// 1. Send connect.challenge
|
||||||
|
// 2. Read connect request
|
||||||
|
// 3. Send hello-ok response
|
||||||
|
//
|
||||||
|
// Returns the connect request frame for inspection.
|
||||||
|
func handleHandshake(t *testing.T, conn *websocket.Conn) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// 1. Send connect.challenge
|
||||||
|
challenge := map[string]any{
|
||||||
|
"type": "event",
|
||||||
|
"event": "connect.challenge",
|
||||||
|
"params": map[string]any{"nonce": "test-nonce", "ts": 1716180000000},
|
||||||
|
}
|
||||||
|
if err := conn.WriteJSON(challenge); err != nil {
|
||||||
|
t.Fatalf("server: write challenge: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Read connect request
|
||||||
|
var req map[string]any
|
||||||
|
if err := conn.ReadJSON(&req); err != nil {
|
||||||
|
t.Fatalf("server: read connect request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req["method"] != "connect" {
|
||||||
|
t.Fatalf("server: expected method=connect, got %v", req["method"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Send hello-ok response
|
||||||
|
// Note: helloOKResponse expects ConnID at the top level of the result,
|
||||||
|
// matching the WSClient's JSON struct tags.
|
||||||
|
result := map[string]any{
|
||||||
|
"type": "hello-ok",
|
||||||
|
"protocol": 3,
|
||||||
|
"connId": "test-conn-123",
|
||||||
|
"features": map[string]any{"methods": []string{}, "events": []string{}},
|
||||||
|
"auth": map[string]any{"role": "operator", "scopes": []string{"operator.read"}},
|
||||||
|
}
|
||||||
|
res := map[string]any{
|
||||||
|
"type": "res",
|
||||||
|
"id": req["id"],
|
||||||
|
"ok": true,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
if err := conn.WriteJSON(res); err != nil {
|
||||||
|
t.Fatalf("server: write hello-ok: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// keepAlive reads frames from the connection until an error occurs
|
||||||
|
// (e.g., the client disconnects). Used as the default "do nothing"
|
||||||
|
// server loop after handshake.
|
||||||
|
func keepAlive(conn *websocket.Conn) {
|
||||||
|
for {
|
||||||
|
var m map[string]any
|
||||||
|
if err := conn.ReadJSON(&m); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 1. Test: Full handshake ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestWSClient_Handshake(t *testing.T) {
|
||||||
|
srv := newTestWSServer(t, func(conn *websocket.Conn) {
|
||||||
|
handleHandshake(t, conn)
|
||||||
|
keepAlive(conn)
|
||||||
|
})
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
client.Start(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait briefly for handshake to complete
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify connId was set
|
||||||
|
client.connMu.Lock()
|
||||||
|
connID := client.connId
|
||||||
|
client.connMu.Unlock()
|
||||||
|
|
||||||
|
if connID != "test-conn-123" {
|
||||||
|
t.Errorf("expected connId 'test-conn-123', got %q", connID)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Client exited cleanly
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("WSClient did not shut down after context cancellation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 2. Test: Send() with response matching ───────────────────────────────
|
||||||
|
|
||||||
|
func TestWSClient_Send(t *testing.T) {
|
||||||
|
srv := newTestWSServer(t, func(conn *websocket.Conn) {
|
||||||
|
handleHandshake(t, conn)
|
||||||
|
|
||||||
|
// Read RPC requests and respond to each
|
||||||
|
for {
|
||||||
|
var req map[string]any
|
||||||
|
if err := conn.ReadJSON(&req); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
reqID, _ := req["id"].(string)
|
||||||
|
method, _ := req["method"].(string)
|
||||||
|
|
||||||
|
var result any
|
||||||
|
switch method {
|
||||||
|
case "agents.list":
|
||||||
|
result = map[string]any{
|
||||||
|
"agents": []map[string]any{
|
||||||
|
{"id": "otto", "name": "Otto"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
result = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
res := map[string]any{
|
||||||
|
"type": "res",
|
||||||
|
"id": reqID,
|
||||||
|
"ok": true,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
if err := conn.WriteJSON(res); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go client.Start(ctx)
|
||||||
|
|
||||||
|
// Give the client time to complete handshake
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
resp, err := client.Send("agents.list", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Send() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the response payload
|
||||||
|
var result map[string]any
|
||||||
|
if err := json.Unmarshal(resp, &result); err != nil {
|
||||||
|
t.Fatalf("unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
agents, ok := result["agents"].([]any)
|
||||||
|
if !ok || len(agents) != 1 {
|
||||||
|
t.Errorf("expected 1 agent in response, got %v", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 3. Test: Event handler routing ───────────────────────────────────────
|
||||||
|
|
||||||
|
func TestWSClient_EventRouting(t *testing.T) {
|
||||||
|
eventReceived := make(chan json.RawMessage, 1)
|
||||||
|
|
||||||
|
srv := newTestWSServer(t, func(conn *websocket.Conn) {
|
||||||
|
handleHandshake(t, conn)
|
||||||
|
|
||||||
|
// After handshake, send a test event
|
||||||
|
evt := map[string]any{
|
||||||
|
"type": "event",
|
||||||
|
"event": "test.event",
|
||||||
|
"params": map[string]any{"greeting": "hello from server"},
|
||||||
|
}
|
||||||
|
if err := conn.WriteJSON(evt); err != nil {
|
||||||
|
t.Logf("server: write event: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
keepAlive(conn)
|
||||||
|
})
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default())
|
||||||
|
|
||||||
|
// Register event handler BEFORE starting the client
|
||||||
|
client.OnEvent("test.event", func(payload json.RawMessage) {
|
||||||
|
eventReceived <- payload
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go client.Start(ctx)
|
||||||
|
|
||||||
|
// Wait for the event handler to fire
|
||||||
|
select {
|
||||||
|
case payload := <-eventReceived:
|
||||||
|
var data map[string]any
|
||||||
|
if err := json.Unmarshal(payload, &data); err != nil {
|
||||||
|
t.Fatalf("unmarshal event payload: %v", err)
|
||||||
|
}
|
||||||
|
if greeting, _ := data["greeting"].(string); greeting != "hello from server" {
|
||||||
|
t.Errorf("expected greeting 'hello from server', got %q", greeting)
|
||||||
|
}
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for event handler to fire")
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 4. Test: Concurrent Send ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestWSClient_ConcurrentSend(t *testing.T) {
|
||||||
|
var reqCount atomic.Int32
|
||||||
|
|
||||||
|
srv := newTestWSServer(t, func(conn *websocket.Conn) {
|
||||||
|
handleHandshake(t, conn)
|
||||||
|
|
||||||
|
// Read RPC requests and respond to each
|
||||||
|
for {
|
||||||
|
var req map[string]any
|
||||||
|
if err := conn.ReadJSON(&req); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
reqID, _ := req["id"].(string)
|
||||||
|
n := reqCount.Add(1)
|
||||||
|
|
||||||
|
res := map[string]any{
|
||||||
|
"type": "res",
|
||||||
|
"id": reqID,
|
||||||
|
"ok": true,
|
||||||
|
"result": map[string]any{"index": n, "method": req["method"]},
|
||||||
|
}
|
||||||
|
if err := conn.WriteJSON(res); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go client.Start(ctx)
|
||||||
|
|
||||||
|
// Give the client time to complete handshake
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
// Fire 3 concurrent Send() calls
|
||||||
|
type sendResult struct {
|
||||||
|
method string
|
||||||
|
payload json.RawMessage
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
results := make(chan sendResult, 3)
|
||||||
|
|
||||||
|
methods := []string{"agents.list", "sessions.list", "agents.config"}
|
||||||
|
for _, method := range methods {
|
||||||
|
go func(m string) {
|
||||||
|
resp, err := client.Send(m, nil)
|
||||||
|
results <- sendResult{method: m, payload: resp, err: err}
|
||||||
|
}(method)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect all results
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
select {
|
||||||
|
case r := <-results:
|
||||||
|
if r.err != nil {
|
||||||
|
t.Errorf("Send(%q) returned error: %v", r.method, r.err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var result map[string]any
|
||||||
|
if err := json.Unmarshal(r.payload, &result); err != nil {
|
||||||
|
t.Errorf("Send(%q) unmarshal error: %v", r.method, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
gotMethod, _ := result["method"].(string)
|
||||||
|
if gotMethod != r.method {
|
||||||
|
t.Errorf("Send(%q) got response for %q (mismatched)", r.method, gotMethod)
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for concurrent Send results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 5. Test: Clean shutdown ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestWSClient_CleanShutdown(t *testing.T) {
|
||||||
|
srv := newTestWSServer(t, func(conn *websocket.Conn) {
|
||||||
|
handleHandshake(t, conn)
|
||||||
|
keepAlive(conn)
|
||||||
|
})
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
client.Start(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Let the client connect and complete handshake
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Cancel context — should trigger clean shutdown
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Client exited cleanly — pass
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("WSClient did not shut down cleanly within timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Pure utility tests (from CUB-205) ─────────────────────────────────────
|
||||||
|
|
||||||
func TestMapSessionStatus(t *testing.T) {
|
func TestMapSessionStatus(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
@@ -64,6 +64,9 @@ type CreateAgentRequest struct {
|
|||||||
// UpdateAgentRequest is the payload for PUT /api/agents/{id}.
|
// UpdateAgentRequest is the payload for PUT /api/agents/{id}.
|
||||||
type UpdateAgentRequest struct {
|
type UpdateAgentRequest struct {
|
||||||
Status *AgentStatus `json:"status,omitempty" validate:"omitempty,agentStatus"`
|
Status *AgentStatus `json:"status,omitempty" validate:"omitempty,agentStatus"`
|
||||||
|
DisplayName *string `json:"displayName,omitempty"`
|
||||||
|
Role *string `json:"role,omitempty"`
|
||||||
|
LastActivityAt *string `json:"lastActivityAt,omitempty"`
|
||||||
CurrentTask *string `json:"currentTask,omitempty"`
|
CurrentTask *string `json:"currentTask,omitempty"`
|
||||||
TaskProgress *int `json:"taskProgress,omitempty" validate:"omitempty,min=0,max=100"`
|
TaskProgress *int `json:"taskProgress,omitempty" validate:"omitempty,min=0,max=100"`
|
||||||
TaskElapsed *string `json:"taskElapsed,omitempty"`
|
TaskElapsed *string `json:"taskElapsed,omitempty"`
|
||||||
|
|||||||
Reference in New Issue
Block a user