From e3bd9369d5638fcf0a262c95d1938708ceadee5f Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 22 Jan 2026 13:56:25 +0000 Subject: [PATCH 01/14] chore(lib): extract Conversation interface --- lib/httpapi/server.go | 12 +- lib/screentracker/conversation.go | 458 ++---------------- lib/screentracker/diff.go | 56 +++ lib/screentracker/diff_internal_test.go | 39 ++ lib/screentracker/pty_conversation.go | 371 ++++++++++++++ ...ation_test.go => pty_conversation_test.go} | 166 ++----- 6 files changed, 554 insertions(+), 548 deletions(-) create mode 100644 lib/screentracker/diff.go create mode 100644 lib/screentracker/diff_internal_test.go create mode 100644 lib/screentracker/pty_conversation.go rename lib/screentracker/{conversation_test.go => pty_conversation_test.go} (75%) diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 59497873..fd0a90c5 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -40,7 +40,7 @@ type Server struct { srv *http.Server mu sync.RWMutex logger *slog.Logger - conversation *st.Conversation + conversation *st.PTYConversation agentio *termexec.Process agentType mf.AgentType emitter *EventEmitter @@ -237,7 +237,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { return mf.FormatToolCall(config.AgentType, message) } - conversation := st.NewConversation(ctx, st.ConversationConfig{ + conversation := st.NewPTY(ctx, st.PTYConversationConfig{ AgentType: config.AgentType, AgentIO: config.Process, GetTime: func() time.Time { @@ -331,7 +331,7 @@ func sseMiddleware(ctx huma.Context, next func(huma.Context)) { } func (s *Server) StartSnapshotLoop(ctx context.Context) { - s.conversation.StartSnapshotLoop(ctx) + s.conversation.Start(ctx) go func() { for { currentStatus := s.conversation.Status() @@ -339,7 +339,7 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { // Send initial prompt when agent becomes stable for the first time if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { - if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { + if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { s.logger.Error("Failed to send initial prompt", "error", err) } else { s.conversation.InitialPromptSent = true @@ -350,7 +350,7 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { } s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages()) - s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen()) + s.emitter.UpdateScreenAndEmitChanges(s.conversation.String()) time.Sleep(snapshotInterval) } }() @@ -449,7 +449,7 @@ func (s *Server) createMessage(ctx context.Context, input *MessageRequest) (*Mes switch input.Body.Type { case MessageTypeUser: - if err := s.conversation.SendMessage(FormatMessage(s.agentType, input.Body.Content)...); err != nil { + if err := s.conversation.Send(FormatMessage(s.agentType, input.Body.Content)...); err != nil { return nil, xerrors.Errorf("failed to send message: %w", err) } case MessageTypeRaw: diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 97a74722..db8d82d1 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -2,55 +2,27 @@ package screentracker import ( "context" - "fmt" - "log/slog" - "strings" - "sync" "time" - "github.com/coder/agentapi/lib/msgfmt" "github.com/coder/agentapi/lib/util" "github.com/danielgtaylor/huma/v2" "golang.org/x/xerrors" ) -type screenSnapshot struct { - timestamp time.Time - screen string -} - -type AgentIO interface { - Write(data []byte) (int, error) - ReadScreen() string -} +type ConversationStatus string -type ConversationConfig struct { - AgentType msgfmt.AgentType - AgentIO AgentIO - // GetTime returns the current time - GetTime func() time.Time - // How often to take a snapshot for the stability check - SnapshotInterval time.Duration - // How long the screen should not change to be considered stable - ScreenStabilityLength time.Duration - // Function to format the messages received from the agent - // userInput is the last user message - FormatMessage func(message string, userInput string) string - // SkipWritingMessage skips the writing of a message to the agent. - // This is used in tests - SkipWritingMessage bool - // SkipSendMessageStatusCheck skips the check for whether the message can be sent. - // This is used in tests - SkipSendMessageStatusCheck bool - // ReadyForInitialPrompt detects whether the agent has initialized and is ready to accept the initial prompt - ReadyForInitialPrompt func(message string) bool - // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls - FormatToolCall func(message string) (string, []string) - Logger *slog.Logger -} +const ( + ConversationStatusChanging ConversationStatus = "changing" + ConversationStatusStable ConversationStatus = "stable" + ConversationStatusInitializing ConversationStatus = "initializing" +) type ConversationRole string +func (c ConversationRole) Schema(r huma.Registry) *huma.Schema { + return util.OpenAPISchema(r, "ConversationRole", ConversationRoleValues) +} + const ( ConversationRoleUser ConversationRole = "user" ConversationRoleAgent ConversationRole = "agent" @@ -61,207 +33,15 @@ var ConversationRoleValues = []ConversationRole{ ConversationRoleAgent, } -func (c ConversationRole) Schema(r huma.Registry) *huma.Schema { - return util.OpenAPISchema(r, "ConversationRole", ConversationRoleValues) -} - -type ConversationMessage struct { - Id int - Message string - Role ConversationRole - Time time.Time -} - -type Conversation struct { - cfg ConversationConfig - // How many stable snapshots are required to consider the screen stable - stableSnapshotsThreshold int - snapshotBuffer *RingBuffer[screenSnapshot] - messages []ConversationMessage - screenBeforeLastUserMessage string - lock sync.Mutex - // InitialPrompt is the initial prompt passed to the agent - InitialPrompt string - // InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents - InitialPromptSent bool - // ReadyForInitialPrompt keeps track if the agent is ready to accept the initial prompt - ReadyForInitialPrompt bool - // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message - toolCallMessageSet map[string]bool -} - -type ConversationStatus string - -const ( - ConversationStatusChanging ConversationStatus = "changing" - ConversationStatusStable ConversationStatus = "stable" - ConversationStatusInitializing ConversationStatus = "initializing" +var ( + MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace") + MessageValidationErrorEmpty = xerrors.New("message must not be empty") + MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input") ) -func getStableSnapshotsThreshold(cfg ConversationConfig) int { - length := cfg.ScreenStabilityLength.Milliseconds() - interval := cfg.SnapshotInterval.Milliseconds() - threshold := int(length / interval) - if length%interval != 0 { - threshold++ - } - return threshold + 1 -} - -func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt string) *Conversation { - threshold := getStableSnapshotsThreshold(cfg) - c := &Conversation{ - cfg: cfg, - stableSnapshotsThreshold: threshold, - snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), - messages: []ConversationMessage{ - { - Message: "", - Role: ConversationRoleAgent, - Time: cfg.GetTime(), - }, - }, - InitialPrompt: initialPrompt, - InitialPromptSent: len(initialPrompt) == 0, - toolCallMessageSet: make(map[string]bool), - } - return c -} - -func (c *Conversation) StartSnapshotLoop(ctx context.Context) { - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(c.cfg.SnapshotInterval): - // It's important that we hold the lock while reading the screen. - // There's a race condition that occurs without it: - // 1. The screen is read - // 2. Independently, SendMessage is called and takes the lock. - // 3. AddSnapshot is called and waits on the lock. - // 4. SendMessage modifies the terminal state, releases the lock - // 5. AddSnapshot adds a snapshot from a stale screen - c.lock.Lock() - screen := c.cfg.AgentIO.ReadScreen() - c.addSnapshotInner(screen) - c.lock.Unlock() - } - } - }() -} - -func FindNewMessage(oldScreen, newScreen string, agentType msgfmt.AgentType) string { - oldLines := strings.Split(oldScreen, "\n") - newLines := strings.Split(newScreen, "\n") - oldLinesMap := make(map[string]bool) - - // -1 indicates no header - dynamicHeaderEnd := -1 - - // Skip header lines for Opencode agent type to avoid false positives - // The header contains dynamic content (token count, context percentage, cost) - // that changes between screens, causing line comparison mismatches: - // - // ┃ # Getting Started with Claude CLI ┃ - // ┃ /share to create a shareable link 12.6K/6% ($0.05) ┃ - if len(newLines) >= 2 && agentType == msgfmt.AgentTypeOpencode { - dynamicHeaderEnd = 2 - } - - for _, line := range oldLines { - oldLinesMap[line] = true - } - firstNonMatchingLine := len(newLines) - for i, line := range newLines[dynamicHeaderEnd+1:] { - if !oldLinesMap[line] { - firstNonMatchingLine = i - break - } - } - newSectionLines := newLines[firstNonMatchingLine:] - - // remove leading and trailing lines which are empty or have only whitespace - startLine := 0 - endLine := len(newSectionLines) - 1 - for i := 0; i < len(newSectionLines); i++ { - if strings.TrimSpace(newSectionLines[i]) != "" { - startLine = i - break - } - } - for i := len(newSectionLines) - 1; i >= 0; i-- { - if strings.TrimSpace(newSectionLines[i]) != "" { - endLine = i - break - } - } - return strings.Join(newSectionLines[startLine:endLine+1], "\n") -} - -func (c *Conversation) lastMessage(role ConversationRole) ConversationMessage { - for i := len(c.messages) - 1; i >= 0; i-- { - if c.messages[i].Role == role { - return c.messages[i] - } - } - return ConversationMessage{} -} - -// This function assumes that the caller holds the lock -func (c *Conversation) updateLastAgentMessage(screen string, timestamp time.Time) { - agentMessage := FindNewMessage(c.screenBeforeLastUserMessage, screen, c.cfg.AgentType) - lastUserMessage := c.lastMessage(ConversationRoleUser) - var toolCalls []string - if c.cfg.FormatMessage != nil { - agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) - } - if c.cfg.FormatToolCall != nil { - agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) - } - for _, toolCall := range toolCalls { - if c.toolCallMessageSet[toolCall] == false { - c.toolCallMessageSet[toolCall] = true - c.cfg.Logger.Info("Tool call detected", "toolCall", toolCall) - } - } - shouldCreateNewMessage := len(c.messages) == 0 || c.messages[len(c.messages)-1].Role == ConversationRoleUser - lastAgentMessage := c.lastMessage(ConversationRoleAgent) - if lastAgentMessage.Message == agentMessage { - return - } - conversationMessage := ConversationMessage{ - Message: agentMessage, - Role: ConversationRoleAgent, - Time: timestamp, - } - if shouldCreateNewMessage { - c.messages = append(c.messages, conversationMessage) - - // Cleanup - c.toolCallMessageSet = make(map[string]bool) - - } else { - c.messages[len(c.messages)-1] = conversationMessage - } - c.messages[len(c.messages)-1].Id = len(c.messages) - 1 -} - -// assumes the caller holds the lock -func (c *Conversation) addSnapshotInner(screen string) { - snapshot := screenSnapshot{ - timestamp: c.cfg.GetTime(), - screen: screen, - } - c.snapshotBuffer.Add(snapshot) - c.updateLastAgentMessage(screen, snapshot.timestamp) -} - -func (c *Conversation) AddSnapshot(screen string) { - c.lock.Lock() - defer c.lock.Unlock() - - c.addSnapshotInner(screen) +type AgentIO interface { + Write(data []byte) (int, error) + ReadScreen() string } type MessagePart interface { @@ -269,198 +49,18 @@ type MessagePart interface { String() string } -type MessagePartText struct { - Content string - Alias string - Hidden bool -} - -func (p MessagePartText) Do(writer AgentIO) error { - _, err := writer.Write([]byte(p.Content)) - return err -} - -func (p MessagePartText) String() string { - if p.Hidden { - return "" - } - if p.Alias != "" { - return p.Alias - } - return p.Content -} - -func PartsToString(parts ...MessagePart) string { - var sb strings.Builder - for _, part := range parts { - sb.WriteString(part.String()) - } - return sb.String() -} - -func ExecuteParts(writer AgentIO, parts ...MessagePart) error { - for _, part := range parts { - if err := part.Do(writer); err != nil { - return xerrors.Errorf("failed to write message part: %w", err) - } - } - return nil -} - -func (c *Conversation) writeMessageWithConfirmation(ctx context.Context, messageParts ...MessagePart) error { - if c.cfg.SkipWritingMessage { - return nil - } - screenBeforeMessage := c.cfg.AgentIO.ReadScreen() - if err := ExecuteParts(c.cfg.AgentIO, messageParts...); err != nil { - return xerrors.Errorf("failed to write message: %w", err) - } - // wait for the screen to stabilize after the message is written - if err := util.WaitFor(ctx, util.WaitTimeout{ - Timeout: 15 * time.Second, - MinInterval: 50 * time.Millisecond, - InitialWait: true, - }, func() (bool, error) { - screen := c.cfg.AgentIO.ReadScreen() - if screen != screenBeforeMessage { - time.Sleep(1 * time.Second) - newScreen := c.cfg.AgentIO.ReadScreen() - return newScreen == screen, nil - } - return false, nil - }); err != nil { - return xerrors.Errorf("failed to wait for screen to stabilize: %w", err) - } - - // wait for the screen to change after the carriage return is written - screenBeforeCarriageReturn := c.cfg.AgentIO.ReadScreen() - lastCarriageReturnTime := time.Time{} - if err := util.WaitFor(ctx, util.WaitTimeout{ - Timeout: 15 * time.Second, - MinInterval: 25 * time.Millisecond, - }, func() (bool, error) { - // we don't want to spam additional carriage returns because the agent may process them - // (aider does this), but we do want to retry sending one if nothing's - // happening for a while - if time.Since(lastCarriageReturnTime) >= 3*time.Second { - lastCarriageReturnTime = time.Now() - if _, err := c.cfg.AgentIO.Write([]byte("\r")); err != nil { - return false, xerrors.Errorf("failed to write carriage return: %w", err) - } - } - time.Sleep(25 * time.Millisecond) - screen := c.cfg.AgentIO.ReadScreen() - - return screen != screenBeforeCarriageReturn, nil - }); err != nil { - return xerrors.Errorf("failed to wait for processing to start: %w", err) - } - - return nil -} - -var MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace") -var MessageValidationErrorEmpty = xerrors.New("message must not be empty") -var MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input") - -func (c *Conversation) SendMessage(messageParts ...MessagePart) error { - c.lock.Lock() - defer c.lock.Unlock() - - if !c.cfg.SkipSendMessageStatusCheck && c.statusInner() != ConversationStatusStable { - return MessageValidationErrorChanging - } - - message := PartsToString(messageParts...) - if message != msgfmt.TrimWhitespace(message) { - // msgfmt formatting functions assume this - return MessageValidationErrorWhitespace - } - if message == "" { - // writeMessageWithConfirmation requires a non-empty message - return MessageValidationErrorEmpty - } - - screenBeforeMessage := c.cfg.AgentIO.ReadScreen() - now := c.cfg.GetTime() - c.updateLastAgentMessage(screenBeforeMessage, now) - - if err := c.writeMessageWithConfirmation(context.Background(), messageParts...); err != nil { - return xerrors.Errorf("failed to send message: %w", err) - } - - c.screenBeforeLastUserMessage = screenBeforeMessage - c.messages = append(c.messages, ConversationMessage{ - Id: len(c.messages), - Message: message, - Role: ConversationRoleUser, - Time: now, - }) - return nil -} - -// Assumes that the caller holds the lock -func (c *Conversation) statusInner() ConversationStatus { - // sanity checks - if c.snapshotBuffer.Capacity() != c.stableSnapshotsThreshold { - panic(fmt.Sprintf("snapshot buffer capacity %d is not equal to snapshot threshold %d. can't check stability", c.snapshotBuffer.Capacity(), c.stableSnapshotsThreshold)) - } - if c.stableSnapshotsThreshold == 0 { - panic("stable snapshots threshold is 0. can't check stability") - } - - snapshots := c.snapshotBuffer.GetAll() - if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == ConversationRoleUser { - // if the last message is a user message then the snapshot loop hasn't - // been triggered since the last user message, and we should assume - // the screen is changing - return ConversationStatusChanging - } - - if len(snapshots) != c.stableSnapshotsThreshold { - return ConversationStatusInitializing - } - - for i := 1; i < len(snapshots); i++ { - if snapshots[0].screen != snapshots[i].screen { - return ConversationStatusChanging - } - } - - if !c.InitialPromptSent && !c.ReadyForInitialPrompt { - if len(snapshots) > 0 && c.cfg.ReadyForInitialPrompt(snapshots[len(snapshots)-1].screen) { - c.ReadyForInitialPrompt = true - return ConversationStatusStable - } - return ConversationStatusChanging - } - - return ConversationStatusStable -} - -func (c *Conversation) Status() ConversationStatus { - c.lock.Lock() - defer c.lock.Unlock() - - return c.statusInner() -} - -func (c *Conversation) Messages() []ConversationMessage { - c.lock.Lock() - defer c.lock.Unlock() - - result := make([]ConversationMessage, len(c.messages)) - copy(result, c.messages) - return result +// Conversation allows tracking of a conversation between a user and an agent. +type Conversation interface { + Messages() []ConversationMessage + Snapshot(string) + Start(context.Context) + Status() ConversationStatus + String() string } -func (c *Conversation) Screen() string { - c.lock.Lock() - defer c.lock.Unlock() - - snapshots := c.snapshotBuffer.GetAll() - if len(snapshots) == 0 { - return "" - } - return snapshots[len(snapshots)-1].screen +type ConversationMessage struct { + Id int + Message string + Role ConversationRole + Time time.Time } diff --git a/lib/screentracker/diff.go b/lib/screentracker/diff.go new file mode 100644 index 00000000..47c5b78c --- /dev/null +++ b/lib/screentracker/diff.go @@ -0,0 +1,56 @@ +package screentracker + +import ( + "strings" + + "github.com/coder/agentapi/lib/msgfmt" +) + +// screenDiff compares two screen states and attempts to find latest message of the given agent type. +func screenDiff(oldScreen, newScreen string, agentType msgfmt.AgentType) string { + oldLines := strings.Split(oldScreen, "\n") + newLines := strings.Split(newScreen, "\n") + oldLinesMap := make(map[string]bool) + + // -1 indicates no header + dynamicHeaderEnd := -1 + + // Skip header lines for Opencode agent type to avoid false positives + // The header contains dynamic content (token count, context percentage, cost) + // that changes between screens, causing line comparison mismatches: + // + // ┃ # Getting Started with Claude CLI ┃ + // ┃ /share to create a shareable link 12.6K/6% ($0.05) ┃ + if len(newLines) >= 2 && agentType == msgfmt.AgentTypeOpencode { + dynamicHeaderEnd = 2 + } + + for _, line := range oldLines { + oldLinesMap[line] = true + } + firstNonMatchingLine := len(newLines) + for i, line := range newLines[dynamicHeaderEnd+1:] { + if !oldLinesMap[line] { + firstNonMatchingLine = i + break + } + } + newSectionLines := newLines[firstNonMatchingLine:] + + // remove leading and trailing lines which are empty or have only whitespace + startLine := 0 + endLine := len(newSectionLines) - 1 + for i := range newSectionLines { + if strings.TrimSpace(newSectionLines[i]) != "" { + startLine = i + break + } + } + for i := len(newSectionLines) - 1; i >= 0; i-- { + if strings.TrimSpace(newSectionLines[i]) != "" { + endLine = i + break + } + } + return strings.Join(newSectionLines[startLine:endLine+1], "\n") +} diff --git a/lib/screentracker/diff_internal_test.go b/lib/screentracker/diff_internal_test.go new file mode 100644 index 00000000..d68bc36c --- /dev/null +++ b/lib/screentracker/diff_internal_test.go @@ -0,0 +1,39 @@ +package screentracker + +import ( + "embed" + "path" + "testing" + + "github.com/coder/agentapi/lib/msgfmt" + "github.com/stretchr/testify/assert" +) + +//go:embed testdata +var testdataDir embed.FS + +func TestScreenDiff(t *testing.T) { + t.Run("simple", func(t *testing.T) { + assert.Equal(t, "", screenDiff("123456", "123456", msgfmt.AgentTypeCustom)) + assert.Equal(t, "1234567", screenDiff("123456", "1234567", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("123", "123\n \n \n \n42", msgfmt.AgentTypeCustom)) + assert.Equal(t, "12342", screenDiff("123", "12342\n \n \n \n", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("123", "123\n \n \n \n42\n \n \n \n", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("89", "42", msgfmt.AgentTypeCustom)) + }) + + dir := "testdata/diff" + cases, err := testdataDir.ReadDir(dir) + assert.NoError(t, err) + for _, c := range cases { + t.Run(c.Name(), func(t *testing.T) { + before, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "before.txt")) + assert.NoError(t, err) + after, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "after.txt")) + assert.NoError(t, err) + expected, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "expected.txt")) + assert.NoError(t, err) + assert.Equal(t, string(expected), screenDiff(string(before), string(after), msgfmt.AgentTypeCustom)) + }) + } +} diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go new file mode 100644 index 00000000..91b956a1 --- /dev/null +++ b/lib/screentracker/pty_conversation.go @@ -0,0 +1,371 @@ +package screentracker + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "time" + + "github.com/coder/agentapi/lib/msgfmt" + "github.com/coder/agentapi/lib/util" + "golang.org/x/xerrors" +) + +// A screenSnapshot represents a snapshot of the PTY at a specific time. +type screenSnapshot struct { + timestamp time.Time + screen string +} + +type MessagePartText struct { + Content string + Alias string + Hidden bool +} + +var _ MessagePart = &MessagePartText{} + +func (p MessagePartText) Do(writer AgentIO) error { + _, err := writer.Write([]byte(p.Content)) + return err +} + +func (p MessagePartText) String() string { + if p.Hidden { + return "" + } + if p.Alias != "" { + return p.Alias + } + return p.Content +} + +// PTYConversationConfig is the configuration for a PTYConversation. +type PTYConversationConfig struct { + AgentType msgfmt.AgentType + AgentIO AgentIO + // GetTime returns the current time + GetTime func() time.Time + // How often to take a snapshot for the stability check + SnapshotInterval time.Duration + // How long the screen should not change to be considered stable + ScreenStabilityLength time.Duration + // Function to format the messages received from the agent + // userInput is the last user message + FormatMessage func(message string, userInput string) string + // SkipWritingMessage skips the writing of a message to the agent. + // This is used in tests + SkipWritingMessage bool + // SkipSendMessageStatusCheck skips the check for whether the message can be sent. + // This is used in tests + SkipSendMessageStatusCheck bool + // ReadyForInitialPrompt detects whether the agent has initialized and is ready to accept the initial prompt + ReadyForInitialPrompt func(message string) bool + // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls + FormatToolCall func(message string) (string, []string) + Logger *slog.Logger +} + +func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { + length := cfg.ScreenStabilityLength.Milliseconds() + interval := cfg.SnapshotInterval.Milliseconds() + threshold := int(length / interval) + if length%interval != 0 { + threshold++ + } + return threshold + 1 +} + +// PTYConversation is a conversation that uses a pseudo-terminal (PTY) for communication. +// It uses a combination of polling and diffs to detect changes in the screen. +type PTYConversation struct { + cfg PTYConversationConfig + // How many stable snapshots are required to consider the screen stable + stableSnapshotsThreshold int + snapshotBuffer *RingBuffer[screenSnapshot] + messages []ConversationMessage + screenBeforeLastUserMessage string + lock sync.Mutex + + // InitialPrompt is the initial prompt passed to the agent + InitialPrompt string + // InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents + InitialPromptSent bool + // ReadyForInitialPrompt keeps track if the agent is ready to accept the initial prompt + ReadyForInitialPrompt bool + // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message + toolCallMessageSet map[string]bool +} + +var _ Conversation = &PTYConversation{} + +func NewPTY(ctx context.Context, cfg PTYConversationConfig, initialPrompt string) *PTYConversation { + threshold := cfg.getStableSnapshotsThreshold() + c := &PTYConversation{ + cfg: cfg, + stableSnapshotsThreshold: threshold, + snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), + messages: []ConversationMessage{ + { + Message: "", + Role: ConversationRoleAgent, + Time: cfg.GetTime(), + }, + }, + InitialPrompt: initialPrompt, + InitialPromptSent: len(initialPrompt) == 0, + toolCallMessageSet: make(map[string]bool), + } + return c +} + +func (c *PTYConversation) Start(ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(c.cfg.SnapshotInterval): + // It's important that we hold the lock while reading the screen. + // There's a race condition that occurs without it: + // 1. The screen is read + // 2. Independently, SendMessage is called and takes the lock. + // 3. AddSnapshot is called and waits on the lock. + // 4. SendMessage modifies the terminal state, releases the lock + // 5. AddSnapshot adds a snapshot from a stale screen + c.lock.Lock() + screen := c.cfg.AgentIO.ReadScreen() + c.snapshotLocked(screen) + c.lock.Unlock() + } + } + }() +} + +func (c *PTYConversation) lastMessage(role ConversationRole) ConversationMessage { + for i := len(c.messages) - 1; i >= 0; i-- { + if c.messages[i].Role == role { + return c.messages[i] + } + } + return ConversationMessage{} +} + +// caller MUST hold c.lock +func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp time.Time) { + agentMessage := screenDiff(c.screenBeforeLastUserMessage, screen, c.cfg.AgentType) + lastUserMessage := c.lastMessage(ConversationRoleUser) + var toolCalls []string + if c.cfg.FormatMessage != nil { + agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) + } + if c.cfg.FormatToolCall != nil { + agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) + } + for _, toolCall := range toolCalls { + if c.toolCallMessageSet[toolCall] == false { + c.toolCallMessageSet[toolCall] = true + c.cfg.Logger.Info("Tool call detected", "toolCall", toolCall) + } + } + shouldCreateNewMessage := len(c.messages) == 0 || c.messages[len(c.messages)-1].Role == ConversationRoleUser + lastAgentMessage := c.lastMessage(ConversationRoleAgent) + if lastAgentMessage.Message == agentMessage { + return + } + conversationMessage := ConversationMessage{ + Message: agentMessage, + Role: ConversationRoleAgent, + Time: timestamp, + } + if shouldCreateNewMessage { + c.messages = append(c.messages, conversationMessage) + + // Cleanup + c.toolCallMessageSet = make(map[string]bool) + + } else { + c.messages[len(c.messages)-1] = conversationMessage + } + c.messages[len(c.messages)-1].Id = len(c.messages) - 1 +} + +func (c *PTYConversation) Snapshot(screen string) { + c.lock.Lock() + defer c.lock.Unlock() + + c.snapshotLocked(screen) +} + +// caller MUST hold c.lock +func (c *PTYConversation) snapshotLocked(screen string) { + snapshot := screenSnapshot{ + timestamp: c.cfg.GetTime(), + screen: screen, + } + c.snapshotBuffer.Add(snapshot) + c.updateLastAgentMessageLocked(screen, snapshot.timestamp) +} + +func (c *PTYConversation) Send(messageParts ...MessagePart) error { + c.lock.Lock() + defer c.lock.Unlock() + + if !c.cfg.SkipSendMessageStatusCheck && c.statusLocked() != ConversationStatusStable { + return MessageValidationErrorChanging + } + + var sb strings.Builder + for _, part := range messageParts { + sb.WriteString(part.String()) + } + message := sb.String() + if message != msgfmt.TrimWhitespace(message) { + // msgfmt formatting functions assume this + return MessageValidationErrorWhitespace + } + if message == "" { + // writeMessageWithConfirmation requires a non-empty message + return MessageValidationErrorEmpty + } + + screenBeforeMessage := c.cfg.AgentIO.ReadScreen() + now := c.cfg.GetTime() + c.updateLastAgentMessageLocked(screenBeforeMessage, now) + + if err := c.writeStabilize(context.Background(), messageParts...); err != nil { + return xerrors.Errorf("failed to send message: %w", err) + } + + c.screenBeforeLastUserMessage = screenBeforeMessage + c.messages = append(c.messages, ConversationMessage{ + Id: len(c.messages), + Message: message, + Role: ConversationRoleUser, + Time: now, + }) + return nil +} + +// writeStabilize writes messageParts to the screen and waits for the screen to stabilize after the message is written. +func (c *PTYConversation) writeStabilize(ctx context.Context, messageParts ...MessagePart) error { + if c.cfg.SkipWritingMessage { + return nil + } + screenBeforeMessage := c.cfg.AgentIO.ReadScreen() + for _, part := range messageParts { + if err := part.Do(c.cfg.AgentIO); err != nil { + return xerrors.Errorf("failed to write message part: %w", err) + } + } + // wait for the screen to stabilize after the message is written + if err := util.WaitFor(ctx, util.WaitTimeout{ + Timeout: 15 * time.Second, + MinInterval: 50 * time.Millisecond, + InitialWait: true, + }, func() (bool, error) { + screen := c.cfg.AgentIO.ReadScreen() + if screen != screenBeforeMessage { + time.Sleep(1 * time.Second) + newScreen := c.cfg.AgentIO.ReadScreen() + return newScreen == screen, nil + } + return false, nil + }); err != nil { + return xerrors.Errorf("failed to wait for screen to stabilize: %w", err) + } + + // wait for the screen to change after the carriage return is written + screenBeforeCarriageReturn := c.cfg.AgentIO.ReadScreen() + lastCarriageReturnTime := time.Time{} + if err := util.WaitFor(ctx, util.WaitTimeout{ + Timeout: 15 * time.Second, + MinInterval: 25 * time.Millisecond, + }, func() (bool, error) { + // we don't want to spam additional carriage returns because the agent may process them + // (aider does this), but we do want to retry sending one if nothing's + // happening for a while + if time.Since(lastCarriageReturnTime) >= 3*time.Second { + lastCarriageReturnTime = time.Now() + if _, err := c.cfg.AgentIO.Write([]byte("\r")); err != nil { + return false, xerrors.Errorf("failed to write carriage return: %w", err) + } + } + time.Sleep(25 * time.Millisecond) + screen := c.cfg.AgentIO.ReadScreen() + + return screen != screenBeforeCarriageReturn, nil + }); err != nil { + return xerrors.Errorf("failed to wait for processing to start: %w", err) + } + + return nil +} + +func (c *PTYConversation) Status() ConversationStatus { + c.lock.Lock() + defer c.lock.Unlock() + + return c.statusLocked() +} + +// caller MUST hold c.lock +func (c *PTYConversation) statusLocked() ConversationStatus { + // sanity checks + if c.snapshotBuffer.Capacity() != c.stableSnapshotsThreshold { + panic(fmt.Sprintf("snapshot buffer capacity %d is not equal to snapshot threshold %d. can't check stability", c.snapshotBuffer.Capacity(), c.stableSnapshotsThreshold)) + } + if c.stableSnapshotsThreshold == 0 { + panic("stable snapshots threshold is 0. can't check stability") + } + + snapshots := c.snapshotBuffer.GetAll() + if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == ConversationRoleUser { + // if the last message is a user message then the snapshot loop hasn't + // been triggered since the last user message, and we should assume + // the screen is changing + return ConversationStatusChanging + } + + if len(snapshots) != c.stableSnapshotsThreshold { + return ConversationStatusInitializing + } + + for i := 1; i < len(snapshots); i++ { + if snapshots[0].screen != snapshots[i].screen { + return ConversationStatusChanging + } + } + + if !c.InitialPromptSent && !c.ReadyForInitialPrompt { + if len(snapshots) > 0 && c.cfg.ReadyForInitialPrompt(snapshots[len(snapshots)-1].screen) { + c.ReadyForInitialPrompt = true + return ConversationStatusStable + } + return ConversationStatusChanging + } + + return ConversationStatusStable +} + +func (c *PTYConversation) Messages() []ConversationMessage { + c.lock.Lock() + defer c.lock.Unlock() + + result := make([]ConversationMessage, len(c.messages)) + copy(result, c.messages) + return result +} + +func (c *PTYConversation) String() string { + c.lock.Lock() + defer c.lock.Unlock() + + snapshots := c.snapshotBuffer.GetAll() + if len(snapshots) == 0 { + return "" + } + return snapshots[len(snapshots)-1].screen +} diff --git a/lib/screentracker/conversation_test.go b/lib/screentracker/pty_conversation_test.go similarity index 75% rename from lib/screentracker/conversation_test.go rename to lib/screentracker/pty_conversation_test.go index 9b888813..6798de4d 100644 --- a/lib/screentracker/conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -2,13 +2,10 @@ package screentracker_test import ( "context" - "embed" "fmt" - "path" "testing" "time" - "github.com/coder/agentapi/lib/msgfmt" "github.com/stretchr/testify/assert" st "github.com/coder/agentapi/lib/screentracker" @@ -19,7 +16,7 @@ type statusTestStep struct { status st.ConversationStatus } type statusTestParams struct { - cfg st.ConversationConfig + cfg st.PTYConversationConfig steps []statusTestStep } @@ -42,11 +39,11 @@ func statusTest(t *testing.T, params statusTestParams) { if params.cfg.GetTime == nil { params.cfg.GetTime = func() time.Time { return time.Now() } } - c := st.NewConversation(ctx, params.cfg, "") + c := st.NewPTY(ctx, params.cfg, "") assert.Equal(t, st.ConversationStatusInitializing, c.Status()) for i, step := range params.steps { - c.AddSnapshot(step.snapshot) + c.Snapshot(step.snapshot) assert.Equal(t, step.status, c.Status(), "step %d", i) } }) @@ -58,7 +55,7 @@ func TestConversation(t *testing.T) { initializing := st.ConversationStatusInitializing statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 2 * time.Second, // stability threshold: 3 @@ -76,7 +73,7 @@ func TestConversation(t *testing.T) { }) statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 2 * time.Second, ScreenStabilityLength: 3 * time.Second, // stability threshold: 3 @@ -95,7 +92,7 @@ func TestConversation(t *testing.T) { }) statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 6 * time.Second, ScreenStabilityLength: 14 * time.Second, // stability threshold: 4 @@ -133,11 +130,11 @@ func TestMessages(t *testing.T) { Time: now, } } - sendMsg := func(c *st.Conversation, msg string) error { - return c.SendMessage(st.MessagePartText{Content: msg}) + sendMsg := func(c *st.PTYConversation, msg string) error { + return c.Send(st.MessagePartText{Content: msg}) } - newConversation := func(opts ...func(*st.ConversationConfig)) *st.Conversation { - cfg := st.ConversationConfig{ + newConversation := func(opts ...func(*st.PTYConversationConfig)) *st.PTYConversation { + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 2 * time.Second, @@ -147,7 +144,7 @@ func TestMessages(t *testing.T) { for _, opt := range opts { opt(&cfg) } - return st.NewConversation(context.Background(), cfg, "") + return st.NewPTY(context.Background(), cfg, "") } t.Run("messages are copied", func(t *testing.T) { @@ -167,7 +164,7 @@ func TestMessages(t *testing.T) { t.Run("whitespace-padding", func(t *testing.T) { c := newConversation() for _, msg := range []string{"123 ", " 123", "123\t\t", "\n123", "123\n\t", " \t123\n\t"} { - err := c.SendMessage(st.MessagePartText{Content: msg}) + err := c.Send(st.MessagePartText{Content: msg}) assert.Error(t, err, st.MessageValidationErrorWhitespace) } }) @@ -178,33 +175,33 @@ func TestMessages(t *testing.T) { }{ Time: now, } - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.GetTime = func() time.Time { return nowWrapper.Time } }) - c.AddSnapshot("1") + c.Snapshot("1") msgs := c.Messages() assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), }, msgs) nowWrapper.Time = nowWrapper.Add(1 * time.Second) - c.AddSnapshot("1") + c.Snapshot("1") assert.Equal(t, msgs, c.Messages()) }) t.Run("tracking messages", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent }) // agent message is recorded when the first snapshot is added - c.AddSnapshot("1") + c.Snapshot("1") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), }, c.Messages()) // agent message is updated when the screen changes - c.AddSnapshot("2") + c.Snapshot("2") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "2"), }, c.Messages()) @@ -218,7 +215,7 @@ func TestMessages(t *testing.T) { }, c.Messages()) // agent message is added after a user message - c.AddSnapshot("4") + c.Snapshot("4") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "2"), userMsg(1, "3"), @@ -236,9 +233,9 @@ func TestMessages(t *testing.T) { }, c.Messages()) // conversation status is changing right after a user message - c.AddSnapshot("7") - c.AddSnapshot("7") - c.AddSnapshot("7") + c.Snapshot("7") + c.Snapshot("7") + c.Snapshot("7") assert.Equal(t, st.ConversationStatusStable, c.Status()) agent.screen = "7" assert.NoError(t, sendMsg(c, "8")) @@ -254,21 +251,21 @@ func TestMessages(t *testing.T) { // conversation status is back to stable after a snapshot that // doesn't change the screen - c.AddSnapshot("7") + c.Snapshot("7") assert.Equal(t, st.ConversationStatusStable, c.Status()) }) t.Run("tracking messages overlap", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent }) // common overlap between screens is removed after a user message - c.AddSnapshot("1") + c.Snapshot("1") agent.screen = "1" assert.NoError(t, sendMsg(c, "2")) - c.AddSnapshot("1\n3") + c.Snapshot("1\n3") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), userMsg(1, "2"), @@ -277,7 +274,7 @@ func TestMessages(t *testing.T) { agent.screen = "1\n3x" assert.NoError(t, sendMsg(c, "4")) - c.AddSnapshot("1\n3x\n5") + c.Snapshot("1\n3x\n5") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), userMsg(1, "2"), @@ -289,7 +286,7 @@ func TestMessages(t *testing.T) { t.Run("format-message", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent cfg.FormatMessage = func(message string, userInput string) string { return message + " " + userInput @@ -302,7 +299,7 @@ func TestMessages(t *testing.T) { userMsg(1, "2"), }, c.Messages()) agent.screen = "x" - c.AddSnapshot("x") + c.Snapshot("x") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1 "), userMsg(1, "2"), @@ -312,7 +309,7 @@ func TestMessages(t *testing.T) { t.Run("format-message", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent cfg.FormatMessage = func(message string, userInput string) string { return "formatted" @@ -329,7 +326,7 @@ func TestMessages(t *testing.T) { }) t.Run("send-message-status-check", func(t *testing.T) { - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.SkipSendMessageStatusCheck = false cfg.SnapshotInterval = 1 * time.Second cfg.ScreenStabilityLength = 2 * time.Second @@ -337,10 +334,10 @@ func TestMessages(t *testing.T) { }) assert.Error(t, sendMsg(c, "1"), st.MessageValidationErrorChanging) for range 3 { - c.AddSnapshot("1") + c.Snapshot("1") } assert.NoError(t, sendMsg(c, "4")) - c.AddSnapshot("2") + c.Snapshot("2") assert.Error(t, sendMsg(c, "5"), st.MessageValidationErrorChanging) }) @@ -350,68 +347,11 @@ func TestMessages(t *testing.T) { }) } -//go:embed testdata -var testdataDir embed.FS - -func TestFindNewMessage(t *testing.T) { - assert.Equal(t, "", st.FindNewMessage("123456", "123456", msgfmt.AgentTypeCustom)) - assert.Equal(t, "1234567", st.FindNewMessage("123456", "1234567", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("123", "123\n \n \n \n42", msgfmt.AgentTypeCustom)) - assert.Equal(t, "12342", st.FindNewMessage("123", "12342\n \n \n \n", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("123", "123\n \n \n \n42\n \n \n \n", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("89", "42", msgfmt.AgentTypeCustom)) - - dir := "testdata/diff" - cases, err := testdataDir.ReadDir(dir) - assert.NoError(t, err) - for _, c := range cases { - t.Run(c.Name(), func(t *testing.T) { - before, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "before.txt")) - assert.NoError(t, err) - after, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "after.txt")) - assert.NoError(t, err) - expected, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "expected.txt")) - assert.NoError(t, err) - assert.Equal(t, string(expected), st.FindNewMessage(string(before), string(after), msgfmt.AgentTypeCustom)) - }) - } -} - -func TestPartsToString(t *testing.T) { - assert.Equal(t, "123", st.PartsToString(st.MessagePartText{Content: "123"})) - assert.Equal(t, - "123", - st.PartsToString( - st.MessagePartText{Content: "1"}, - st.MessagePartText{Content: "2"}, - st.MessagePartText{Content: "3"}, - ), - ) - assert.Equal(t, - "123", - st.PartsToString( - st.MessagePartText{Content: "1"}, - st.MessagePartText{Content: "x", Hidden: true}, - st.MessagePartText{Content: "2"}, - st.MessagePartText{Content: "3"}, - st.MessagePartText{Content: "y", Hidden: true}, - ), - ) - assert.Equal(t, - "ab", - st.PartsToString( - st.MessagePartText{Content: "1", Alias: "a"}, - st.MessagePartText{Content: "2", Alias: "b"}, - st.MessagePartText{Content: "3", Alias: "c", Hidden: true}, - ), - ) -} - func TestInitialPromptReadiness(t *testing.T) { now := time.Now() t.Run("agent not ready - status remains changing", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -420,10 +360,10 @@ func TestInitialPromptReadiness(t *testing.T) { return message == "ready" }, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Fill buffer with stable snapshots, but agent is not ready - c.AddSnapshot("loading...") + c.Snapshot("loading...") // Even though screen is stable, status should be changing because agent is not ready assert.Equal(t, st.ConversationStatusChanging, c.Status()) @@ -432,7 +372,7 @@ func TestInitialPromptReadiness(t *testing.T) { }) t.Run("agent becomes ready - status changes to stable", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -441,14 +381,14 @@ func TestInitialPromptReadiness(t *testing.T) { return message == "ready" }, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Agent not ready initially - c.AddSnapshot("loading...") + c.Snapshot("loading...") assert.Equal(t, st.ConversationStatusChanging, c.Status()) // Agent becomes ready - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt) assert.False(t, c.InitialPromptSent) @@ -456,7 +396,7 @@ func TestInitialPromptReadiness(t *testing.T) { t.Run("ready for initial prompt lifecycle: false -> true -> false", func(t *testing.T) { agent := &testAgent{screen: "loading..."} - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -467,23 +407,23 @@ func TestInitialPromptReadiness(t *testing.T) { SkipWritingMessage: true, SkipSendMessageStatusCheck: true, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Initial state: ReadyForInitialPrompt should be false - c.AddSnapshot("loading...") + c.Snapshot("loading...") assert.False(t, c.ReadyForInitialPrompt, "should start as false") assert.False(t, c.InitialPromptSent) assert.Equal(t, st.ConversationStatusChanging, c.Status()) // Agent becomes ready: ReadyForInitialPrompt should become true agent.screen = "ready" - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt, "should become true when ready") assert.False(t, c.InitialPromptSent) // Send the initial prompt - assert.NoError(t, c.SendMessage(st.MessagePartText{Content: "initial prompt here"})) + assert.NoError(t, c.Send(st.MessagePartText{Content: "initial prompt here"})) // After sending initial prompt: ReadyForInitialPrompt should be set back to false // (simulating what happens in the actual server code) @@ -496,7 +436,7 @@ func TestInitialPromptReadiness(t *testing.T) { }) t.Run("no initial prompt - normal status logic applies", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -506,9 +446,9 @@ func TestInitialPromptReadiness(t *testing.T) { }, } // Empty initial prompt means no need to wait for readiness - c := st.NewConversation(context.Background(), cfg, "") + c := st.NewPTY(context.Background(), cfg, "") - c.AddSnapshot("loading...") + c.Snapshot("loading...") // Status should be stable because no initial prompt to wait for assert.Equal(t, st.ConversationStatusStable, c.Status()) @@ -518,7 +458,7 @@ func TestInitialPromptReadiness(t *testing.T) { t.Run("initial prompt sent - normal status logic applies", func(t *testing.T) { agent := &testAgent{screen: "ready"} - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -529,24 +469,24 @@ func TestInitialPromptReadiness(t *testing.T) { SkipWritingMessage: true, SkipSendMessageStatusCheck: true, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // First, agent becomes ready - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt) assert.False(t, c.InitialPromptSent) // Send the initial prompt agent.screen = "processing..." - assert.NoError(t, c.SendMessage(st.MessagePartText{Content: "initial prompt here"})) + assert.NoError(t, c.Send(st.MessagePartText{Content: "initial prompt here"})) // Mark initial prompt as sent (simulating what the server does) c.InitialPromptSent = true c.ReadyForInitialPrompt = false // Now test that status logic works normally after initial prompt is sent - c.AddSnapshot("processing...") + c.Snapshot("processing...") // Status should be stable because initial prompt was already sent // and the readiness check is bypassed From a0f8bb563bd4cde3c40752c1f2fdecb454f995b7 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Sat, 31 Jan 2026 20:52:52 +0530 Subject: [PATCH 02/14] feat: implement state persistence --- cmd/server/server.go | 38 +++++- lib/httpapi/events.go | 2 +- lib/httpapi/server.go | 162 ++++++++++++++++++++------ lib/httpapi/setup.go | 14 --- lib/screentracker/conversation.go | 9 ++ lib/screentracker/pty_conversation.go | 116 ++++++++++++++++++ 6 files changed, 287 insertions(+), 54 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index 6a7fa7f0..5a125af4 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -103,6 +103,26 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } } + // Get the variables related to state management + stateFile := viper.GetString(StateFile) + loadState := true + saveState := true + if stateFile != "" { + if !viper.IsSet(LoadState) { + loadState = true + } else { + loadState = viper.GetBool(LoadState) + } + + if !viper.IsSet(SaveState) { + saveState = true + } else { + saveState = viper.GetBool(SaveState) + } + } + + pidFile := viper.GetString(PidFile) + printOpenAPI := viper.GetBool(FlagPrintOpenAPI) var process *termexec.Process if printOpenAPI { @@ -128,7 +148,14 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), InitialPrompt: initialPrompt, + StatePersistenceCfg: httpapi.StatePersistenceCfg{ + StateFile: stateFile, + LoadState: loadState, + SaveState: saveState, + PidFile: pidFile, + }, }) + if err != nil { return xerrors.Errorf("failed to create server: %w", err) } @@ -137,6 +164,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er return nil } srv.StartSnapshotLoop(ctx) + srv.HandleSignals(ctx, process) logger.Info("Starting server on port", "port", port) processExitCh := make(chan error, 1) go func() { @@ -152,7 +180,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er logger.Error("Failed to stop server", "error", err) } }() - if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed { + if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { return xerrors.Errorf("failed to start server: %w", err) } select { @@ -191,6 +219,10 @@ const ( FlagAllowedOrigins = "allowed-origins" FlagExit = "exit" FlagInitialPrompt = "initial-prompt" + StateFile = "state-file" + LoadState = "load-state" + SaveState = "save-state" + PidFile = "pid-file" ) func CreateServerCmd() *cobra.Command { @@ -229,6 +261,10 @@ func CreateServerCmd() *cobra.Command { // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, + {StateFile, "s", "", "Path to file for saving/loading server state", "string"}, + {LoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"}, + {SaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"}, + {PidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"}, } for _, spec := range flagSpecs { diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 73eff07b..e8cabab6 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -120,7 +120,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) { } } -// Assumes that only the last message can change or new messages can be added. +// UpdateMessagesAndEmitChanges assumes that only the last message can change or new messages can be added. // If a new message is injected between existing messages (identified by Id), the behavior is undefined. func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.ConversationMessage) { e.mu.Lock() diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index cc330c6b..965fa28f 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -11,11 +11,13 @@ import ( "net/http" "net/url" "os" + "os/signal" "path/filepath" "slices" "sort" "strings" "sync" + "syscall" "time" "unicode" @@ -34,18 +36,20 @@ import ( // Server represents the HTTP server type Server struct { - router chi.Router - api huma.API - port int - srv *http.Server - mu sync.RWMutex - logger *slog.Logger - conversation *st.PTYConversation - agentio *termexec.Process - agentType mf.AgentType - emitter *EventEmitter - chatBasePath string - tempDir string + router chi.Router + api huma.API + port int + srv *http.Server + mu sync.RWMutex + logger *slog.Logger + conversation *st.PTYConversation + agentio *termexec.Process + agentType mf.AgentType + emitter *EventEmitter + chatBasePath string + tempDir string + statePersistenceCfg StatePersistenceCfg + stateLoadComplete bool } func (s *Server) NormalizeSchema(schema any) any { @@ -94,14 +98,22 @@ func (s *Server) GetOpenAPI() string { // because the action of taking a snapshot takes time too. const snapshotInterval = 25 * time.Millisecond +type StatePersistenceCfg struct { + StateFile string + LoadState bool + SaveState bool + PidFile string +} + type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string - AllowedHosts []string - AllowedOrigins []string - InitialPrompt string + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string + InitialPrompt string + StatePersistenceCfg StatePersistenceCfg } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -260,16 +272,18 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { logger.Info("Created temporary directory for uploads", "tempDir", tempDir) s := &Server{ - router: router, - api: api, - port: config.Port, - conversation: conversation, - logger: logger, - agentio: config.Process, - agentType: config.AgentType, - emitter: emitter, - chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), - tempDir: tempDir, + router: router, + api: api, + port: config.Port, + conversation: conversation, + logger: logger, + agentio: config.Process, + agentType: config.AgentType, + emitter: emitter, + chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), + tempDir: tempDir, + statePersistenceCfg: config.StatePersistenceCfg, + stateLoadComplete: false, } // Register API routes @@ -337,15 +351,26 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { currentStatus := s.conversation.Status() // Send initial prompt when agent becomes stable for the first time - if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { - - if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { - s.logger.Error("Failed to send initial prompt", "error", err) - } else { - s.conversation.InitialPromptSent = true - s.conversation.ReadyForInitialPrompt = false - currentStatus = st.ConversationStatusChanging - s.logger.Info("Initial prompt sent successfully") + if convertStatus(currentStatus) == AgentStatusStable { + + if !s.stateLoadComplete && s.statePersistenceCfg.LoadState { + _, err := s.conversation.LoadState(s.statePersistenceCfg.StateFile) + if err != nil { + s.logger.Warn("Failed to load state file", "path", s.statePersistenceCfg.StateFile, "err", err) + } else { + s.logger.Info("Successfully loaded state", "path", s.statePersistenceCfg.StateFile) + } + s.stateLoadComplete = true + } + if !s.conversation.InitialPromptSent { + if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { + s.logger.Error("Failed to send initial prompt", "error", err) + } else { + s.conversation.InitialPromptSent = true + s.conversation.ReadyForInitialPrompt = false + currentStatus = st.ConversationStatusChanging + s.logger.Info("Initial prompt sent successfully") + } } } s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) @@ -592,6 +617,15 @@ func (s *Server) Start() error { // Stop gracefully stops the HTTP server func (s *Server) Stop(ctx context.Context) error { + // Save conversation state if configured + if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { + s.logger.Error("Failed to save conversation state", "error", err) + } else { + s.logger.Info("Saved conversation state", "stateFile", s.statePersistenceCfg.StateFile) + } + } + // Clean up temporary directory s.cleanupTempDir() @@ -610,6 +644,58 @@ func (s *Server) cleanupTempDir() { } } +// HandleSignals sets up signal handlers for: +// - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process +// - SIGUSR1: save conversation state without exiting +func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { + // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) + go func() { + sig := <-shutdownCh + s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) + + // Save conversation state if configured (synchronously before closing process) + if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { + s.logger.Error("Failed to save conversation state on signal", "signal", sig, "error", err) + } else { + s.logger.Info("Saved conversation state on signal", "signal", sig, "stateFile", s.statePersistenceCfg.StateFile) + } + } + + // Now close the process + if err := process.Close(s.logger, 5*time.Second); err != nil { + s.logger.Error("Error closing process", "signal", sig, "error", err) + } + }() + + // Handle SIGUSR1 for save without exit + saveOnlyCh := make(chan os.Signal, 1) + signal.Notify(saveOnlyCh, syscall.SIGUSR1) + go func() { + for { + select { + case <-saveOnlyCh: + s.logger.Info("Received SIGUSR1, saving state without exiting") + + // Save conversation state if configured + if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { + s.logger.Error("Failed to save conversation state on SIGUSR1", "error", err) + } else { + s.logger.Info("Saved conversation state on SIGUSR1", "stateFile", s.statePersistenceCfg.StateFile) + } + } else { + s.logger.Warn("SIGUSR1 received but state saving is not configured") + } + case <-ctx.Done(): + return + } + } + }() +} + // registerStaticFileRoutes sets up routes for serving static files func (s *Server) registerStaticFileRoutes() { chatHandler := FileServerWithIndexFallback(s.chatBasePath) diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 16203041..c8d95b6e 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -4,10 +4,7 @@ import ( "context" "fmt" "os" - "os/signal" "strings" - "syscall" - "time" "github.com/coder/agentapi/lib/logctx" mf "github.com/coder/agentapi/lib/msgfmt" @@ -45,16 +42,5 @@ func SetupProcess(ctx context.Context, config SetupProcessConfig) (*termexec.Pro return nil, err } } - - // Handle SIGINT (Ctrl+C) and send it to the process - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) - go func() { - <-signalCh - if err := process.Close(logger, 5*time.Second); err != nil { - logger.Error("Error closing process", "error", err) - } - }() - return process, nil } diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index db8d82d1..daf129a1 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -52,6 +52,8 @@ type MessagePart interface { // Conversation allows tracking of a conversation between a user and an agent. type Conversation interface { Messages() []ConversationMessage + SaveState([]ConversationMessage, string) error + LoadState(string) ([]ConversationMessage, error) Snapshot(string) Start(context.Context) Status() ConversationStatus @@ -64,3 +66,10 @@ type ConversationMessage struct { Role ConversationRole Time time.Time } + +type AgentState struct { + Version int `json:"version"` + Messages []ConversationMessage `json:"messages"` + InitialPrompt string `json:"initial_prompt"` + InitialPromptSent bool `json:"initial_prompt_sent"` +} diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 91b956a1..2f8e804f 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -2,8 +2,11 @@ package screentracker import ( "context" + "encoding/json" "fmt" "log/slog" + "os" + "path/filepath" "strings" "sync" "time" @@ -97,6 +100,12 @@ type PTYConversation struct { ReadyForInitialPrompt bool // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message toolCallMessageSet map[string]bool + // dirty tracks whether the conversation state has changed since the last save + dirty bool + // firstStableSnapshot is the conversation history rolled out by the agent in case of a resume (given that the agent supports it) + firstStableSnapshot string + // userSentMessageAfterLoadState tracks if the user has sent their first message after we load the state + userSentMessageAfterLoadState bool } var _ Conversation = &PTYConversation{} @@ -161,6 +170,7 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp if c.cfg.FormatMessage != nil { agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) } + agentMessage = c.skipInitialSnapshot(agentMessage) if c.cfg.FormatToolCall != nil { agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) } @@ -190,6 +200,7 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp c.messages[len(c.messages)-1] = conversationMessage } c.messages[len(c.messages)-1].Id = len(c.messages) - 1 + c.dirty = true } func (c *PTYConversation) Snapshot(screen string) { @@ -246,6 +257,9 @@ func (c *PTYConversation) Send(messageParts ...MessagePart) error { Role: ConversationRoleUser, Time: now, }) + c.dirty = true + c.userSentMessageAfterLoadState = true + return nil } @@ -369,3 +383,105 @@ func (c *PTYConversation) String() string { } return snapshots[len(snapshots)-1].screen } + +func (c *PTYConversation) SaveState(conversation []ConversationMessage, stateFile string) error { + c.lock.Lock() + defer c.lock.Unlock() + + // Skip if state file is not configured + if stateFile == "" { + return nil + } + + // Skip if not dirty + if !c.dirty { + return nil + } + + // Use atomic write: write to temp file, then rename to target path + data, err := json.MarshalIndent(AgentState{ + Version: 1, + Messages: conversation, + InitialPrompt: c.InitialPrompt, + InitialPromptSent: c.InitialPromptSent, + }, "", " ") + if err != nil { + return xerrors.Errorf("failed to marshal state: %w", err) + } + + // Create directory if it doesn't exist + dir := filepath.Dir(stateFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + return xerrors.Errorf("failed to create state directory: %w", err) + } + + // Write to temp file + tempFile := stateFile + ".tmp" + if err := os.WriteFile(tempFile, data, 0o644); err != nil { + return xerrors.Errorf("failed to write temp state file: %w", err) + } + + // Atomic rename + if err := os.Rename(tempFile, stateFile); err != nil { + return xerrors.Errorf("failed to rename state file: %w", err) + } + + // Clear dirty flag after successful save + c.dirty = false + return nil +} + +func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, error) { + c.lock.Lock() + defer c.lock.Unlock() + + // Skip if state file is not configured + if stateFile == "" { + return nil, nil + } + + // Check if file exists + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + return nil, nil + } + + // Read state file + data, err := os.ReadFile(stateFile) + if err != nil { + return nil, xerrors.Errorf("failed to read state file: %w", err) + } + + if len(data) == 0 { + return nil, xerrors.Errorf("failed to read state file: empty state file") + } + + var agentState AgentState + if err := json.Unmarshal(data, &agentState); err != nil { + return nil, xerrors.Errorf("failed to unmarshal state: %w", err) + } + + c.InitialPromptSent = agentState.InitialPromptSent + c.InitialPrompt = agentState.InitialPrompt + c.messages = agentState.Messages + + // Store the first stable snapshot for filtering later + snapshots := c.snapshotBuffer.GetAll() + if len(snapshots) > 0 { + c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") + } + + return c.messages, nil +} + +func (c *PTYConversation) skipInitialSnapshot(screen string) string { + newScreen := strings.ReplaceAll(screen, c.firstStableSnapshot, "") + + // Before the first user message after loading state, return the last message from the loaded state. + // This prevents computing incorrect diffs from the restored screen, as the agent's message should + // remain stable until the user continues the conversation. + if c.userSentMessageAfterLoadState == false { + newScreen = "\n" + c.messages[len(c.messages)-1].Message + } + + return newScreen +} From ca3cdff8c019238b02f5f38644ae96d074b0f2dc Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Sat, 31 Jan 2026 21:56:44 +0530 Subject: [PATCH 03/14] feat: pid file writing and clearing and improved error handling for load state --- lib/httpapi/server.go | 57 +++++++++++++++++++++++---- lib/screentracker/pty_conversation.go | 9 ++++- 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 965fa28f..e895c4ec 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -350,16 +350,11 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { for { currentStatus := s.conversation.Status() - // Send initial prompt when agent becomes stable for the first time + // Send initial prompt & load state when agent becomes stable for the first time if convertStatus(currentStatus) == AgentStatusStable { if !s.stateLoadComplete && s.statePersistenceCfg.LoadState { - _, err := s.conversation.LoadState(s.statePersistenceCfg.StateFile) - if err != nil { - s.logger.Warn("Failed to load state file", "path", s.statePersistenceCfg.StateFile, "err", err) - } else { - s.logger.Info("Successfully loaded state", "path", s.statePersistenceCfg.StateFile) - } + _, _ = s.conversation.LoadState(s.statePersistenceCfg.StateFile) s.stateLoadComplete = true } if !s.conversation.InitialPromptSent { @@ -612,6 +607,11 @@ func (s *Server) Start() error { Handler: s.router, } + // Write PID file if configured + if err := s.writePIDFile(); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + return s.srv.ListenAndServe() } @@ -626,6 +626,9 @@ func (s *Server) Stop(ctx context.Context) error { } } + // Clean up PID file + s.cleanupPIDFile() + // Clean up temporary directory s.cleanupTempDir() @@ -644,6 +647,43 @@ func (s *Server) cleanupTempDir() { } } +// writePIDFile writes the current process ID to the configured PID file +func (s *Server) writePIDFile() error { + if s.statePersistenceCfg.PidFile == "" { + return nil + } + + pid := os.Getpid() + pidContent := fmt.Sprintf("%d\n", pid) + + // Create directory if it doesn't exist + dir := filepath.Dir(s.statePersistenceCfg.PidFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + return xerrors.Errorf("failed to create PID file directory: %w", err) + } + + // Write PID file + if err := os.WriteFile(s.statePersistenceCfg.PidFile, []byte(pidContent), 0o644); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + + s.logger.Info("Wrote PID file", "pidFile", s.statePersistenceCfg.PidFile, "pid", pid) + return nil +} + +// cleanupPIDFile removes the PID file if it exists +func (s *Server) cleanupPIDFile() { + if s.statePersistenceCfg.PidFile == "" { + return + } + + if err := os.Remove(s.statePersistenceCfg.PidFile); err != nil && !os.IsNotExist(err) { + s.logger.Error("Failed to remove PID file", "pidFile", s.statePersistenceCfg.PidFile, "error", err) + } else if err == nil { + s.logger.Info("Removed PID file", "pidFile", s.statePersistenceCfg.PidFile) + } +} + // HandleSignals sets up signal handlers for: // - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process // - SIGUSR1: save conversation state without exiting @@ -664,6 +704,9 @@ func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { } } + // Clean up PID file + s.cleanupPIDFile() + // Now close the process if err := process.Close(s.logger, 5*time.Second); err != nil { s.logger.Error("Error closing process", "signal", sig, "error", err) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 2f8e804f..abfc886e 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -442,22 +442,26 @@ func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, er // Check if file exists if _, err := os.Stat(stateFile); os.IsNotExist(err) { + c.cfg.Logger.Info("No previous state to load (file does not exist)", "path", stateFile) return nil, nil } // Read state file data, err := os.ReadFile(stateFile) if err != nil { + c.cfg.Logger.Warn("Failed to load state file", "path", stateFile, "err", err) return nil, xerrors.Errorf("failed to read state file: %w", err) } if len(data) == 0 { - return nil, xerrors.Errorf("failed to read state file: empty state file") + c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) + return nil, nil } var agentState AgentState if err := json.Unmarshal(data, &agentState); err != nil { - return nil, xerrors.Errorf("failed to unmarshal state: %w", err) + c.cfg.Logger.Warn("Failed to load state file (corrupted or invalid JSON)", "path", stateFile, "err", err) + return nil, xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) } c.InitialPromptSent = agentState.InitialPromptSent @@ -470,6 +474,7 @@ func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, er c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") } + c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) return c.messages, nil } From 1c224e9235a7187f74d1d3a09a36f46423cc53cc Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Sat, 31 Jan 2026 22:19:39 +0530 Subject: [PATCH 04/14] refactor: remove redundant save logic --- lib/httpapi/server.go | 12 ------------ lib/screentracker/pty_conversation.go | 1 - 2 files changed, 13 deletions(-) diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index e895c4ec..68be4b4a 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -617,18 +617,6 @@ func (s *Server) Start() error { // Stop gracefully stops the HTTP server func (s *Server) Stop(ctx context.Context) error { - // Save conversation state if configured - if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { - if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { - s.logger.Error("Failed to save conversation state", "error", err) - } else { - s.logger.Info("Saved conversation state", "stateFile", s.statePersistenceCfg.StateFile) - } - } - - // Clean up PID file - s.cleanupPIDFile() - // Clean up temporary directory s.cleanupTempDir() diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index abfc886e..9c603b04 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -257,7 +257,6 @@ func (c *PTYConversation) Send(messageParts ...MessagePart) error { Role: ConversationRoleUser, Time: now, }) - c.dirty = true c.userSentMessageAfterLoadState = true return nil From 30f82d7c4d7080a3d6b657d5bdc8b3f3e4fc01e7 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Mon, 2 Feb 2026 16:18:40 +0530 Subject: [PATCH 05/14] feat: improve logic for first run with empty state file --- lib/screentracker/pty_conversation.go | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 9c603b04..a8f73c0a 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -106,6 +106,8 @@ type PTYConversation struct { firstStableSnapshot string // userSentMessageAfterLoadState tracks if the user has sent their first message after we load the state userSentMessageAfterLoadState bool + // loadStateSuccessful indicates whether conversation state was successfully restored from file. + loadStateSuccessful bool } var _ Conversation = &PTYConversation{} @@ -123,9 +125,13 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig, initialPrompt string Time: cfg.GetTime(), }, }, - InitialPrompt: initialPrompt, - InitialPromptSent: len(initialPrompt) == 0, - toolCallMessageSet: make(map[string]bool), + InitialPrompt: initialPrompt, + InitialPromptSent: len(initialPrompt) == 0, + toolCallMessageSet: make(map[string]bool), + dirty: false, + firstStableSnapshot: "", + userSentMessageAfterLoadState: false, + loadStateSuccessful: false, } return c } @@ -170,7 +176,9 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp if c.cfg.FormatMessage != nil { agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) } - agentMessage = c.skipInitialSnapshot(agentMessage) + if c.loadStateSuccessful { + agentMessage = c.adjustScreenAfterStateLoad(agentMessage) + } if c.cfg.FormatToolCall != nil { agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) } @@ -473,12 +481,13 @@ func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, er c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") } + c.loadStateSuccessful = true c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) return c.messages, nil } -func (c *PTYConversation) skipInitialSnapshot(screen string) string { - newScreen := strings.ReplaceAll(screen, c.firstStableSnapshot, "") +func (c *PTYConversation) adjustScreenAfterStateLoad(screen string) string { + newScreen := strings.Replace(screen, c.firstStableSnapshot, "", 1) // Before the first user message after loading state, return the last message from the loaded state. // This prevents computing incorrect diffs from the restored screen, as the agent's message should From 12bed1c23355eb6a53b57cb412e9ed7458dcf29b Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 3 Feb 2026 16:27:53 +0530 Subject: [PATCH 06/14] feat: implement platform-specific signal handling --- lib/httpapi/server.go | 72 ++++++++------------------- lib/httpapi/server_signals_unix.go | 42 ++++++++++++++++ lib/httpapi/server_signals_windows.go | 26 ++++++++++ 3 files changed, 89 insertions(+), 51 deletions(-) create mode 100644 lib/httpapi/server_signals_unix.go create mode 100644 lib/httpapi/server_signals_windows.go diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 68be4b4a..cb761189 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -11,13 +11,11 @@ import ( "net/http" "net/url" "os" - "os/signal" "path/filepath" "slices" "sort" "strings" "sync" - "syscall" "time" "unicode" @@ -672,59 +670,31 @@ func (s *Server) cleanupPIDFile() { } } -// HandleSignals sets up signal handlers for: -// - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process -// - SIGUSR1: save conversation state without exiting -func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { - // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) - shutdownCh := make(chan os.Signal, 1) - signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) - go func() { - sig := <-shutdownCh - s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) - - // Save conversation state if configured (synchronously before closing process) - if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { - if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { - s.logger.Error("Failed to save conversation state on signal", "signal", sig, "error", err) - } else { - s.logger.Info("Saved conversation state on signal", "signal", sig, "stateFile", s.statePersistenceCfg.StateFile) - } - } +// saveAndCleanup saves the conversation state and cleans up before shutdown +func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { + // Save conversation state if configured (synchronously before closing process) + s.saveStateIfConfigured(sig.String()) - // Clean up PID file - s.cleanupPIDFile() + // Clean up PID file + s.cleanupPIDFile() - // Now close the process - if err := process.Close(s.logger, 5*time.Second); err != nil { - s.logger.Error("Error closing process", "signal", sig, "error", err) - } - }() + // Now close the process + if err := process.Close(s.logger, 5*time.Second); err != nil { + s.logger.Error("Error closing process", "signal", sig, "error", err) + } +} - // Handle SIGUSR1 for save without exit - saveOnlyCh := make(chan os.Signal, 1) - signal.Notify(saveOnlyCh, syscall.SIGUSR1) - go func() { - for { - select { - case <-saveOnlyCh: - s.logger.Info("Received SIGUSR1, saving state without exiting") - - // Save conversation state if configured - if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { - if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { - s.logger.Error("Failed to save conversation state on SIGUSR1", "error", err) - } else { - s.logger.Info("Saved conversation state on SIGUSR1", "stateFile", s.statePersistenceCfg.StateFile) - } - } else { - s.logger.Warn("SIGUSR1 received but state saving is not configured") - } - case <-ctx.Done(): - return - } +// saveStateIfConfigured saves the conversation state if configured +func (s *Server) saveStateIfConfigured(source string) { + if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { + s.logger.Error("Failed to save conversation state", "source", source, "error", err) + } else { + s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceCfg.StateFile) } - }() + } else { + s.logger.Warn("Save requested but state saving is not configured", "source", source) + } } // registerStaticFileRoutes sets up routes for serving static files diff --git a/lib/httpapi/server_signals_unix.go b/lib/httpapi/server_signals_unix.go new file mode 100644 index 00000000..bfc6eaa9 --- /dev/null +++ b/lib/httpapi/server_signals_unix.go @@ -0,0 +1,42 @@ +//go:build unix + +package httpapi + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/termexec" +) + +// HandleSignals sets up signal handlers for: +// - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process +// - SIGUSR1: save conversation state without exiting +func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { + // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) + go func() { + sig := <-shutdownCh + s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) + + s.saveAndCleanup(sig, process) + }() + + // Handle SIGUSR1 for save without exit + saveOnlyCh := make(chan os.Signal, 1) + signal.Notify(saveOnlyCh, syscall.SIGUSR1) + go func() { + for { + select { + case <-saveOnlyCh: + s.logger.Info("Received SIGUSR1, saving state without exiting") + s.saveStateIfConfigured("SIGUSR1") + case <-ctx.Done(): + return + } + } + }() +} diff --git a/lib/httpapi/server_signals_windows.go b/lib/httpapi/server_signals_windows.go new file mode 100644 index 00000000..ea07c6ad --- /dev/null +++ b/lib/httpapi/server_signals_windows.go @@ -0,0 +1,26 @@ +//go:build windows + +package httpapi + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/termexec" +) + +// HandleSignals sets up signal handlers for Windows. +// Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). +func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { + // Handle shutdown signals (SIGTERM, SIGINT only on Windows) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) + go func() { + sig := <-shutdownCh + s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) + + s.saveAndCleanup(sig, process) + }() +} From e366e8b8f512a5ffde415063aa04261f4b4f3dd0 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 5 Feb 2026 18:11:03 +0530 Subject: [PATCH 07/14] feat: refactor cfg -> Config and move pid ops to server --- cmd/server/server.go | 42 +++++++++++++- lib/httpapi/server.go | 126 ++++++++++++++---------------------------- 2 files changed, 80 insertions(+), 88 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index 5a125af4..561877e8 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -8,6 +8,7 @@ import ( "log/slog" "net/http" "os" + "path/filepath" "sort" "strings" @@ -123,6 +124,15 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er pidFile := viper.GetString(PidFile) + // Write PID file if configured + if pidFile != "" { + if err := writePIDFile(pidFile, logger); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + // Ensure PID file is cleaned up on exit + defer cleanupPIDFile(pidFile, logger) + } + printOpenAPI := viper.GetBool(FlagPrintOpenAPI) var process *termexec.Process if printOpenAPI { @@ -148,11 +158,10 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), InitialPrompt: initialPrompt, - StatePersistenceCfg: httpapi.StatePersistenceCfg{ + StatePersistenceConfig: httpapi.StatePersistenceConfig{ StateFile: stateFile, LoadState: loadState, SaveState: saveState, - PidFile: pidFile, }, }) @@ -200,6 +209,35 @@ var agentNames = (func() []string { return names })() +// writePIDFile writes the current process ID to the specified file +func writePIDFile(pidFile string, logger *slog.Logger) error { + pid := os.Getpid() + pidContent := fmt.Sprintf("%d\n", pid) + + // Create directory if it doesn't exist + dir := filepath.Dir(pidFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + return xerrors.Errorf("failed to create PID file directory: %w", err) + } + + // Write PID file + if err := os.WriteFile(pidFile, []byte(pidContent), 0o644); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + + logger.Info("Wrote PID file", "pidFile", pidFile, "pid", pid) + return nil +} + +// cleanupPIDFile removes the PID file if it exists +func cleanupPIDFile(pidFile string, logger *slog.Logger) { + if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { + logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err) + } else if err == nil { + logger.Info("Removed PID file", "pidFile", pidFile) + } +} + type flagSpec struct { name string shorthand string diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index cb761189..d234bad0 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -34,20 +34,20 @@ import ( // Server represents the HTTP server type Server struct { - router chi.Router - api huma.API - port int - srv *http.Server - mu sync.RWMutex - logger *slog.Logger - conversation *st.PTYConversation - agentio *termexec.Process - agentType mf.AgentType - emitter *EventEmitter - chatBasePath string - tempDir string - statePersistenceCfg StatePersistenceCfg - stateLoadComplete bool + router chi.Router + api huma.API + port int + srv *http.Server + mu sync.RWMutex + logger *slog.Logger + conversation *st.PTYConversation + agentio *termexec.Process + agentType mf.AgentType + emitter *EventEmitter + chatBasePath string + tempDir string + statePersistenceConfig StatePersistenceConfig + stateLoadComplete bool } func (s *Server) NormalizeSchema(schema any) any { @@ -96,22 +96,21 @@ func (s *Server) GetOpenAPI() string { // because the action of taking a snapshot takes time too. const snapshotInterval = 25 * time.Millisecond -type StatePersistenceCfg struct { +type StatePersistenceConfig struct { StateFile string LoadState bool SaveState bool - PidFile string } type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string - AllowedHosts []string - AllowedOrigins []string - InitialPrompt string - StatePersistenceCfg StatePersistenceCfg + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string + InitialPrompt string + StatePersistenceConfig StatePersistenceConfig } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -270,18 +269,18 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { logger.Info("Created temporary directory for uploads", "tempDir", tempDir) s := &Server{ - router: router, - api: api, - port: config.Port, - conversation: conversation, - logger: logger, - agentio: config.Process, - agentType: config.AgentType, - emitter: emitter, - chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), - tempDir: tempDir, - statePersistenceCfg: config.StatePersistenceCfg, - stateLoadComplete: false, + router: router, + api: api, + port: config.Port, + conversation: conversation, + logger: logger, + agentio: config.Process, + agentType: config.AgentType, + emitter: emitter, + chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), + tempDir: tempDir, + statePersistenceConfig: config.StatePersistenceConfig, + stateLoadComplete: false, } // Register API routes @@ -351,8 +350,8 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { // Send initial prompt & load state when agent becomes stable for the first time if convertStatus(currentStatus) == AgentStatusStable { - if !s.stateLoadComplete && s.statePersistenceCfg.LoadState { - _, _ = s.conversation.LoadState(s.statePersistenceCfg.StateFile) + if !s.stateLoadComplete && s.statePersistenceConfig.LoadState { + _, _ = s.conversation.LoadState(s.statePersistenceConfig.StateFile) s.stateLoadComplete = true } if !s.conversation.InitialPromptSent { @@ -605,11 +604,6 @@ func (s *Server) Start() error { Handler: s.router, } - // Write PID file if configured - if err := s.writePIDFile(); err != nil { - return xerrors.Errorf("failed to write PID file: %w", err) - } - return s.srv.ListenAndServe() } @@ -633,51 +627,11 @@ func (s *Server) cleanupTempDir() { } } -// writePIDFile writes the current process ID to the configured PID file -func (s *Server) writePIDFile() error { - if s.statePersistenceCfg.PidFile == "" { - return nil - } - - pid := os.Getpid() - pidContent := fmt.Sprintf("%d\n", pid) - - // Create directory if it doesn't exist - dir := filepath.Dir(s.statePersistenceCfg.PidFile) - if err := os.MkdirAll(dir, 0o755); err != nil { - return xerrors.Errorf("failed to create PID file directory: %w", err) - } - - // Write PID file - if err := os.WriteFile(s.statePersistenceCfg.PidFile, []byte(pidContent), 0o644); err != nil { - return xerrors.Errorf("failed to write PID file: %w", err) - } - - s.logger.Info("Wrote PID file", "pidFile", s.statePersistenceCfg.PidFile, "pid", pid) - return nil -} - -// cleanupPIDFile removes the PID file if it exists -func (s *Server) cleanupPIDFile() { - if s.statePersistenceCfg.PidFile == "" { - return - } - - if err := os.Remove(s.statePersistenceCfg.PidFile); err != nil && !os.IsNotExist(err) { - s.logger.Error("Failed to remove PID file", "pidFile", s.statePersistenceCfg.PidFile, "error", err) - } else if err == nil { - s.logger.Info("Removed PID file", "pidFile", s.statePersistenceCfg.PidFile) - } -} - // saveAndCleanup saves the conversation state and cleans up before shutdown func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { // Save conversation state if configured (synchronously before closing process) s.saveStateIfConfigured(sig.String()) - // Clean up PID file - s.cleanupPIDFile() - // Now close the process if err := process.Close(s.logger, 5*time.Second); err != nil { s.logger.Error("Error closing process", "signal", sig, "error", err) @@ -686,11 +640,11 @@ func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { // saveStateIfConfigured saves the conversation state if configured func (s *Server) saveStateIfConfigured(source string) { - if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { - if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { + if s.statePersistenceConfig.SaveState && s.statePersistenceConfig.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceConfig.StateFile); err != nil { s.logger.Error("Failed to save conversation state", "source", source, "error", err) } else { - s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceCfg.StateFile) + s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceConfig.StateFile) } } else { s.logger.Warn("Save requested but state saving is not configured", "source", source) From 26fdf818438c6837debb644c153829a64cb76f81 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 5 Feb 2026 18:15:25 +0530 Subject: [PATCH 08/14] feat: unregister the signal handlers on teardown --- lib/httpapi/server_signals_unix.go | 2 ++ lib/httpapi/server_signals_windows.go | 1 + 2 files changed, 3 insertions(+) diff --git a/lib/httpapi/server_signals_unix.go b/lib/httpapi/server_signals_unix.go index bfc6eaa9..837db86c 100644 --- a/lib/httpapi/server_signals_unix.go +++ b/lib/httpapi/server_signals_unix.go @@ -19,6 +19,7 @@ func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { shutdownCh := make(chan os.Signal, 1) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) go func() { + defer signal.Stop(shutdownCh) sig := <-shutdownCh s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) @@ -29,6 +30,7 @@ func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { saveOnlyCh := make(chan os.Signal, 1) signal.Notify(saveOnlyCh, syscall.SIGUSR1) go func() { + defer signal.Stop(saveOnlyCh) for { select { case <-saveOnlyCh: diff --git a/lib/httpapi/server_signals_windows.go b/lib/httpapi/server_signals_windows.go index ea07c6ad..503e56a9 100644 --- a/lib/httpapi/server_signals_windows.go +++ b/lib/httpapi/server_signals_windows.go @@ -18,6 +18,7 @@ func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { shutdownCh := make(chan os.Signal, 1) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) go func() { + defer signal.Stop(shutdownCh) sig := <-shutdownCh s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) From 5795db7235a1436afc8a4ecd343ad072303dab23 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 14:27:51 +0530 Subject: [PATCH 09/14] feat: resolve conflicts and improve shutdown sequence --- cmd/server/server.go | 26 ++++- cmd/server/signals.go | 36 ++++++ .../server/signals_unix.go | 22 ++-- .../server/signals_windows.go | 12 +- lib/httpapi/server.go | 104 +++++------------- lib/screentracker/conversation.go | 7 ++ lib/screentracker/pty_conversation.go | 97 +++++++++++----- 7 files changed, 179 insertions(+), 125 deletions(-) create mode 100644 cmd/server/signals.go rename lib/httpapi/server_signals_unix.go => cmd/server/signals_unix.go (52%) rename lib/httpapi/server_signals_windows.go => cmd/server/signals_windows.go (60%) diff --git a/cmd/server/server.go b/cmd/server/server.go index b42c232a..46b21e26 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -11,7 +11,9 @@ import ( "path/filepath" "sort" "strings" + "time" + "github.com/coder/agentapi/lib/screentracker" "github.com/mattn/go-isatty" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -106,8 +108,10 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er // Get the variables related to state management stateFile := viper.GetString(StateFile) - loadState := true - saveState := true + loadState := false + saveState := false + + // Validate state file configuration if stateFile != "" { if !viper.IsSet(LoadState) { loadState = true @@ -120,6 +124,14 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } else { saveState = viper.GetBool(SaveState) } + } else { + // No state file provided - ensure load/save flags are not explicitly set to true + if viper.IsSet(LoadState) && viper.GetBool(LoadState) { + return xerrors.Errorf("--load-state requires --state-file to be set") + } + if viper.IsSet(SaveState) && viper.GetBool(SaveState) { + return xerrors.Errorf("--save-state requires --state-file to be set") + } } pidFile := viper.GetString(PidFile) @@ -158,7 +170,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), InitialPrompt: initialPrompt, - StatePersistenceConfig: httpapi.StatePersistenceConfig{ + StatePersistenceConfig: screentracker.StatePersistenceConfig{ StateFile: stateFile, LoadState: loadState, SaveState: saveState, @@ -172,7 +184,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er fmt.Println(srv.GetOpenAPI()) return nil } - srv.HandleSignals(ctx, process) + handleSignals(ctx, logger, srv, process) logger.Info("Starting server on port", "port", port) processExitCh := make(chan error, 1) go func() { @@ -184,8 +196,10 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) } } - if err := srv.Stop(ctx); err != nil { - logger.Error("Failed to stop server", "error", err) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Stop(shutdownCtx); err != nil { + logger.Error("Failed to stop server after process exit", "error", err) } }() if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { diff --git a/cmd/server/signals.go b/cmd/server/signals.go new file mode 100644 index 00000000..e66554e7 --- /dev/null +++ b/cmd/server/signals.go @@ -0,0 +1,36 @@ +package server + +import ( + "context" + "log/slog" + "os" + "time" + + "github.com/coder/agentapi/lib/httpapi" + "github.com/coder/agentapi/lib/termexec" +) + +// performGracefulShutdown handles the common shutdown logic for all platforms. +// It saves state, stops the HTTP server, closes the process, and exits. +func performGracefulShutdown(sig os.Signal, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { + logger.Info("Received shutdown signal, initiating graceful shutdown", "signal", sig) + + // Save state + if err := srv.SaveState(sig.String()); err != nil { + logger.Error("Failed to save state during shutdown", "signal", sig, "error", err) + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Stop(shutdownCtx); err != nil { + logger.Error("Failed to stop HTTP server", "signal", sig, "error", err) + } + + // Close the process + if err := process.Close(logger, 5*time.Second); err != nil { + logger.Error("Failed to close process cleanly", "signal", sig, "error", err) + } + + // Exit cleanly + os.Exit(0) +} diff --git a/lib/httpapi/server_signals_unix.go b/cmd/server/signals_unix.go similarity index 52% rename from lib/httpapi/server_signals_unix.go rename to cmd/server/signals_unix.go index 837db86c..fe6b4693 100644 --- a/lib/httpapi/server_signals_unix.go +++ b/cmd/server/signals_unix.go @@ -1,29 +1,29 @@ //go:build unix -package httpapi +package server import ( "context" + "log/slog" "os" "os/signal" "syscall" + "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/termexec" ) -// HandleSignals sets up signal handlers for: -// - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process +// handleSignals sets up signal handlers for: +// - SIGTERM, SIGINT, SIGHUP: save conversation state, stop server, then close the process // - SIGUSR1: save conversation state without exiting -func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { +func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) shutdownCh := make(chan os.Signal, 1) - signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT) go func() { defer signal.Stop(shutdownCh) sig := <-shutdownCh - s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) - - s.saveAndCleanup(sig, process) + performGracefulShutdown(sig, logger, srv, process) }() // Handle SIGUSR1 for save without exit @@ -34,8 +34,10 @@ func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { for { select { case <-saveOnlyCh: - s.logger.Info("Received SIGUSR1, saving state without exiting") - s.saveStateIfConfigured("SIGUSR1") + logger.Info("Received SIGUSR1, saving state without exiting") + if err := srv.SaveState("SIGUSR1"); err != nil { + logger.Error("Failed to save state on SIGUSR1", "error", err) + } case <-ctx.Done(): return } diff --git a/lib/httpapi/server_signals_windows.go b/cmd/server/signals_windows.go similarity index 60% rename from lib/httpapi/server_signals_windows.go rename to cmd/server/signals_windows.go index 503e56a9..52d90616 100644 --- a/lib/httpapi/server_signals_windows.go +++ b/cmd/server/signals_windows.go @@ -1,27 +1,27 @@ //go:build windows -package httpapi +package server import ( "context" + "log/slog" "os" "os/signal" "syscall" + "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/termexec" ) -// HandleSignals sets up signal handlers for Windows. +// handleSignals sets up signal handlers for Windows. // Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). -func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { +func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { // Handle shutdown signals (SIGTERM, SIGINT only on Windows) shutdownCh := make(chan os.Signal, 1) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) go func() { defer signal.Stop(shutdownCh) sig := <-shutdownCh - s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) - - s.saveAndCleanup(sig, process) + performGracefulShutdown(sig, logger, srv, process) }() } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 4deb4e38..32cd64cf 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -40,6 +40,7 @@ type Server struct { port int srv *http.Server mu sync.RWMutex + stopOnce sync.Once logger *slog.Logger conversation st.Conversation agentio *termexec.Process @@ -48,8 +49,6 @@ type Server struct { chatBasePath string tempDir string clock quartz.Clock - statePersistenceConfig StatePersistenceConfig - stateLoadComplete bool } func (s *Server) NormalizeSchema(schema any) any { @@ -98,22 +97,16 @@ func (s *Server) GetOpenAPI() string { // because the action of taking a snapshot takes time too. const snapshotInterval = 25 * time.Millisecond -type StatePersistenceConfig struct { - StateFile string - LoadState bool - SaveState bool -} - type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string - AllowedHosts []string - AllowedOrigins []string - InitialPrompt string - Clock quartz.Clock - StatePersistenceConfig StatePersistenceConfig + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string + InitialPrompt string + Clock quartz.Clock + StatePersistenceConfig st.StatePersistenceConfig } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -279,7 +272,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { emitter.UpdateMessagesAndEmitChanges(messages) emitter.UpdateScreenAndEmitChanges(screen) }, - Logger: logger, + Logger: logger, + StatePersistenceConfig: config.StatePersistenceConfig, }) // Create temporary directory for uploads @@ -301,8 +295,6 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), tempDir: tempDir, clock: config.Clock, - statePersistenceConfig: config.StatePersistenceConfig, - stateLoadComplete: false, } // Register API routes @@ -373,32 +365,6 @@ func sseMiddleware(ctx huma.Context, next func(huma.Context)) { next(ctx) } -func (s *Server) StartSnapshotLoop(ctx context.Context) { - s.conversation.StartSnapshotLoop(ctx) - go func() { - for { - currentStatus := s.conversation.Status() - - // Send initial prompt when agent becomes stable for the first time - if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { - - if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { - s.logger.Error("Failed to send initial prompt", "error", err) - } else { - s.conversation.InitialPromptSent = true - s.conversation.ReadyForInitialPrompt = false - currentStatus = st.ConversationStatusChanging - s.logger.Info("Initial prompt sent successfully") - } - } - s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) - s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages()) - s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen()) - time.Sleep(snapshotInterval) - } - }() -} - // registerRoutes sets up all API endpoints func (s *Server) registerRoutes() { // GET /status endpoint @@ -633,15 +599,19 @@ func (s *Server) Start() error { return s.srv.ListenAndServe() } -// Stop gracefully stops the HTTP server +// Stop gracefully stops the HTTP server. It is safe to call multiple times; +// only the first call will perform the shutdown, subsequent calls are no-ops. func (s *Server) Stop(ctx context.Context) error { - // Clean up temporary directory - s.cleanupTempDir() + var err error + s.stopOnce.Do(func() { + // Clean up temporary directory + s.cleanupTempDir() - if s.srv != nil { - return s.srv.Shutdown(ctx) - } - return nil + if s.srv != nil { + err = s.srv.Shutdown(ctx) + } + }) + return err } // cleanupTempDir removes the temporary directory and all its contents @@ -653,28 +623,14 @@ func (s *Server) cleanupTempDir() { } } -// saveAndCleanup saves the conversation state and cleans up before shutdown -func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { - // Save conversation state if configured (synchronously before closing process) - s.saveStateIfConfigured(sig.String()) - - // Now close the process - if err := process.Close(s.logger, 5*time.Second); err != nil { - s.logger.Error("Error closing process", "signal", sig, "error", err) - } -} - -// saveStateIfConfigured saves the conversation state if configured -func (s *Server) saveStateIfConfigured(source string) { - if s.statePersistenceConfig.SaveState && s.statePersistenceConfig.StateFile != "" { - if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceConfig.StateFile); err != nil { - s.logger.Error("Failed to save conversation state", "source", source, "error", err) - } else { - s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceConfig.StateFile) - } - } else { - s.logger.Warn("Save requested but state saving is not configured", "source", source) +// SaveState saves the conversation state if configured. This can be called from signal handlers. +// The source parameter indicates what triggered the save (e.g., "SIGTERM", "SIGUSR1"). +func (s *Server) SaveState(source string) error { + if err := s.conversation.SaveState(); err != nil { + s.logger.Error("Failed to save conversation state", "source", source, "error", err) + return err } + return nil } // registerStaticFileRoutes sets up routes for serving static files diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 9e6b856f..44e303f1 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -63,6 +63,7 @@ type Conversation interface { Start(context.Context) Status() ConversationStatus Text() string + SaveState() error } type ConversationMessage struct { @@ -71,3 +72,9 @@ type ConversationMessage struct { Role ConversationRole Time time.Time } + +type StatePersistenceConfig struct { + StateFile string + LoadState bool + SaveState bool +} diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 91ed7cba..fd3dedcf 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -29,6 +29,13 @@ type MessagePartText struct { Hidden bool } +type AgentState struct { + Version int `json:"version"` + Messages []ConversationMessage `json:"messages"` + InitialPrompt string `json:"initial_prompt"` + //InitialPromptSent bool `json:"initial_prompt_sent"` +} + var _ MessagePart = &MessagePartText{} func (p MessagePartText) Do(writer AgentIO) error { @@ -72,8 +79,9 @@ type PTYConversationConfig struct { // InitialPrompt is the initial prompt to send to the agent once ready InitialPrompt []MessagePart // OnSnapshot is called after each snapshot with current status, messages, and screen content - OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string) - Logger *slog.Logger + OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string) + Logger *slog.Logger + StatePersistenceConfig StatePersistenceConfig } func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { @@ -142,9 +150,9 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation { Time: cfg.Clock.Now(), }, }, - outboundQueue: make(chan outboundMessage, 1), - stableSignal: make(chan struct{}, 1), - toolCallMessageSet: make(map[string]bool), + outboundQueue: make(chan outboundMessage, 1), + stableSignal: make(chan struct{}, 1), + toolCallMessageSet: make(map[string]bool), dirty: false, firstStableSnapshot: "", userSentMessageAfterLoadState: false, @@ -178,6 +186,12 @@ func (c *PTYConversation) Start(ctx context.Context) { if !c.initialPromptReady && c.cfg.ReadyForInitialPrompt(screen) { c.initialPromptReady = true } + + if !c.loadStateSuccessful && c.cfg.StatePersistenceConfig.LoadState { + _ = c.loadState() + c.loadStateSuccessful = true + } + if c.initialPromptReady && len(c.outboundQueue) > 0 && c.isScreenStableLocked() { select { case c.stableSignal <- struct{}{}: @@ -284,6 +298,8 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp c.messages[len(c.messages)-1] = conversationMessage } c.messages[len(c.messages)-1].Id = len(c.messages) - 1 + + c.dirty = true } // caller MUST hold c.lock @@ -297,10 +313,6 @@ func (c *PTYConversation) snapshotLocked(screen string) { } func (c *PTYConversation) Send(messageParts ...MessagePart) error { - if !c.cfg.SkipSendMessageStatusCheck && c.statusLocked() != ConversationStatusStable { - return MessageValidationErrorChanging - } - // Validate message content before enqueueing var sb strings.Builder for _, part := range messageParts { @@ -514,26 +526,41 @@ func (c *PTYConversation) Text() string { return snapshots[len(snapshots)-1].screen } -func (c *PTYConversation) SaveState(conversation []ConversationMessage, stateFile string) error { +func (c *PTYConversation) SaveState() error { + conversation := c.Messages() + c.lock.Lock() defer c.lock.Unlock() - // Skip if state file is not configured - if stateFile == "" { + stateFile := c.cfg.StatePersistenceConfig.StateFile + saveState := c.cfg.StatePersistenceConfig.SaveState + + if !saveState { + c.cfg.Logger.Info("") return nil } // Skip if not dirty if !c.dirty { + c.cfg.Logger.Info("Skipping state save: no changes since last save") return nil } + // Serialize initial prompt from message parts + var initialPromptStr string + if len(c.cfg.InitialPrompt) > 0 { + var sb strings.Builder + for _, part := range c.cfg.InitialPrompt { + sb.WriteString(part.String()) + } + initialPromptStr = sb.String() + } + // Use atomic write: write to temp file, then rename to target path data, err := json.MarshalIndent(AgentState{ - Version: 1, - Messages: conversation, - InitialPrompt: c.InitialPrompt, - InitialPromptSent: c.InitialPromptSent, + Version: 1, + Messages: conversation, + InitialPrompt: initialPromptStr, }, "", " ") if err != nil { return xerrors.Errorf("failed to marshal state: %w", err) @@ -558,44 +585,51 @@ func (c *PTYConversation) SaveState(conversation []ConversationMessage, stateFil // Clear dirty flag after successful save c.dirty = false + + c.cfg.Logger.Info("State saved successfully to: %s", stateFile) + return nil } -func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, error) { - c.lock.Lock() - defer c.lock.Unlock() +// LoadState loads the state, this method assumes that caller holds the Lock +func (c *PTYConversation) loadState() error { + stateFile := c.cfg.StatePersistenceConfig.StateFile + loadState := c.cfg.StatePersistenceConfig.LoadState - // Skip if state file is not configured - if stateFile == "" { - return nil, nil + if !loadState { + return nil } // Check if file exists if _, err := os.Stat(stateFile); os.IsNotExist(err) { c.cfg.Logger.Info("No previous state to load (file does not exist)", "path", stateFile) - return nil, nil + return nil } // Read state file data, err := os.ReadFile(stateFile) if err != nil { c.cfg.Logger.Warn("Failed to load state file", "path", stateFile, "err", err) - return nil, xerrors.Errorf("failed to read state file: %w", err) + return xerrors.Errorf("failed to read state file: %w", err) } if len(data) == 0 { c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) - return nil, nil + return nil } var agentState AgentState if err := json.Unmarshal(data, &agentState); err != nil { c.cfg.Logger.Warn("Failed to load state file (corrupted or invalid JSON)", "path", stateFile, "err", err) - return nil, xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) + return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) } - c.InitialPromptSent = agentState.InitialPromptSent - c.InitialPrompt = agentState.InitialPrompt + //c.cfg.initialPromptSent = agentState.InitialPromptSent + c.cfg.InitialPrompt = []MessagePart{MessagePartText{ + Content: agentState.InitialPrompt, + Alias: "", + Hidden: false, + }} c.messages = agentState.Messages // Store the first stable snapshot for filtering later @@ -606,10 +640,15 @@ func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, er c.loadStateSuccessful = true c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) - return c.messages, nil + return nil } func (c *PTYConversation) adjustScreenAfterStateLoad(screen string) string { + + if c.firstStableSnapshot == "" { + return screen + } + newScreen := strings.Replace(screen, c.firstStableSnapshot, "", 1) // Before the first user message after loading state, return the last message from the loaded state. From 9deab88258e92521ca36cafac48e898e2d7ab651 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 15:08:43 +0530 Subject: [PATCH 10/14] feat: resolve conflicts --- lib/screentracker/pty_conversation.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 946a07fd..d40337c5 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -195,7 +195,7 @@ func (c *PTYConversation) Start(ctx context.Context) { c.initialPromptReady = true } - if !c.loadStateSuccessful && c.cfg.StatePersistenceConfig.LoadState { + if c.initialPromptReady && !c.loadStateSuccessful && c.cfg.StatePersistenceConfig.LoadState { _ = c.loadState() c.loadStateSuccessful = true } @@ -596,7 +596,7 @@ func (c *PTYConversation) SaveState() error { // Clear dirty flag after successful save c.dirty = false - c.cfg.Logger.Info("State saved successfully to: %s", stateFile) + c.cfg.Logger.Info(fmt.Sprintf("State saved successfully to: %s", stateFile)) return nil } From 18fb1e4bdf1dcc332461376f518740a0ebd9d5d1 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 15:13:33 +0530 Subject: [PATCH 11/14] chore: not dirty after load state --- lib/screentracker/pty_conversation.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index d40337c5..a8ddb5ea 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -649,6 +649,8 @@ func (c *PTYConversation) loadState() error { } c.loadStateSuccessful = true + c.dirty = false + c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) return nil } From b719dac58d869d4a9b82cac051afb058945f185d Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 15:55:21 +0530 Subject: [PATCH 12/14] feat: add tests --- cmd/server/server_test.go | 214 +++++++++++++ lib/httpapi/server_test.go | 34 ++ lib/screentracker/pty_conversation_test.go | 353 +++++++++++++++++++++ 3 files changed, 601 insertions(+) diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index bd07fc63..4affad0d 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -2,6 +2,8 @@ package server import ( "fmt" + "io" + "log/slog" "os" "strings" "testing" @@ -477,6 +479,218 @@ func TestServerCmd_AllowedHosts(t *testing.T) { } } +func TestServerCmd_StatePersistenceFlags(t *testing.T) { + // NOTE: These tests use --exit flag to test flag parsing and defaults. + // Runtime validation that happens in runServer (e.g., "--load-state requires --state-file") + // would call os.Exit(1) which terminates the test process, so those validations + // are tested through integration/E2E tests instead. + + t.Run("state-file with defaults", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + // load-state and save-state default to true when state-file is set (validated in runServer) + }) + + t.Run("state-file with explicit load-state=false", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--load-state=false", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, false, viper.GetBool(LoadState)) + }) + + t.Run("state-file with explicit save-state=false", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--save-state=false", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, false, viper.GetBool(SaveState)) + }) + + t.Run("state-file with explicit load-state=true and save-state=true", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{ + "--state-file", "/tmp/state.json", + "--load-state=true", + "--save-state=true", + "--exit", "dummy-command", + }) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, true, viper.GetBool(LoadState)) + assert.Equal(t, true, viper.GetBool(SaveState)) + }) + + t.Run("load-state flag can be parsed", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--load-state", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + // Flag is parsed correctly (validation happens in runServer) + assert.Equal(t, true, viper.GetBool(LoadState)) + }) + + t.Run("save-state flag can be parsed", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--save-state", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + // Flag is parsed correctly (validation happens in runServer) + assert.Equal(t, true, viper.GetBool(SaveState)) + }) + + t.Run("pid-file can be set independently", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--pid-file", "/tmp/server.pid", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/server.pid", viper.GetString(PidFile)) + }) + + t.Run("state-file and pid-file can be set together", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{ + "--state-file", "/tmp/state.json", + "--pid-file", "/tmp/server.pid", + "--exit", "dummy-command", + }) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, "/tmp/server.pid", viper.GetString(PidFile)) + }) +} + +func TestPIDFileOperations(t *testing.T) { + discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + t.Run("writePIDFile creates file with process ID", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + err := writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(pidFile) + require.NoError(t, err) + + // Verify content contains current PID + data, err := os.ReadFile(pidFile) + require.NoError(t, err) + + expectedPID := fmt.Sprintf("%d\n", os.Getpid()) + assert.Equal(t, expectedPID, string(data)) + }) + + t.Run("writePIDFile creates directory if not exists", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/nested/deep/test.pid" + + err := writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(pidFile) + require.NoError(t, err) + + // Verify directory was created + _, err = os.Stat(tmpDir + "/nested/deep") + require.NoError(t, err) + }) + + t.Run("writePIDFile overwrites existing file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Write initial PID file + err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + require.NoError(t, err) + + // Overwrite with current PID + err = writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify content is updated + data, err := os.ReadFile(pidFile) + require.NoError(t, err) + + expectedPID := fmt.Sprintf("%d\n", os.Getpid()) + assert.Equal(t, expectedPID, string(data)) + }) + + t.Run("cleanupPIDFile removes file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Create PID file + err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + require.NoError(t, err) + + // Cleanup + cleanupPIDFile(pidFile, discardLogger) + + // Verify file is removed + _, err = os.Stat(pidFile) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("cleanupPIDFile handles non-existent file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/nonexistent.pid" + + // Should not panic or error + cleanupPIDFile(pidFile, discardLogger) + }) + + t.Run("cleanupPIDFile handles directory removal error gracefully", func(t *testing.T) { + // Create a file in a protected directory (this is system-dependent) + // Just verify it doesn't panic when it can't remove the file + pidFile := "/this/should/not/exist/test.pid" + + // Should not panic + cleanupPIDFile(pidFile, discardLogger) + }) +} + func TestServerCmd_AllowedOrigins(t *testing.T) { tests := []struct { name string diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index c8e8b23c..82fc6713 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -13,6 +13,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/logctx" @@ -956,3 +957,36 @@ func TestServer_UploadFiles_Errors(t *testing.T) { require.Contains(t, string(body), "file size exceeds 10MB limit") }) } + +func TestServer_Stop_Idempotency(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + + srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, + AllowedOrigins: []string{"*"}, + }) + require.NoError(t, err) + + // First call to Stop should succeed + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err = srv.Stop(stopCtx) + require.NoError(t, err) + + // Second call to Stop should also succeed (no-op) + stopCtx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + err = srv.Stop(stopCtx2) + require.NoError(t, err) + + // Third call to Stop should also succeed (no-op) + stopCtx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel3() + err = srv.Stop(stopCtx3) + require.NoError(t, err) +} diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index 19b4511b..67ff1395 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -2,9 +2,11 @@ package screentracker_test import ( "context" + "encoding/json" "fmt" "io" "log/slog" + "os" "sync" "testing" "time" @@ -446,6 +448,357 @@ func TestMessages(t *testing.T) { }) } +func TestStatePersistence(t *testing.T) { + t.Run("SaveState creates file with correct structure", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + // Create temp directory for state file + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Generate some conversation + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Save state + err := c.SaveState() + require.NoError(t, err) + + // Read and verify the saved file + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.Equal(t, 1, agentState.Version) + assert.Equal(t, "test prompt", agentState.InitialPrompt) + assert.NotEmpty(t, agentState.Messages) + }) + + t.Run("SaveState skips when not configured", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + err := c.SaveState() + require.NoError(t, err) + + // File should not be created + _, err = os.Stat(stateFile) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("SaveState honors dirty flag", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Generate conversation and save + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + err := c.SaveState() + require.NoError(t, err) + + // Get file modification time + info1, err := os.Stat(stateFile) + require.NoError(t, err) + modTime1 := info1.ModTime() + + // Save again without changes - file should not be modified + err = c.SaveState() + require.NoError(t, err) + + info2, err := os.Stat(stateFile) + require.NoError(t, err) + modTime2 := info2.ModTime() + + // File modification time should be the same (dirty flag prevents save) + assert.Equal(t, modTime1, modTime2) + }) + + t.Run("SaveState creates directory if not exists", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/nested/deep/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + err := c.SaveState() + require.NoError(t, err) + + // Verify file and directory were created + _, err = os.Stat(stateFile) + assert.NoError(t, err) + }) + + t.Run("LoadState restores conversation from file", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with test data + testState := st.AgentState{ + Version: 1, + InitialPrompt: "restored prompt", + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message 1", Role: st.ConversationRoleAgent, Time: time.Now()}, + {Id: 1, Message: "user message 1", Role: st.ConversationRoleUser, Time: time.Now()}, + {Id: 2, Message: "agent message 2", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with LoadState enabled + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Advance until agent is ready and state is loaded + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Verify messages were restored + messages := c.Messages() + assert.Len(t, messages, 3) + assert.Equal(t, "agent message 1", messages[0].Message) + assert.Equal(t, "user message 1", messages[1].Message) + // The last agent message may have adjustments from adjustScreenAfterStateLoad + assert.Contains(t, messages[2].Message, "agent message 2") + }) + + t.Run("LoadState handles missing file gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/nonexistent.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic or error + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState handles empty file gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/empty.json" + + // Create empty file + err := os.WriteFile(stateFile, []byte(""), 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic or error + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState handles corrupted JSON gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/corrupted.json" + + // Create corrupted JSON file + err := os.WriteFile(stateFile, []byte("{invalid json}"), 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic - logs warning and continues + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) +} + func TestInitialPromptReadiness(t *testing.T) { discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) From 39590021d1e48d059dc87bd4b96d82d533544bc1 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 16:24:39 +0530 Subject: [PATCH 13/14] feat: remove comment --- lib/screentracker/pty_conversation.go | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index a8ddb5ea..a4b44124 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -33,7 +33,6 @@ type AgentState struct { Version int `json:"version"` Messages []ConversationMessage `json:"messages"` InitialPrompt string `json:"initial_prompt"` - //InitialPromptSent bool `json:"initial_prompt_sent"` } var _ MessagePart = &MessagePartText{} From 7e389d29234de38d963a70bb5dd544756a032021 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 16:28:11 +0530 Subject: [PATCH 14/14] feat: remove comments --- cmd/server/server.go | 1 - lib/httpapi/server.go | 23 ++++++++++------------- lib/screentracker/pty_conversation.go | 4 +--- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index 46b21e26..c20a833c 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -125,7 +125,6 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er saveState = viper.GetBool(SaveState) } } else { - // No state file provided - ensure load/save flags are not explicitly set to true if viper.IsSet(LoadState) && viper.GetBool(LoadState) { return xerrors.Errorf("--load-state requires --state-file to be set") } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 0243038e..29f0dce1 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -255,15 +255,15 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { } conversation := st.NewPTY(ctx, st.PTYConversationConfig{ - AgentType: config.AgentType, - AgentIO: config.Process, - Clock: config.Clock, - SnapshotInterval: snapshotInterval, - ScreenStabilityLength: 2 * time.Second, - FormatMessage: formatMessage, - ReadyForInitialPrompt: isAgentReadyForInitialPrompt, - FormatToolCall: formatToolCall, - InitialPrompt: initialPrompt, + AgentType: config.AgentType, + AgentIO: config.Process, + Clock: config.Clock, + SnapshotInterval: snapshotInterval, + ScreenStabilityLength: 2 * time.Second, + FormatMessage: formatMessage, + ReadyForInitialPrompt: isAgentReadyForInitialPrompt, + FormatToolCall: formatToolCall, + InitialPrompt: initialPrompt, Logger: logger, StatePersistenceConfig: config.StatePersistenceConfig, }, emitter) @@ -591,8 +591,7 @@ func (s *Server) Start() error { return s.srv.ListenAndServe() } -// Stop gracefully stops the HTTP server. It is safe to call multiple times; -// only the first call will perform the shutdown, subsequent calls are no-ops. +// Stop gracefully stops the HTTP server. It is safe to call multiple times. func (s *Server) Stop(ctx context.Context) error { var err error s.stopOnce.Do(func() { @@ -615,8 +614,6 @@ func (s *Server) cleanupTempDir() { } } -// SaveState saves the conversation state if configured. This can be called from signal handlers. -// The source parameter indicates what triggered the save (e.g., "SIGTERM", "SIGUSR1"). func (s *Server) SaveState(source string) error { if err := s.conversation.SaveState(); err != nil { s.logger.Error("Failed to save conversation state", "source", source, "error", err) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index a4b44124..e5b6feb4 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -76,9 +76,7 @@ type PTYConversationConfig struct { // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls FormatToolCall func(message string) (string, []string) // InitialPrompt is the initial prompt to send to the agent once ready - InitialPrompt []MessagePart - // OnSnapshot is called after each snapshot with current status, messages, and screen content - OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string) + InitialPrompt []MessagePart Logger *slog.Logger StatePersistenceConfig StatePersistenceConfig }