diff --git a/chat/src/app/layout.tsx b/chat/src/app/layout.tsx index 830124ed..7c44c440 100644 --- a/chat/src/app/layout.tsx +++ b/chat/src/app/layout.tsx @@ -29,7 +29,7 @@ export default function RootLayout({ disableTransitionOnChange > {children} - + diff --git a/chat/src/components/chat-provider.tsx b/chat/src/components/chat-provider.tsx index 21a2ee3f..e8789ac0 100644 --- a/chat/src/components/chat-provider.tsx +++ b/chat/src/components/chat-provider.tsx @@ -36,6 +36,12 @@ interface StatusChangeEvent { agent_type: string; } +interface ErrorEventData { + message: string; + level: string; + time: string; +} + interface APIErrorDetail { location: string; message: string; @@ -215,6 +221,25 @@ export function ChatProvider({ children }: PropsWithChildren) { setAgentType(data.agent_type === "" ? "unknown" : data.agent_type as AgentType); }); + // Handle agent error events + eventSource.addEventListener("agent_error", (event) => { + const messageEvent = event as MessageEvent; + try { + const data: ErrorEventData = JSON.parse(messageEvent.data); + + // Display error as toast notification that persists until manually dismissed + if (data.level === "error") { + toast.error(data.message, { duration: Infinity }); + } else if (data.level === "warning") { + toast.warning(data.message, { duration: Infinity }); + } else { + toast.info(data.message, { duration: Infinity }); + } + } catch (e) { + console.error("Failed to parse agent_error event data:", e); + } + }); + // Handle connection open (server is online) eventSource.onopen = () => { // Connection is established, but we'll wait for status_change event diff --git a/cmd/server/server.go b/cmd/server/server.go index 6d5cdec3..cfb9640d 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -8,9 +8,12 @@ import ( "log/slog" "net/http" "os" + "path/filepath" "sort" "strings" + "time" + "github.com/coder/agentapi/lib/screentracker" "github.com/mattn/go-isatty" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -103,6 +106,42 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } } + // Get the variables related to state management + stateFile := viper.GetString(FlagStateFile) + loadState := false + saveState := false + + // Validate state file configuration + if stateFile != "" { + if !viper.IsSet(FlagLoadState) { + loadState = true + } else { + loadState = viper.GetBool(FlagLoadState) + } + + if !viper.IsSet(FlagSaveState) { + saveState = true + } else { + saveState = viper.GetBool(FlagSaveState) + } + } else { + if viper.IsSet(FlagLoadState) && viper.GetBool(FlagLoadState) { + return xerrors.Errorf("--load-state requires --state-file to be set") + } + if viper.IsSet(FlagSaveState) && viper.GetBool(FlagSaveState) { + return xerrors.Errorf("--save-state requires --state-file to be set") + } + } + + pidFile := viper.GetString(FlagPidFile) + + // Write PID file if configured + if pidFile != "" { + if err := writePIDFile(pidFile, logger); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + } + printOpenAPI := viper.GetBool(FlagPrintOpenAPI) var process *termexec.Process if printOpenAPI { @@ -128,7 +167,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), InitialPrompt: initialPrompt, + StatePersistenceConfig: screentracker.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: loadState, + SaveState: saveState, + }, }) + if err != nil { return xerrors.Errorf("failed to create server: %w", err) } @@ -136,7 +181,22 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er fmt.Println(srv.GetOpenAPI()) return nil } + + // Create a context for graceful shutdown + gracefulCtx, gracefulCancel := context.WithCancel(ctx) + defer gracefulCancel() + + // Setup signal handlers (they will call gracefulCancel) + handleSignals(gracefulCtx, gracefulCancel, logger, srv) + + // Setup PID file cleanup + if pidFile != "" { + defer cleanupPIDFile(pidFile, logger) + } + logger.Info("Starting server on port", "port", port) + + // Monitor process exit processExitCh := make(chan error, 1) go func() { defer close(processExitCh) @@ -147,16 +207,52 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) } } - if err := srv.Stop(ctx); err != nil { - logger.Error("Failed to stop server", "error", err) + + select { + case <-gracefulCtx.Done(): + default: + gracefulCancel() } }() - if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed { - return xerrors.Errorf("failed to start server: %w", err) + + // Start the server + serverErrCh := make(chan error, 1) + go func() { + defer close(serverErrCh) + if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { + serverErrCh <- err + } + }() + + select { + case err := <-serverErrCh: + if err != nil { + return xerrors.Errorf("failed to start server: %w", err) + } + case <-gracefulCtx.Done(): + } + + if err := srv.SaveState("shutdown"); err != nil { + logger.Error("Failed to save state during shutdown", "error", err) } + + // Stop the HTTP server + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Stop(shutdownCtx); err != nil { + logger.Error("Failed to stop HTTP server", "error", err) + } + + // Close the process + if err := process.Close(logger, 5*time.Second); err != nil { + logger.Error("Failed to close process cleanly", "error", err) + } + select { case err := <-processExitCh: - return xerrors.Errorf("agent exited with error: %w", err) + if err != nil { + return xerrors.Errorf("agent exited with error: %w", err) + } default: } return nil @@ -171,6 +267,35 @@ var agentNames = (func() []string { return names })() +// writePIDFile writes the current process ID to the specified file +func writePIDFile(pidFile string, logger *slog.Logger) error { + pid := os.Getpid() + pidContent := fmt.Sprintf("%d\n", pid) + + // Create directory if it doesn't exist + dir := filepath.Dir(pidFile) + if err := os.MkdirAll(dir, 0o700); err != nil { + return xerrors.Errorf("failed to create PID file directory: %w", err) + } + + // Write PID file + if err := os.WriteFile(pidFile, []byte(pidContent), 0o600); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + + logger.Info("Wrote PID file", "pidFile", pidFile, "pid", pid) + return nil +} + +// cleanupPIDFile removes the PID file if it exists +func cleanupPIDFile(pidFile string, logger *slog.Logger) { + if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { + logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err) + } else if err == nil { + logger.Info("Removed PID file", "pidFile", pidFile) + } +} + type flagSpec struct { name string shorthand string @@ -190,6 +315,10 @@ const ( FlagAllowedOrigins = "allowed-origins" FlagExit = "exit" FlagInitialPrompt = "initial-prompt" + FlagStateFile = "state-file" + FlagLoadState = "load-state" + FlagSaveState = "save-state" + FlagPidFile = "pid-file" ) func CreateServerCmd() *cobra.Command { @@ -228,6 +357,10 @@ func CreateServerCmd() *cobra.Command { // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, + {FlagStateFile, "s", "", "Path to file for saving/loading server state", "string"}, + {FlagLoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"}, + {FlagSaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"}, + {FlagPidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"}, } for _, spec := range flagSpecs { diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index bd07fc63..7b9372c1 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -2,6 +2,8 @@ package server import ( "fmt" + "io" + "log/slog" "os" "strings" "testing" @@ -477,6 +479,218 @@ func TestServerCmd_AllowedHosts(t *testing.T) { } } +func TestServerCmd_StatePersistenceFlags(t *testing.T) { + // NOTE: These tests use --exit flag to test flag parsing and defaults. + // Runtime validation that happens in runServer (e.g., "--load-state requires --state-file") + // would call os.Exit(1) which terminates the test process, so those validations + // are tested through integration/E2E tests instead. + + t.Run("state-file with defaults", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + // load-state and save-state default to true when state-file is set (validated in runServer) + }) + + t.Run("state-file with explicit load-state=false", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--load-state=false", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, false, viper.GetBool(FlagLoadState)) + }) + + t.Run("state-file with explicit save-state=false", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--save-state=false", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, false, viper.GetBool(FlagSaveState)) + }) + + t.Run("state-file with explicit load-state=true and save-state=true", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{ + "--state-file", "/tmp/state.json", + "--load-state=true", + "--save-state=true", + "--exit", "dummy-command", + }) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, true, viper.GetBool(FlagLoadState)) + assert.Equal(t, true, viper.GetBool(FlagSaveState)) + }) + + t.Run("load-state flag can be parsed", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--load-state", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + // Flag is parsed correctly (validation happens in runServer) + assert.Equal(t, true, viper.GetBool(FlagLoadState)) + }) + + t.Run("save-state flag can be parsed", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--save-state", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + // Flag is parsed correctly (validation happens in runServer) + assert.Equal(t, true, viper.GetBool(FlagSaveState)) + }) + + t.Run("pid-file can be set independently", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--pid-file", "/tmp/server.pid", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/server.pid", viper.GetString(FlagPidFile)) + }) + + t.Run("state-file and pid-file can be set together", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{ + "--state-file", "/tmp/state.json", + "--pid-file", "/tmp/server.pid", + "--exit", "dummy-command", + }) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, "/tmp/server.pid", viper.GetString(FlagPidFile)) + }) +} + +func TestPIDFileOperations(t *testing.T) { + discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + t.Run("writePIDFile creates file with process ID", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + err := writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(pidFile) + require.NoError(t, err) + + // Verify content contains current PID + data, err := os.ReadFile(pidFile) + require.NoError(t, err) + + expectedPID := fmt.Sprintf("%d\n", os.Getpid()) + assert.Equal(t, expectedPID, string(data)) + }) + + t.Run("writePIDFile creates directory if not exists", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/nested/deep/test.pid" + + err := writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(pidFile) + require.NoError(t, err) + + // Verify directory was created + _, err = os.Stat(tmpDir + "/nested/deep") + require.NoError(t, err) + }) + + t.Run("writePIDFile overwrites existing file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Write initial PID file + err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + require.NoError(t, err) + + // Overwrite with current PID + err = writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify content is updated + data, err := os.ReadFile(pidFile) + require.NoError(t, err) + + expectedPID := fmt.Sprintf("%d\n", os.Getpid()) + assert.Equal(t, expectedPID, string(data)) + }) + + t.Run("cleanupPIDFile removes file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Create PID file + err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + require.NoError(t, err) + + // Cleanup + cleanupPIDFile(pidFile, discardLogger) + + // Verify file is removed + _, err = os.Stat(pidFile) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("cleanupPIDFile handles non-existent file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/nonexistent.pid" + + // Should not panic or error + cleanupPIDFile(pidFile, discardLogger) + }) + + t.Run("cleanupPIDFile handles directory removal error gracefully", func(t *testing.T) { + // Create a file in a protected directory (this is system-dependent) + // Just verify it doesn't panic when it can't remove the file + pidFile := "/this/should/not/exist/test.pid" + + // Should not panic + cleanupPIDFile(pidFile, discardLogger) + }) +} + func TestServerCmd_AllowedOrigins(t *testing.T) { tests := []struct { name string diff --git a/cmd/server/signals_unix.go b/cmd/server/signals_unix.go new file mode 100644 index 00000000..b15b5b2b --- /dev/null +++ b/cmd/server/signals_unix.go @@ -0,0 +1,46 @@ +//go:build unix + +package server + +import ( + "context" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/httpapi" +) + +// handleSignals sets up signal handlers for: +// - SIGTERM, SIGINT, SIGHUP: trigger graceful shutdown by canceling the context +// - SIGUSR1: save conversation state without exiting +func handleSignals(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, srv *httpapi.Server) { + // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT) + go func() { + defer signal.Stop(shutdownCh) + sig := <-shutdownCh + logger.Info("Received shutdown signal", "signal", sig) + cancel() + }() + + // Handle SIGUSR1 for save without exit + saveOnlyCh := make(chan os.Signal, 1) + signal.Notify(saveOnlyCh, syscall.SIGUSR1) + go func() { + defer signal.Stop(saveOnlyCh) + for { + select { + case <-saveOnlyCh: + logger.Info("Received SIGUSR1, saving state without exiting") + if err := srv.SaveState("SIGUSR1"); err != nil { + logger.Error("Failed to save state on SIGUSR1", "error", err) + } + case <-ctx.Done(): + return + } + } + }() +} diff --git a/cmd/server/signals_windows.go b/cmd/server/signals_windows.go new file mode 100644 index 00000000..b8a109c9 --- /dev/null +++ b/cmd/server/signals_windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package server + +import ( + "context" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/httpapi" +) + +// handleSignals sets up signal handlers for Windows. +// Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). +func handleSignals(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, srv *httpapi.Server) { + // Handle shutdown signals (SIGTERM, SIGINT only on Windows) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) + go func() { + defer signal.Stop(shutdownCh) + sig := <-shutdownCh + logger.Info("Received shutdown signal", "signal", sig) + cancel() + }() +} diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 906a3a42..d2ba4d3f 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -18,6 +18,7 @@ const ( EventTypeMessageUpdate EventType = "message_update" EventTypeStatusChange EventType = "status_change" EventTypeScreenUpdate EventType = "screen_update" + EventTypeError EventType = "agent_error" ) type AgentStatus string @@ -52,6 +53,12 @@ type ScreenUpdateBody struct { Screen string `json:"screen"` } +type ErrorBody struct { + Message string `json:"message" doc:"Error message"` + Level string `json:"level" doc:"Error level: 'warning' or 'error'"` + Time time.Time `json:"time" doc:"Timestamp when the error occurred"` +} + type Event struct { Type EventType Payload any @@ -66,6 +73,7 @@ type EventEmitter struct { chanIdx int subscriptionBufSize uint screen string + errors []ErrorBody } func convertStatus(status st.ConversationStatus) AgentStatus { @@ -137,7 +145,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) { } } -// Assumes that only the last message can change or new messages can be added. +// EmitMessages assumes that only the last message can change or new messages can be added. // If a new message is injected between existing messages (identified by Id), the behavior is undefined. func (e *EventEmitter) EmitMessages(newMessages []st.ConversationMessage) { e.mu.Lock() @@ -194,6 +202,22 @@ func (e *EventEmitter) EmitScreen(newScreen string) { e.screen = newScreen } +func (e *EventEmitter) EmitError(message string, level string) { + e.mu.Lock() + defer e.mu.Unlock() + + errorBody := ErrorBody{ + Message: message, + Level: level, + Time: time.Now(), + } + + // Store the error so new subscribers can receive all errors + e.errors = append(e.errors, errorBody) + + e.notifyChannels(EventTypeError, errorBody) +} + // Assumes the caller holds the lock. func (e *EventEmitter) currentStateAsEvents() []Event { events := make([]Event, 0, len(e.messages)+2) @@ -211,6 +235,15 @@ func (e *EventEmitter) currentStateAsEvents() []Event { Type: EventTypeScreenUpdate, Payload: ScreenUpdateBody{Screen: strings.TrimRight(e.screen, mf.WhiteSpaceChars)}, }) + + // Include all error events + for _, err := range e.errors { + events = append(events, Event{ + Type: EventTypeError, + Payload: err, + }) + } + return events } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 956cfb8a..f18ce679 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -40,6 +40,7 @@ type Server struct { port int srv *http.Server mu sync.RWMutex + stopOnce sync.Once logger *slog.Logger conversation st.Conversation agentio *termexec.Process @@ -48,6 +49,8 @@ type Server struct { chatBasePath string tempDir string clock quartz.Clock + shutdownCtx context.Context + shutdown context.CancelFunc } func (s *Server) NormalizeSchema(schema any) any { @@ -97,14 +100,15 @@ func (s *Server) GetOpenAPI() string { const snapshotInterval = 25 * time.Millisecond type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string - AllowedHosts []string - AllowedOrigins []string - InitialPrompt string - Clock quartz.Clock + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string + InitialPrompt string + Clock quartz.Clock + StatePersistenceConfig st.StatePersistenceConfig } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -253,16 +257,17 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { } conversation := st.NewPTY(ctx, st.PTYConversationConfig{ - AgentType: config.AgentType, - AgentIO: config.Process, - Clock: config.Clock, - SnapshotInterval: snapshotInterval, - ScreenStabilityLength: 2 * time.Second, - FormatMessage: formatMessage, - ReadyForInitialPrompt: isAgentReadyForInitialPrompt, - FormatToolCall: formatToolCall, - InitialPrompt: initialPrompt, - Logger: logger, + AgentType: config.AgentType, + AgentIO: config.Process, + Clock: config.Clock, + SnapshotInterval: snapshotInterval, + ScreenStabilityLength: 2 * time.Second, + FormatMessage: formatMessage, + ReadyForInitialPrompt: isAgentReadyForInitialPrompt, + FormatToolCall: formatToolCall, + InitialPrompt: initialPrompt, + Logger: logger, + StatePersistenceConfig: config.StatePersistenceConfig, }, emitter) // Create temporary directory for uploads @@ -272,6 +277,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { } logger.Info("Created temporary directory for uploads", "tempDir", tempDir) + ctx, cancel := context.WithCancel(context.Background()) + s := &Server{ router: router, api: api, @@ -284,6 +291,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), tempDir: tempDir, clock: config.Clock, + shutdownCtx: ctx, + shutdown: cancel, } // Register API routes @@ -387,6 +396,7 @@ func (s *Server) registerRoutes() { // Mapping of event type name to Go struct for that event. "message_update": MessageUpdateBody{}, "status_change": StatusChangeBody{}, + "agent_error": ErrorBody{}, }, s.subscribeEvents) sse.Register(s.api, huma.Operation{ @@ -511,6 +521,7 @@ func (s *Server) uploadFiles(ctx context.Context, input *struct { func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse.Sender) { subscriberId, ch, stateEvents := s.emitter.Subscribe() defer s.emitter.Unsubscribe(subscriberId) + s.logger.Info("New subscriber", "subscriberId", subscriberId) for _, event := range stateEvents { if event.Type == EventTypeScreenUpdate { @@ -536,6 +547,9 @@ func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse. s.logger.Error("Failed to send event", "subscriberId", subscriberId, "error", err) return } + case <-s.shutdownCtx.Done(): + s.logger.Info("Server stop initiated, unsubscribing.", "subscriberId", subscriberId) + return case <-ctx.Done(): s.logger.Info("Context done", "subscriberId", subscriberId) return @@ -570,6 +584,9 @@ func (s *Server) subscribeScreen(ctx context.Context, input *struct{}, send sse. s.logger.Error("Failed to send screen event", "subscriberId", subscriberId, "error", err) return } + case <-s.shutdownCtx.Done(): + s.logger.Info("Server stop initiated, unsubscribing.", "subscriberId", subscriberId) + return case <-ctx.Done(): s.logger.Info("Screen context done", "subscriberId", subscriberId) return @@ -588,15 +605,20 @@ func (s *Server) Start() error { return s.srv.ListenAndServe() } -// Stop gracefully stops the HTTP server +// Stop gracefully stops the HTTP server. It is safe to call multiple times. func (s *Server) Stop(ctx context.Context) error { - // Clean up temporary directory - s.cleanupTempDir() + var err error + s.stopOnce.Do(func() { + s.shutdown() - if s.srv != nil { - return s.srv.Shutdown(ctx) - } - return nil + // Clean up temporary directory + s.cleanupTempDir() + + if s.srv != nil { + err = s.srv.Shutdown(ctx) + } + }) + return err } // cleanupTempDir removes the temporary directory and all its contents @@ -608,6 +630,14 @@ func (s *Server) cleanupTempDir() { } } +func (s *Server) SaveState(source string) error { + if err := s.conversation.SaveState(); err != nil { + s.logger.Error("Failed to save conversation state", "source", source, "error", err) + return err + } + return nil +} + // registerStaticFileRoutes sets up routes for serving static files func (s *Server) registerStaticFileRoutes() { chatHandler := FileServerWithIndexFallback(s.chatBasePath) diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index c8e8b23c..82fc6713 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -13,6 +13,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/logctx" @@ -956,3 +957,36 @@ func TestServer_UploadFiles_Errors(t *testing.T) { require.Contains(t, string(body), "file size exceeds 10MB limit") }) } + +func TestServer_Stop_Idempotency(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + + srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, + AllowedOrigins: []string{"*"}, + }) + require.NoError(t, err) + + // First call to Stop should succeed + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err = srv.Stop(stopCtx) + require.NoError(t, err) + + // Second call to Stop should also succeed (no-op) + stopCtx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + err = srv.Stop(stopCtx2) + require.NoError(t, err) + + // Third call to Stop should also succeed (no-op) + stopCtx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel3() + err = srv.Stop(stopCtx3) + require.NoError(t, err) +} diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 16203041..c8d95b6e 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -4,10 +4,7 @@ import ( "context" "fmt" "os" - "os/signal" "strings" - "syscall" - "time" "github.com/coder/agentapi/lib/logctx" mf "github.com/coder/agentapi/lib/msgfmt" @@ -45,16 +42,5 @@ func SetupProcess(ctx context.Context, config SetupProcessConfig) (*termexec.Pro return nil, err } } - - // Handle SIGINT (Ctrl+C) and send it to the process - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) - go func() { - <-signalCh - if err := process.Close(logger, 5*time.Second); err != nil { - logger.Error("Error closing process", "error", err) - } - }() - return process, nil } diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 8299faa1..8555e424 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -2,6 +2,7 @@ package screentracker import ( "context" + "strings" "time" "github.com/coder/agentapi/lib/util" @@ -49,6 +50,14 @@ type MessagePart interface { String() string } +func buildStringFromMessageParts(parts []MessagePart) string { + var sb strings.Builder + for _, part := range parts { + sb.WriteString(part.String()) + } + return sb.String() +} + // Conversation represents a conversation between a user and an agent. // It is intended as the primary interface for interacting with a session. // Implementations must support the following capabilities: @@ -63,6 +72,7 @@ type Conversation interface { Start(context.Context) Status() ConversationStatus Text() string + SaveState() error } // Emitter receives conversation state updates. @@ -70,6 +80,7 @@ type Emitter interface { EmitMessages([]ConversationMessage) EmitStatus(ConversationStatus) EmitScreen(string) + EmitError(message string, level string) } type ConversationMessage struct { @@ -78,3 +89,9 @@ type ConversationMessage struct { Role ConversationRole Time time.Time } + +type StatePersistenceConfig struct { + StateFile string + LoadState bool + SaveState bool +} diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 27283775..407bb10b 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -2,8 +2,12 @@ package screentracker import ( "context" + "encoding/json" "fmt" + "io" "log/slog" + "os" + "path/filepath" "strings" "sync" "time" @@ -26,6 +30,25 @@ type MessagePartText struct { Hidden bool } +type AgentState struct { + Version int `json:"version"` + Messages []ConversationMessage `json:"messages"` + InitialPrompt string `json:"initial_prompt"` + InitialPromptSent bool `json:"initial_prompt_sent"` +} + +// LoadStateStatus represents the state of loading persisted conversation state. +type LoadStateStatus int + +const ( + // LoadStatePending indicates state loading has not been attempted yet. + LoadStatePending LoadStateStatus = iota + // LoadStateSucceeded indicates state was successfully loaded. + LoadStateSucceeded + // LoadStateFailed indicates state loading was attempted but failed. + LoadStateFailed +) + var _ MessagePart = &MessagePartText{} func (p MessagePartText) Do(writer AgentIO) error { @@ -67,8 +90,9 @@ type PTYConversationConfig struct { // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls FormatToolCall func(message string) (string, []string) // InitialPrompt is the initial prompt to send to the agent once ready - InitialPrompt []MessagePart - Logger *slog.Logger + InitialPrompt []MessagePart + Logger *slog.Logger + StatePersistenceConfig StatePersistenceConfig } func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { @@ -107,9 +131,19 @@ type PTYConversation struct { stableSignal chan struct{} // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message toolCallMessageSet map[string]bool + // dirty tracks whether the conversation state has changed since the last save + dirty bool + // firstStableSnapshot is the conversation history rolled out by the agent in case of a resume (given that the agent supports it) + firstStableSnapshot string + // userSentMessageAfterLoadState tracks if the user has sent their first message after we load the state + userSentMessageAfterLoadState bool + // loadStateStatus tracks the status of loading conversation state from file. + loadStateStatus LoadStateStatus // initialPromptReady is set to true when ReadyForInitialPrompt returns true. // Checked inline in the snapshot loop on each tick. initialPromptReady bool + // initialPromptSent is set to true when the initial prompt has been enqueued to the outbound queue. + initialPromptSent bool } var _ Conversation = &PTYConversation{} @@ -119,6 +153,7 @@ type noopEmitter struct{} func (noopEmitter) EmitMessages([]ConversationMessage) {} func (noopEmitter) EmitStatus(ConversationStatus) {} func (noopEmitter) EmitScreen(string) {} +func (noopEmitter) EmitError(_ string, _ string) {} func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PTYConversation { if cfg.Clock == nil { @@ -140,13 +175,13 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PT Time: cfg.Clock.Now(), }, }, - outboundQueue: make(chan outboundMessage, 1), - stableSignal: make(chan struct{}, 1), - toolCallMessageSet: make(map[string]bool), - } - // If we have an initial prompt, enqueue it - if len(cfg.InitialPrompt) > 0 { - c.outboundQueue <- outboundMessage{parts: cfg.InitialPrompt, errCh: nil} + outboundQueue: make(chan outboundMessage, 1), + stableSignal: make(chan struct{}, 1), + toolCallMessageSet: make(map[string]bool), + dirty: false, + firstStableSnapshot: "", + userSentMessageAfterLoadState: false, + loadStateStatus: LoadStatePending, } if c.cfg.ReadyForInitialPrompt == nil { c.cfg.ReadyForInitialPrompt = func(string) bool { return true } @@ -169,6 +204,24 @@ func (c *PTYConversation) Start(ctx context.Context) { if !c.initialPromptReady && c.cfg.ReadyForInitialPrompt(screen) { c.initialPromptReady = true } + + if c.initialPromptReady && c.loadStateStatus == LoadStatePending && c.cfg.StatePersistenceConfig.LoadState { + if err := c.loadStateLocked(); err != nil { + c.cfg.Logger.Error("Failed to load state", "error", err) + c.emitter.EmitError(fmt.Sprintf("Failed to restore previous session: %v", err), "warning") + c.loadStateStatus = LoadStateFailed + } else { + c.loadStateStatus = LoadStateSucceeded + } + } + + // Enqueue initial prompt once after agent is ready (and after state is potentially loaded) + if c.initialPromptReady && len(c.cfg.InitialPrompt) > 0 && !c.initialPromptSent { + c.outboundQueue <- outboundMessage{parts: c.cfg.InitialPrompt, errCh: nil} + c.initialPromptSent = true + c.dirty = true + } + if c.initialPromptReady && len(c.outboundQueue) > 0 && c.isScreenStableLocked() { select { case c.stableSignal <- struct{}{}: @@ -245,6 +298,9 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp if c.cfg.FormatMessage != nil { agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) } + if c.loadStateStatus == LoadStateSucceeded { + agentMessage = c.adjustScreenAfterStateLoad(agentMessage) + } if c.cfg.FormatToolCall != nil { agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) } @@ -274,6 +330,8 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp c.messages[len(c.messages)-1] = conversationMessage } c.messages[len(c.messages)-1].Id = len(c.messages) - 1 + + c.dirty = true } // caller MUST hold c.lock @@ -288,11 +346,7 @@ func (c *PTYConversation) snapshotLocked(screen string) { func (c *PTYConversation) Send(messageParts ...MessagePart) error { // Validate message content before enqueueing - var sb strings.Builder - for _, part := range messageParts { - sb.WriteString(part.String()) - } - message := sb.String() + message := buildStringFromMessageParts(messageParts) if message != msgfmt.TrimWhitespace(message) { return ErrMessageValidationWhitespace } @@ -316,11 +370,7 @@ func (c *PTYConversation) Send(messageParts ...MessagePart) error { // around the parts that access shared state, but releases it during // writeStabilize to avoid blocking the snapshot loop. func (c *PTYConversation) sendMessage(ctx context.Context, messageParts ...MessagePart) error { - var sb strings.Builder - for _, part := range messageParts { - sb.WriteString(part.String()) - } - message := sb.String() + message := buildStringFromMessageParts(messageParts) c.lock.Lock() screenBeforeMessage := c.cfg.AgentIO.ReadScreen() @@ -350,6 +400,8 @@ func (c *PTYConversation) sendMessage(ctx context.Context, messageParts ...Messa Role: ConversationRoleUser, Time: now, }) + c.userSentMessageAfterLoadState = true + c.lock.Unlock() return nil } @@ -497,3 +549,149 @@ func (c *PTYConversation) Text() string { } return snapshots[len(snapshots)-1].screen } + +func (c *PTYConversation) SaveState() error { + c.lock.Lock() + defer c.lock.Unlock() + + stateFile := c.cfg.StatePersistenceConfig.StateFile + saveState := c.cfg.StatePersistenceConfig.SaveState + + if !saveState { + c.cfg.Logger.Info("State persistence is disabled") + return nil + } + + // Skip if not dirty + if !c.dirty { + c.cfg.Logger.Info("Skipping state save: no changes since last save") + return nil + } + + conversation := c.messagesLocked() + + // Serialize initial prompt from message parts + var initialPromptStr string + if len(c.cfg.InitialPrompt) > 0 { + initialPromptStr = buildStringFromMessageParts(c.cfg.InitialPrompt) + } + + // Use atomic write: write to temp file, then rename to target path + data, err := json.MarshalIndent(AgentState{ + Version: 1, + Messages: conversation, + InitialPrompt: initialPromptStr, + InitialPromptSent: c.initialPromptSent, + }, "", " ") + if err != nil { + return xerrors.Errorf("failed to marshal state: %w", err) + } + + // Create directory if it doesn't exist + dir := filepath.Dir(stateFile) + if err := os.MkdirAll(dir, 0o700); err != nil { + return xerrors.Errorf("failed to create state directory: %w", err) + } + + // Write to temp file + tempFile := stateFile + ".tmp" + if err := os.WriteFile(tempFile, data, 0o600); err != nil { + return xerrors.Errorf("failed to write temp state file: %w", err) + } + + // Atomic rename + if err := os.Rename(tempFile, stateFile); err != nil { + return xerrors.Errorf("failed to rename state file: %w", err) + } + + // Clear dirty flag after successful save + c.dirty = false + + c.cfg.Logger.Info("State saved successfully", "path", stateFile) + + return nil +} + +// loadStateLocked loads the state, this method assumes that caller holds the Lock +func (c *PTYConversation) loadStateLocked() error { + stateFile := c.cfg.StatePersistenceConfig.StateFile + loadState := c.cfg.StatePersistenceConfig.LoadState + + if !loadState || c.loadStateStatus != LoadStatePending { + return nil + } + + // Check if file exists + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + c.cfg.Logger.Info("No previous state to load (file does not exist)", "path", stateFile) + return nil + } + + // Open state file + f, err := os.Open(stateFile) + if err != nil { + return xerrors.Errorf("failed to open state file: %w", err) + } + defer func() { + if closeErr := f.Close(); closeErr != nil { + c.cfg.Logger.Warn("Failed to close state file", "path", stateFile, "err", closeErr) + } + }() + + var agentState AgentState + decoder := json.NewDecoder(f) + if err := decoder.Decode(&agentState); err != nil { + if err == io.EOF { + c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) + return nil + } + return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) + } + + // Handle initial prompt restoration: + // - If a new initial prompt was provided via flags, check if it differs from the saved one. + // If different, mark as not sent (will be sent). If same, preserve sent status. + // - If no new prompt provided, restore the saved prompt and its sent status. + c.initialPromptSent = agentState.InitialPromptSent + if len(c.cfg.InitialPrompt) > 0 { + isDifferent := buildStringFromMessageParts(c.cfg.InitialPrompt) != agentState.InitialPrompt + c.initialPromptSent = !isDifferent + } else { + c.cfg.InitialPrompt = []MessagePart{MessagePartText{ + Content: agentState.InitialPrompt, + Alias: "", + Hidden: false, + }} + } + + c.messages = agentState.Messages + + // Store the first stable snapshot for filtering later + snapshots := c.snapshotBuffer.GetAll() + if len(snapshots) > 0 && c.cfg.FormatMessage != nil { + c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") + } + + c.dirty = false + + c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) + return nil +} + +func (c *PTYConversation) adjustScreenAfterStateLoad(screen string) string { + + if c.firstStableSnapshot == "" { + return screen + } + + newScreen := strings.Replace(screen, c.firstStableSnapshot, "", 1) + + // Before the first user message after loading state, return the last message from the loaded state. + // This prevents computing incorrect diffs from the restored screen, as the agent's message should + // remain stable until the user continues the conversation. + if !c.userSentMessageAfterLoadState && len(c.messages) > 0 { + newScreen = "\n" + c.messages[len(c.messages)-1].Message + } + + return newScreen +} diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index 19b4511b..c8a49c7e 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -2,9 +2,11 @@ package screentracker_test import ( "context" + "encoding/json" "fmt" "io" "log/slog" + "os" "sync" "testing" "time" @@ -54,6 +56,7 @@ type testEmitter struct{} func (testEmitter) EmitMessages([]st.ConversationMessage) {} func (testEmitter) EmitStatus(st.ConversationStatus) {} func (testEmitter) EmitScreen(string) {} +func (testEmitter) EmitError(_ string, _ string) {} // advanceFor is a shorthand for advanceUntil with a time-based condition. func advanceFor(ctx context.Context, t *testing.T, mClock *quartz.Mock, total time.Duration) { @@ -446,10 +449,361 @@ func TestMessages(t *testing.T) { }) } +func TestStatePersistence(t *testing.T) { + t.Run("SaveState creates file with correct structure", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + // Create temp directory for state file + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Generate some conversation + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Save state + err := c.SaveState() + require.NoError(t, err) + + // Read and verify the saved file + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.Equal(t, 1, agentState.Version) + assert.Equal(t, "test prompt", agentState.InitialPrompt) + assert.NotEmpty(t, agentState.Messages) + }) + + t.Run("SaveState skips when not configured", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + err := c.SaveState() + require.NoError(t, err) + + // File should not be created + _, err = os.Stat(stateFile) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("SaveState honors dirty flag", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Generate conversation and save + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + err := c.SaveState() + require.NoError(t, err) + + // Get file modification time + info1, err := os.Stat(stateFile) + require.NoError(t, err) + modTime1 := info1.ModTime() + + // Save again without changes - file should not be modified + err = c.SaveState() + require.NoError(t, err) + + info2, err := os.Stat(stateFile) + require.NoError(t, err) + modTime2 := info2.ModTime() + + // File modification time should be the same (dirty flag prevents save) + assert.Equal(t, modTime1, modTime2) + }) + + t.Run("SaveState creates directory if not exists", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/nested/deep/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + err := c.SaveState() + require.NoError(t, err) + + // Verify file and directory were created + _, err = os.Stat(stateFile) + assert.NoError(t, err) + }) + + t.Run("LoadState restores conversation from file", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with test data + testState := st.AgentState{ + Version: 1, + InitialPrompt: "restored prompt", + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message 1", Role: st.ConversationRoleAgent, Time: time.Now()}, + {Id: 1, Message: "user message 1", Role: st.ConversationRoleUser, Time: time.Now()}, + {Id: 2, Message: "agent message 2", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with LoadState enabled + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Advance until agent is ready and state is loaded + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Verify messages were restored + messages := c.Messages() + assert.Len(t, messages, 3) + assert.Equal(t, "agent message 1", messages[0].Message) + assert.Equal(t, "user message 1", messages[1].Message) + // The last agent message may have adjustments from adjustScreenAfterStateLoad + assert.Contains(t, messages[2].Message, "agent message 2") + }) + + t.Run("LoadState handles missing file gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/nonexistent.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic or error + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState handles empty file gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/empty.json" + + // Create empty file + err := os.WriteFile(stateFile, []byte(""), 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic or error + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState handles corrupted JSON gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/corrupted.json" + + // Create corrupted JSON file + err := os.WriteFile(stateFile, []byte("{invalid json}"), 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic - logs warning and continues + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) +} + func TestInitialPromptReadiness(t *testing.T) { discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - t.Run("agent not ready - status remains changing", func(t *testing.T) { + t.Run("agent not ready - status is stable until agent becomes ready", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) t.Cleanup(cancel) mClock := quartz.NewMock(t) @@ -472,12 +826,12 @@ func TestInitialPromptReadiness(t *testing.T) { // Take a snapshot with "loading...". Threshold is 1 (stability 0 / interval 1s = 0 + 1 = 1). advanceFor(ctx, t, mClock, 1*time.Second) - // Even though screen is stable, status should be changing because - // the initial prompt is still in the outbound queue. - assert.Equal(t, st.ConversationStatusChanging, c.Status()) + // Screen is stable and agent is not ready, so initial prompt hasn't been enqueued yet. + // Status should be stable. + assert.Equal(t, st.ConversationStatusStable, c.Status()) }) - t.Run("agent becomes ready - status stays changing until initial prompt sent", func(t *testing.T) { + t.Run("agent becomes ready - prompt enqueued and status changes to changing", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) t.Cleanup(cancel) mClock := quartz.NewMock(t) @@ -497,12 +851,11 @@ func TestInitialPromptReadiness(t *testing.T) { c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) - // Agent not ready initially. + // Agent not ready initially, status should be stable advanceFor(ctx, t, mClock, 1*time.Second) - assert.Equal(t, st.ConversationStatusChanging, c.Status()) + assert.Equal(t, st.ConversationStatusStable, c.Status()) - // Agent becomes ready, but status stays "changing" because the - // initial prompt is still in the outbound queue. + // Agent becomes ready, prompt gets enqueued, status becomes "changing" agent.setScreen("ready") advanceFor(ctx, t, mClock, 1*time.Second) assert.Equal(t, st.ConversationStatusChanging, c.Status()) @@ -533,12 +886,12 @@ func TestInitialPromptReadiness(t *testing.T) { c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) - // Status is "changing" while waiting for readiness. + // Status is "stable" while waiting for readiness (prompt not yet enqueued). advanceFor(ctx, t, mClock, 1*time.Second) - assert.Equal(t, st.ConversationStatusChanging, c.Status()) + assert.Equal(t, st.ConversationStatusStable, c.Status()) - // Agent becomes ready. The readiness loop detects this, the snapshot - // loop sees queue + stable + ready and signals the send loop. + // Agent becomes ready. The snapshot loop detects this, enqueues the prompt, + // then sees queue + stable + ready and signals the send loop. // writeStabilize runs with onWrite changing the screen, so it completes. agent.setScreen("ready") // Drive clock until the initial prompt is sent (queue drains). @@ -611,3 +964,326 @@ func TestInitialPromptReadiness(t *testing.T) { assert.Equal(t, st.ConversationStatusStable, c.Status()) }) } + +func TestInitialPromptSent(t *testing.T) { + discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + t.Run("initialPromptSent is set when initial prompt is sent", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "loading..."} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Agent becomes ready and initial prompt is sent + agent.setScreen("ready") + advanceUntil(ctx, t, mClock, func() bool { + return len(c.Messages()) >= 2 + }) + + // Save state and verify initialPromptSent is persisted + agent.setScreen("response") + advanceFor(ctx, t, mClock, 2*time.Second) + + err := c.SaveState() + require.NoError(t, err) + + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.True(t, agentState.InitialPromptSent, "initialPromptSent should be true after initial prompt is sent") + assert.Equal(t, "test prompt", agentState.InitialPrompt) + }) + + t.Run("initialPromptSent prevents re-sending prompt after state load", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with initialPromptSent=true + testState := st.AgentState{ + Version: 1, + InitialPrompt: "test prompt", + InitialPromptSent: true, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message", Role: st.ConversationRoleAgent, Time: time.Now()}, + {Id: 1, Message: "test prompt", Role: st.ConversationRoleUser, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with same initial prompt + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + writeCount := 0 + agent.onWrite = func(data []byte) { + writeCount++ + agent.screen = "after_write" + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Advance until ready and state is loaded + advanceFor(ctx, t, mClock, 500*time.Millisecond) + + // Verify the prompt was NOT re-sent (no writes occurred) + assert.Equal(t, 0, writeCount, "initial prompt should not be re-sent when already sent") + + // Messages should be restored from state (at minimum, the original 2) + messages := c.Messages() + assert.GreaterOrEqual(t, len(messages), 2, "messages should be restored from state") + // Verify the first two messages match what we saved + assert.Equal(t, "agent message", messages[0].Message) + assert.Equal(t, st.ConversationRoleAgent, messages[0].Role) + assert.Equal(t, "test prompt", messages[1].Message) + assert.Equal(t, st.ConversationRoleUser, messages[1].Role) + }) + + t.Run("new initial prompt is sent if different from saved prompt", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with old prompt + testState := st.AgentState{ + Version: 1, + InitialPrompt: "old prompt", + InitialPromptSent: true, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with different initial prompt + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "loading..."} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "new prompt"}}, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Agent becomes ready + agent.setScreen("ready") + + // Advance until the new prompt is sent + advanceUntil(ctx, t, mClock, func() bool { + msgs := c.Messages() + // Look for the new prompt in messages + for _, msg := range msgs { + if msg.Role == st.ConversationRoleUser && msg.Message == "new prompt" { + return true + } + } + return false + }) + + // Verify the new prompt was sent + messages := c.Messages() + found := false + for _, msg := range messages { + if msg.Role == st.ConversationRoleUser && msg.Message == "new prompt" { + found = true + break + } + } + assert.True(t, found, "new prompt should be sent when different from saved prompt") + }) + + t.Run("initialPromptSent not set when no initial prompt configured", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + err := c.SaveState() + require.NoError(t, err) + + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.False(t, agentState.InitialPromptSent, "initialPromptSent should be false when no initial prompt configured") + }) + + t.Run("restored prompt used when no new prompt provided", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with a prompt + testState := st.AgentState{ + Version: 1, + InitialPrompt: "saved prompt", + InitialPromptSent: false, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation without providing an initial prompt + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "loading..."} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Agent becomes ready + agent.setScreen("ready") + + // Advance until the saved prompt is sent + advanceUntil(ctx, t, mClock, func() bool { + msgs := c.Messages() + for _, msg := range msgs { + if msg.Role == st.ConversationRoleUser && msg.Message == "saved prompt" { + return true + } + } + return false + }) + + // Verify the saved prompt was sent + messages := c.Messages() + found := false + for _, msg := range messages { + if msg.Role == st.ConversationRoleUser && msg.Message == "saved prompt" { + found = true + break + } + } + assert.True(t, found, "saved prompt should be sent when no new prompt provided") + }) +} diff --git a/openapi.json b/openapi.json index dda817cc..790338a6 100644 --- a/openapi.json +++ b/openapi.json @@ -19,6 +19,30 @@ "title": "ConversationRole", "type": "string" }, + "ErrorBody": { + "additionalProperties": false, + "properties": { + "level": { + "description": "Error level: 'warning' or 'error'", + "type": "string" + }, + "message": { + "description": "Error message", + "type": "string" + }, + "time": { + "description": "Timestamp when the error occurred", + "format": "date-time", + "type": "string" + } + }, + "required": [ + "level", + "message", + "time" + ], + "type": "object" + }, "ErrorDetail": { "additionalProperties": false, "properties": { @@ -326,6 +350,32 @@ "description": "Each oneOf object in the array represents one possible Server Sent Events (SSE) message, serialized as UTF-8 text according to the SSE specification.", "items": { "oneOf": [ + { + "properties": { + "data": { + "$ref": "#/components/schemas/ErrorBody" + }, + "event": { + "const": "agent_error", + "description": "The event name.", + "type": "string" + }, + "id": { + "description": "The event ID.", + "type": "integer" + }, + "retry": { + "description": "The retry time in milliseconds.", + "type": "integer" + } + }, + "required": [ + "data", + "event" + ], + "title": "Event agent_error", + "type": "object" + }, { "properties": { "data": {