package gateway import ( "context" "encoding/json" "log/slog" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" "time" "code.cubecraftcreations.com/CubeCraft-Creations/Control-Center/go-backend/internal/models" "github.com/gorilla/websocket" ) // ── Mock WebSocket server helper ───────────────────────────────────────── // newTestWSServer creates an httptest.Server that upgrades to WebSocket and // delegates each connection to handler. The server URL can be converted to // a ws:// URL by replacing "http" with "ws". func newTestWSServer(t *testing.T, handler func(conn *websocket.Conn)) *httptest.Server { t.Helper() upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } handler(conn) })) return srv } // wsURL converts an httptest.Server http URL to a ws URL. func wsURL(srv *httptest.Server) string { return "ws" + strings.TrimPrefix(srv.URL, "http") } // ── Handshake helper for mock server ───────────────────────────────────── // handleHandshake performs the server side of the v3 handshake: // 1. Send connect.challenge // 2. Read connect request // 3. Send hello-ok response // // Returns the connect request frame for inspection. func handleHandshake(t *testing.T, conn *websocket.Conn) map[string]any { t.Helper() // 1. Send connect.challenge challenge := map[string]any{ "type": "event", "event": "connect.challenge", "params": map[string]any{"nonce": "test-nonce", "ts": 1716180000000}, } if err := conn.WriteJSON(challenge); err != nil { t.Fatalf("server: write challenge: %v", err) } // 2. Read connect request var req map[string]any if err := conn.ReadJSON(&req); err != nil { t.Fatalf("server: read connect request: %v", err) } if req["method"] != "connect" { t.Fatalf("server: expected method=connect, got %v", req["method"]) } // 3. Send hello-ok response // Note: helloOKResponse expects ConnID at the top level of the result, // matching the WSClient's JSON struct tags. result := map[string]any{ "type": "hello-ok", "protocol": 3, "connId": "test-conn-123", "features": map[string]any{"methods": []string{}, "events": []string{}}, "auth": map[string]any{"role": "operator", "scopes": []string{"operator.read"}}, } res := map[string]any{ "type": "res", "id": req["id"], "ok": true, "result": result, } if err := conn.WriteJSON(res); err != nil { t.Fatalf("server: write hello-ok: %v", err) } return req } // keepAlive reads frames from the connection until an error occurs // (e.g., the client disconnects). Used as the default "do nothing" // server loop after handshake. func keepAlive(conn *websocket.Conn) { for { var m map[string]any if err := conn.ReadJSON(&m); err != nil { break } } } // ── 1. Test: Full handshake ────────────────────────────────────────────── func TestWSClient_Handshake(t *testing.T) { srv := newTestWSServer(t, func(conn *websocket.Conn) { handleHandshake(t, conn) keepAlive(conn) }) defer srv.Close() client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default()) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() done := make(chan struct{}) go func() { client.Start(ctx) close(done) }() // Wait briefly for handshake to complete time.Sleep(200 * time.Millisecond) // Verify connId was set client.connMu.Lock() connID := client.connId client.connMu.Unlock() if connID != "test-conn-123" { t.Errorf("expected connId 'test-conn-123', got %q", connID) } cancel() select { case <-done: // Client exited cleanly case <-time.After(3 * time.Second): t.Fatal("WSClient did not shut down after context cancellation") } } // ── 2. Test: Send() with response matching ─────────────────────────────── func TestWSClient_Send(t *testing.T) { srv := newTestWSServer(t, func(conn *websocket.Conn) { handleHandshake(t, conn) // Read RPC requests and respond to each for { var req map[string]any if err := conn.ReadJSON(&req); err != nil { break } reqID, _ := req["id"].(string) method, _ := req["method"].(string) var result any switch method { case "agents.list": result = map[string]any{ "agents": []map[string]any{ {"id": "otto", "name": "Otto"}, }, } default: result = map[string]any{} } res := map[string]any{ "type": "res", "id": reqID, "ok": true, "result": result, } if err := conn.WriteJSON(res); err != nil { break } } }) defer srv.Close() client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default()) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() go client.Start(ctx) // Give the client time to complete handshake time.Sleep(300 * time.Millisecond) resp, err := client.Send("agents.list", nil) if err != nil { t.Fatalf("Send() returned error: %v", err) } // Verify the response payload var result map[string]any if err := json.Unmarshal(resp, &result); err != nil { t.Fatalf("unmarshal response: %v", err) } agents, ok := result["agents"].([]any) if !ok || len(agents) != 1 { t.Errorf("expected 1 agent in response, got %v", result) } cancel() } // ── 3. Test: Event handler routing ─────────────────────────────────────── func TestWSClient_EventRouting(t *testing.T) { eventReceived := make(chan json.RawMessage, 1) srv := newTestWSServer(t, func(conn *websocket.Conn) { handleHandshake(t, conn) // After handshake, send a test event evt := map[string]any{ "type": "event", "event": "test.event", "params": map[string]any{"greeting": "hello from server"}, } if err := conn.WriteJSON(evt); err != nil { t.Logf("server: write event: %v", err) return } keepAlive(conn) }) defer srv.Close() client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default()) // Register event handler BEFORE starting the client client.OnEvent("test.event", func(payload json.RawMessage) { eventReceived <- payload }) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() go client.Start(ctx) // Wait for the event handler to fire select { case payload := <-eventReceived: var data map[string]any if err := json.Unmarshal(payload, &data); err != nil { t.Fatalf("unmarshal event payload: %v", err) } if greeting, _ := data["greeting"].(string); greeting != "hello from server" { t.Errorf("expected greeting 'hello from server', got %q", greeting) } case <-time.After(3 * time.Second): t.Fatal("timed out waiting for event handler to fire") } cancel() } // ── 4. Test: Concurrent Send ───────────────────────────────────────────── func TestWSClient_ConcurrentSend(t *testing.T) { var reqCount atomic.Int32 srv := newTestWSServer(t, func(conn *websocket.Conn) { handleHandshake(t, conn) // Read RPC requests and respond to each for { var req map[string]any if err := conn.ReadJSON(&req); err != nil { break } reqID, _ := req["id"].(string) n := reqCount.Add(1) res := map[string]any{ "type": "res", "id": reqID, "ok": true, "result": map[string]any{"index": n, "method": req["method"]}, } if err := conn.WriteJSON(res); err != nil { break } } }) defer srv.Close() client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default()) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() go client.Start(ctx) // Give the client time to complete handshake time.Sleep(300 * time.Millisecond) // Fire 3 concurrent Send() calls type sendResult struct { method string payload json.RawMessage err error } results := make(chan sendResult, 3) methods := []string{"agents.list", "sessions.list", "agents.config"} for _, method := range methods { go func(m string) { resp, err := client.Send(m, nil) results <- sendResult{method: m, payload: resp, err: err} }(method) } // Collect all results for i := 0; i < 3; i++ { select { case r := <-results: if r.err != nil { t.Errorf("Send(%q) returned error: %v", r.method, r.err) continue } var result map[string]any if err := json.Unmarshal(r.payload, &result); err != nil { t.Errorf("Send(%q) unmarshal error: %v", r.method, err) continue } gotMethod, _ := result["method"].(string) if gotMethod != r.method { t.Errorf("Send(%q) got response for %q (mismatched)", r.method, gotMethod) } case <-time.After(5 * time.Second): t.Fatal("timed out waiting for concurrent Send results") } } cancel() } // ── 5. Test: Clean shutdown ────────────────────────────────────────────── func TestWSClient_CleanShutdown(t *testing.T) { srv := newTestWSServer(t, func(conn *websocket.Conn) { handleHandshake(t, conn) keepAlive(conn) }) defer srv.Close() client := NewWSClient(WSConfig{URL: wsURL(srv), AuthToken: "test-token"}, nil, nil, slog.Default()) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) done := make(chan struct{}) go func() { client.Start(ctx) close(done) }() // Let the client connect and complete handshake time.Sleep(200 * time.Millisecond) // Cancel context — should trigger clean shutdown cancel() select { case <-done: // Client exited cleanly — pass case <-time.After(3 * time.Second): t.Fatal("WSClient did not shut down cleanly within timeout") } } // ── Pure utility tests (from CUB-205) ───────────────────────────────────── func TestMapSessionStatus(t *testing.T) { tests := []struct { input string expected models.AgentStatus }{ {"running", models.AgentStatusActive}, {"streaming", models.AgentStatusActive}, {"done", models.AgentStatusIdle}, {"error", models.AgentStatusError}, {"", models.AgentStatusIdle}, {"garbage", models.AgentStatusIdle}, } for _, tt := range tests { result := mapSessionStatus(tt.input) if result != tt.expected { t.Errorf("mapSessionStatus(%q) = %q, want %q", tt.input, result, tt.expected) } } } func TestAgentItemToCard(t *testing.T) { t.Run("full fields", func(t *testing.T) { item := agentListItem{ ID: "dex", Name: "Dex", Role: "backend", Channel: "telegram", } card := agentItemToCard(item) if card.ID != "dex" { t.Errorf("ID = %q, want %q", card.ID, "dex") } if card.DisplayName != "Dex" { t.Errorf("DisplayName = %q, want %q", card.DisplayName, "Dex") } if card.Role != "backend" { t.Errorf("Role = %q, want %q", card.Role, "backend") } if card.Channel != "telegram" { t.Errorf("Channel = %q, want %q", card.Channel, "telegram") } if card.Status != models.AgentStatusIdle { t.Errorf("Status = %q, want %q", card.Status, models.AgentStatusIdle) } }) t.Run("empty fields use defaults", func(t *testing.T) { item := agentListItem{ ID: "otto", } card := agentItemToCard(item) if card.ID != "otto" { t.Errorf("ID = %q, want %q", card.ID, "otto") } if card.DisplayName != "otto" { t.Errorf("DisplayName = %q, want %q (should fallback to ID)", card.DisplayName, "otto") } if card.Role != "agent" { t.Errorf("Role = %q, want %q (default)", card.Role, "agent") } if card.Channel != "unknown" { t.Errorf("Channel = %q, want %q (per Grimm requirement)", card.Channel, "unknown") } if card.Status != models.AgentStatusIdle { t.Errorf("Status = %q, want %q", card.Status, models.AgentStatusIdle) } }) t.Run("empty name falls back to ID", func(t *testing.T) { item := agentListItem{ ID: "hex", Name: "", Role: "database", } card := agentItemToCard(item) if card.DisplayName != "hex" { t.Errorf("DisplayName = %q, want %q (ID fallback)", card.DisplayName, "hex") } }) } func TestStrPtr(t *testing.T) { s := "hello" p := strPtr(s) if p == nil { t.Fatal("strPtr returned nil") } if *p != s { t.Errorf("strPtr(%q) = %q, want %q", s, *p, s) } empty := "" ep := strPtr(empty) if *ep != empty { t.Errorf("strPtr(empty) = %q, want %q", *ep, empty) } }