diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 73eff07..906a3a4 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -64,7 +64,7 @@ type EventEmitter struct { agentType mf.AgentType chans map[int]chan Event chanIdx int - subscriptionBufSize int + subscriptionBufSize uint screen string } @@ -81,20 +81,37 @@ func convertStatus(status st.ConversationStatus) AgentStatus { } } -// subscriptionBufSize is the size of the buffer for each subscription. -// Once the buffer is full, the channel will be closed. -// Listeners must actively drain the channel, so it's important to -// set this to a value that is large enough to handle the expected -// number of events. -func NewEventEmitter(subscriptionBufSize int) *EventEmitter { - return &EventEmitter{ - mu: sync.Mutex{}, +const defaultSubscriptionBufSize uint = 1024 + +type EventEmitterOption func(*EventEmitter) + +func WithSubscriptionBufSize(size uint) EventEmitterOption { + return func(e *EventEmitter) { + if size == 0 { + e.subscriptionBufSize = defaultSubscriptionBufSize + } else { + e.subscriptionBufSize = size + } + } +} + +func WithAgentType(agentType mf.AgentType) EventEmitterOption { + return func(e *EventEmitter) { + e.agentType = agentType + } +} + +func NewEventEmitter(opts ...EventEmitterOption) *EventEmitter { + e := &EventEmitter{ messages: make([]st.ConversationMessage, 0), status: AgentStatusRunning, chans: make(map[int]chan Event), - chanIdx: 0, - subscriptionBufSize: subscriptionBufSize, + subscriptionBufSize: defaultSubscriptionBufSize, + } + for _, opt := range opts { + opt(e) } + return e } // Assumes the caller holds the lock. @@ -122,7 +139,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) { // Assumes that only the last message can change or new messages can be added. // If a new message is injected between existing messages (identified by Id), the behavior is undefined. -func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.ConversationMessage) { +func (e *EventEmitter) EmitMessages(newMessages []st.ConversationMessage) { e.mu.Lock() defer e.mu.Unlock() @@ -137,6 +154,9 @@ func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.Conversatio newMsg = newMessages[i] } if oldMsg != newMsg { + if i >= len(newMessages) { + continue + } e.notifyChannels(EventTypeMessageUpdate, MessageUpdateBody{ Id: newMessages[i].Id, Role: newMessages[i].Role, @@ -149,7 +169,7 @@ func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.Conversatio e.messages = newMessages } -func (e *EventEmitter) UpdateStatusAndEmitChanges(newStatus st.ConversationStatus, agentType mf.AgentType) { +func (e *EventEmitter) EmitStatus(newStatus st.ConversationStatus) { e.mu.Lock() defer e.mu.Unlock() @@ -158,12 +178,11 @@ func (e *EventEmitter) UpdateStatusAndEmitChanges(newStatus st.ConversationStatu return } - e.notifyChannels(EventTypeStatusChange, StatusChangeBody{Status: newAgentStatus, AgentType: agentType}) + e.notifyChannels(EventTypeStatusChange, StatusChangeBody{Status: newAgentStatus, AgentType: e.agentType}) e.status = newAgentStatus - e.agentType = agentType } -func (e *EventEmitter) UpdateScreenAndEmitChanges(newScreen string) { +func (e *EventEmitter) EmitScreen(newScreen string) { e.mu.Lock() defer e.mu.Unlock() diff --git a/lib/httpapi/events_test.go b/lib/httpapi/events_test.go index 46ccea5..a1d024c 100644 --- a/lib/httpapi/events_test.go +++ b/lib/httpapi/events_test.go @@ -5,14 +5,13 @@ import ( "testing" "time" - mf "github.com/coder/agentapi/lib/msgfmt" st "github.com/coder/agentapi/lib/screentracker" "github.com/stretchr/testify/assert" ) func TestEventEmitter(t *testing.T) { t.Run("single-subscription", func(t *testing.T) { - emitter := NewEventEmitter(10) + emitter := NewEventEmitter(WithSubscriptionBufSize(10)) _, ch, stateEvents := emitter.Subscribe() assert.Empty(t, ch) assert.Equal(t, []Event{ @@ -27,7 +26,7 @@ func TestEventEmitter(t *testing.T) { }, stateEvents) now := time.Now() - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now}, }) newEvent := <-ch @@ -36,7 +35,7 @@ func TestEventEmitter(t *testing.T) { Payload: MessageUpdateBody{Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now}, }, newEvent) - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: 1, Message: "Hello, world! (updated)", Role: st.ConversationRoleUser, Time: now}, {Id: 2, Message: "What's up?", Role: st.ConversationRoleAgent, Time: now}, }) @@ -52,16 +51,16 @@ func TestEventEmitter(t *testing.T) { Payload: MessageUpdateBody{Id: 2, Message: "What's up?", Role: st.ConversationRoleAgent, Time: now}, }, newEvent) - emitter.UpdateStatusAndEmitChanges(st.ConversationStatusStable, mf.AgentTypeAider) + emitter.EmitStatus(st.ConversationStatusStable) newEvent = <-ch assert.Equal(t, Event{ Type: EventTypeStatusChange, - Payload: StatusChangeBody{Status: AgentStatusStable, AgentType: mf.AgentTypeAider}, + Payload: StatusChangeBody{Status: AgentStatusStable, AgentType: ""}, }, newEvent) }) t.Run("multiple-subscriptions", func(t *testing.T) { - emitter := NewEventEmitter(10) + emitter := NewEventEmitter(WithSubscriptionBufSize(10)) channels := make([]<-chan Event, 0, 10) for i := 0; i < 10; i++ { _, ch, _ := emitter.Subscribe() @@ -69,7 +68,7 @@ func TestEventEmitter(t *testing.T) { } now := time.Now() - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now}, }) for _, ch := range channels { @@ -82,10 +81,10 @@ func TestEventEmitter(t *testing.T) { }) t.Run("close-channel", func(t *testing.T) { - emitter := NewEventEmitter(1) + emitter := NewEventEmitter(WithSubscriptionBufSize(1)) _, ch, _ := emitter.Subscribe() for i := range 5 { - emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{ + emitter.EmitMessages([]st.ConversationMessage{ {Id: i, Message: fmt.Sprintf("Hello, world! %d", i), Role: st.ConversationRoleUser, Time: time.Now()}, }) } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index e43315b..956cfb8 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -244,7 +244,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { return mf.FormatToolCall(config.AgentType, message) } - emitter := NewEventEmitter(1024) + emitter := NewEventEmitter(WithAgentType(config.AgentType)) // Format initial prompt into message parts if provided var initialPrompt []st.MessagePart @@ -262,16 +262,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { ReadyForInitialPrompt: isAgentReadyForInitialPrompt, FormatToolCall: formatToolCall, InitialPrompt: initialPrompt, - // OnSnapshot uses a callback rather than passing the emitter directly - // to keep the screentracker package decoupled from httpapi concerns. - // This preserves clean package boundaries and avoids import cycles. - OnSnapshot: func(status st.ConversationStatus, messages []st.ConversationMessage, screen string) { - emitter.UpdateStatusAndEmitChanges(status, config.AgentType) - emitter.UpdateMessagesAndEmitChanges(messages) - emitter.UpdateScreenAndEmitChanges(screen) - }, - Logger: logger, - }) + Logger: logger, + }, emitter) // Create temporary directory for uploads tempDir, err := os.MkdirTemp("", "agentapi-uploads-") diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 9e6b856..8299faa 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -65,6 +65,13 @@ type Conversation interface { Text() string } +// Emitter receives conversation state updates. +type Emitter interface { + EmitMessages([]ConversationMessage) + EmitStatus(ConversationStatus) + EmitScreen(string) +} + type ConversationMessage struct { Id int Message string diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index ff0c7ee..2728377 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -68,9 +68,7 @@ type PTYConversationConfig struct { FormatToolCall func(message string) (string, []string) // InitialPrompt is the initial prompt to send to the agent once ready InitialPrompt []MessagePart - // OnSnapshot is called after each snapshot with current status, messages, and screen content - OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string) - Logger *slog.Logger + Logger *slog.Logger } func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { @@ -86,7 +84,8 @@ func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { // PTYConversation is a conversation that uses a pseudo-terminal (PTY) for communication. // It uses a combination of polling and diffs to detect changes in the screen. type PTYConversation struct { - cfg PTYConversationConfig + cfg PTYConversationConfig + emitter Emitter // How many stable snapshots are required to consider the screen stable stableSnapshotsThreshold int snapshotBuffer *RingBuffer[screenSnapshot] @@ -115,13 +114,23 @@ type PTYConversation struct { var _ Conversation = &PTYConversation{} -func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation { +type noopEmitter struct{} + +func (noopEmitter) EmitMessages([]ConversationMessage) {} +func (noopEmitter) EmitStatus(ConversationStatus) {} +func (noopEmitter) EmitScreen(string) {} + +func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PTYConversation { if cfg.Clock == nil { cfg.Clock = quartz.NewReal() } + if emitter == nil { + emitter = noopEmitter{} + } threshold := cfg.getStableSnapshotsThreshold() c := &PTYConversation{ cfg: cfg, + emitter: emitter, stableSnapshotsThreshold: threshold, snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), messages: []ConversationMessage{ @@ -139,9 +148,6 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation { if len(cfg.InitialPrompt) > 0 { c.outboundQueue <- outboundMessage{parts: cfg.InitialPrompt, errCh: nil} } - if c.cfg.OnSnapshot == nil { - c.cfg.OnSnapshot = func(ConversationStatus, []ConversationMessage, string) {} - } if c.cfg.ReadyForInitialPrompt == nil { c.cfg.ReadyForInitialPrompt = func(string) bool { return true } } @@ -173,7 +179,9 @@ func (c *PTYConversation) Start(ctx context.Context) { } c.lock.Unlock() - c.cfg.OnSnapshot(status, messages, screen) + c.emitter.EmitStatus(status) + c.emitter.EmitMessages(messages) + c.emitter.EmitScreen(screen) return nil }, "snapshot") diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index eaa4a69..19b4511 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -20,8 +20,8 @@ const testTimeout = 10 * time.Second // testAgent is a goroutine-safe mock implementation of AgentIO. type testAgent struct { - mu sync.Mutex - screen string + mu sync.Mutex + screen string // onWrite is called during Write to simulate the agent reacting to // terminal input (e.g., changing the screen), which unblocks // writeStabilize's polling loops. @@ -49,6 +49,12 @@ func (a *testAgent) setScreen(s string) { a.screen = s } +type testEmitter struct{} + +func (testEmitter) EmitMessages([]st.ConversationMessage) {} +func (testEmitter) EmitStatus(st.ConversationStatus) {} +func (testEmitter) EmitScreen(string) {} + // advanceFor is a shorthand for advanceUntil with a time-based condition. func advanceFor(ctx context.Context, t *testing.T, mClock *quartz.Mock, total time.Duration) { t.Helper() @@ -125,7 +131,7 @@ func statusTest(t *testing.T, params statusTestParams) { params.cfg.AgentIO = agent params.cfg.Logger = slog.New(slog.NewTextHandler(io.Discard, nil)) - c := st.NewPTY(ctx, params.cfg) + c := st.NewPTY(ctx, params.cfg, &testEmitter{}) c.Start(ctx) assert.Equal(t, st.ConversationStatusInitializing, c.Status()) @@ -220,11 +226,11 @@ func TestMessages(t *testing.T) { mClock := quartz.NewMock(t) mClock.Set(now) cfg := st.PTYConversationConfig{ - Clock: mClock, - AgentIO: agent, - SnapshotInterval: 100 * time.Millisecond, - ScreenStabilityLength: 200 * time.Millisecond, - Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + Clock: mClock, + AgentIO: agent, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } for _, opt := range opts { opt(&cfg) @@ -233,7 +239,7 @@ func TestMessages(t *testing.T) { agent = a } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) return c, agent, mClock @@ -460,7 +466,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Take a snapshot with "loading...". Threshold is 1 (stability 0 / interval 1s = 0 + 1 = 1). @@ -488,7 +494,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Agent not ready initially. @@ -513,18 +519,18 @@ func TestInitialPromptReadiness(t *testing.T) { agent.screen = fmt.Sprintf("__write_%d", writeCounter) } cfg := st.PTYConversationConfig{ - Clock: mClock, - SnapshotInterval: 1 * time.Second, - ScreenStabilityLength: 0, - AgentIO: agent, + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, ReadyForInitialPrompt: func(message string) bool { return message == "ready" }, - InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "initial prompt here"}}, - Logger: discardLogger, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "initial prompt here"}}, + Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Status is "changing" while waiting for readiness. @@ -564,7 +570,7 @@ func TestInitialPromptReadiness(t *testing.T) { Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) advanceFor(ctx, t, mClock, 1*time.Second) @@ -579,14 +585,14 @@ func TestInitialPromptReadiness(t *testing.T) { mClock := quartz.NewMock(t) agent := &testAgent{screen: "ready"} cfg := st.PTYConversationConfig{ - Clock: mClock, - SnapshotInterval: 1 * time.Second, - ScreenStabilityLength: 2 * time.Second, // threshold = 3 - AgentIO: agent, - Logger: discardLogger, + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 2 * time.Second, // threshold = 3 + AgentIO: agent, + Logger: discardLogger, } - c := st.NewPTY(ctx, cfg) + c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) // Fill buffer to reach stability with "ready" screen.