Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions crates/rmcp-macros/src/task_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
status: rmcp::model::TaskStatus::Working,
status_message: None,
created_at: timestamp.clone(),
last_updated_at: Some(timestamp),
last_updated_at: timestamp,
ttl: None,
poll_interval: None,
}
Expand Down Expand Up @@ -111,7 +111,7 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
status: rmcp::model::TaskStatus::Working,
status_message: Some("Task accepted".to_string()),
created_at: timestamp.clone(),
last_updated_at: Some(timestamp),
last_updated_at: timestamp,
ttl: None,
poll_interval: None,
};
Expand All @@ -128,7 +128,7 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
&self,
request: rmcp::model::GetTaskInfoParam,
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::GetTaskInfoResult, McpError> {
) -> Result<rmcp::model::GetTaskResult, McpError> {
use rmcp::task_manager::current_timestamp;
let task_id = request.task_id.clone();
let mut processor = (#processor).lock().await;
Expand Down Expand Up @@ -156,11 +156,11 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
status,
status_message: None,
created_at: timestamp.clone(),
last_updated_at: Some(timestamp),
last_updated_at: timestamp,
ttl: completed_result.descriptor.ttl,
poll_interval: None,
};
return Ok(rmcp::model::GetTaskInfoResult { task: Some(task) });
return Ok(rmcp::model::GetTaskResult { meta: None, task });
}

// If not completed, check running
Expand All @@ -172,14 +172,14 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
status: rmcp::model::TaskStatus::Working,
status_message: None,
created_at: timestamp.clone(),
last_updated_at: Some(timestamp),
last_updated_at: timestamp,
ttl: None,
poll_interval: None,
};
return Ok(rmcp::model::GetTaskInfoResult { task: Some(task) });
return Ok(rmcp::model::GetTaskResult { meta: None, task });
}

Ok(rmcp::model::GetTaskInfoResult { task: None })
Err(McpError::resource_not_found(format!("task not found: {}", task_id), None))
}
};
item_impl.items.push(syn::parse2::<ImplItem>(get_info_fn)?);
Expand All @@ -191,7 +191,7 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
&self,
request: rmcp::model::GetTaskResultParam,
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::TaskResult, McpError> {
) -> Result<rmcp::model::GetTaskPayloadResult, McpError> {
use std::time::Duration;
let task_id = request.task_id.clone();

Expand All @@ -207,11 +207,7 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
match &tool.result {
Ok(call_tool) => {
let value = ::serde_json::to_value(call_tool).unwrap_or(::serde_json::Value::Null);
return Ok(rmcp::model::TaskResult {
content_type: "application/json".to_string(),
value,
summary: None,
});
return Ok(rmcp::model::GetTaskPayloadResult(value));
}
Err(err) => return Err(McpError::internal_error(
format!("task failed: {}", err),
Expand Down Expand Up @@ -251,12 +247,23 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
&self,
request: rmcp::model::CancelTaskParam,
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<(), McpError> {
) -> Result<rmcp::model::CancelTaskResult, McpError> {
use rmcp::task_manager::current_timestamp;
let task_id = request.task_id;
let mut processor = (#processor).lock().await;

if processor.cancel_task(&task_id) {
return Ok(());
let timestamp = current_timestamp();
let task = rmcp::model::Task {
task_id,
status: rmcp::model::TaskStatus::Cancelled,
status_message: None,
created_at: timestamp.clone(),
last_updated_at: timestamp,
ttl: None,
poll_interval: None,
};
return Ok(rmcp::model::CancelTaskResult { meta: None, task });
}

// If already completed, signal it's not cancellable
Expand Down
19 changes: 10 additions & 9 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ impl<H: ServerHandler> Service<RoleServer> for H {
ClientRequest::GetTaskInfoRequest(request) => self
.get_task_info(request.params, context)
.await
.map(ServerResult::GetTaskInfoResult),
.map(ServerResult::GetTaskResult),
ClientRequest::GetTaskResultRequest(request) => self
.get_task_result(request.params, context)
.await
.map(ServerResult::TaskResult),
.map(ServerResult::GetTaskPayloadResult),
ClientRequest::CancelTaskRequest(request) => self
.cancel_task(request.params, context)
.await
.map(ServerResult::empty),
.map(ServerResult::CancelTaskResult),
}
}

Expand Down Expand Up @@ -339,15 +339,16 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
&self,
request: GetTaskInfoParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<GetTaskInfoResult, McpError>> + Send + '_ {
) -> impl Future<Output = Result<GetTaskResult, McpError>> + Send + '_ {
let _ = (request, context);
std::future::ready(Err(McpError::method_not_found::<GetTaskInfoMethod>()))
}

fn get_task_result(
&self,
request: GetTaskResultParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<TaskResult, McpError>> + Send + '_ {
) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + Send + '_ {
let _ = (request, context);
std::future::ready(Err(McpError::method_not_found::<GetTaskResultMethod>()))
}
Expand All @@ -356,7 +357,7 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
&self,
request: CancelTaskParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
) -> impl Future<Output = Result<CancelTaskResult, McpError>> + Send + '_ {
let _ = (request, context);
std::future::ready(Err(McpError::method_not_found::<CancelTaskMethod>()))
}
Expand Down Expand Up @@ -543,23 +544,23 @@ macro_rules! impl_server_handler_for_wrapper {
&self,
request: GetTaskInfoParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<GetTaskInfoResult, McpError>> + Send + '_ {
) -> impl Future<Output = Result<GetTaskResult, McpError>> + Send + '_ {
(**self).get_task_info(request, context)
}

fn get_task_result(
&self,
request: GetTaskResultParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<TaskResult, McpError>> + Send + '_ {
) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + Send + '_ {
(**self).get_task_result(request, context)
}

fn cancel_task(
&self,
request: CancelTaskParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
) -> impl Future<Output = Result<CancelTaskResult, McpError>> + Send + '_ {
(**self).cancel_task(request, context)
}
}
Expand Down
16 changes: 6 additions & 10 deletions crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2537,14 +2537,9 @@ impl RequestParamsMeta for CancelTaskParams {
/// Deprecated: Use [`CancelTaskParams`] instead (SEP-1319 compliance).
#[deprecated(since = "0.13.0", note = "Use CancelTaskParams instead")]
pub type CancelTaskParam = CancelTaskParams;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct GetTaskInfoResult {
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<crate::model::Task>,
}
/// Deprecated: Use [`GetTaskResult`] instead (spec alignment).
#[deprecated(since = "0.15.0", note = "Use GetTaskResult instead")]
pub type GetTaskInfoResult = GetTaskResult;

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -2720,9 +2715,10 @@ ts_union!(
| EmptyResult
| CreateTaskResult
| ListTasksResult
| GetTaskInfoResult
| TaskResult
| GetTaskResult
| CancelTaskResult
| CustomResult
| GetTaskPayloadResult
;
);

Expand Down
61 changes: 44 additions & 17 deletions crates/rmcp/src/model/task.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;

use super::Meta;

/// Canonical task lifecycle status as defined by SEP-1686.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
Expand All @@ -19,21 +21,10 @@ pub enum TaskStatus {
Cancelled,
}

/// Final result for a succeeded task (returned from `tasks/result`).
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct TaskResult {
/// MIME type or custom content-type identifier.
pub content_type: String,
/// The actual result payload, matching the underlying request's schema.
pub value: Value,
/// Optional short summary for UI surfaces.
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<String>,
}

/// Primary Task object that surfaces metadata during the task lifecycle.
///
/// Per spec, `lastUpdatedAt` and `ttl` are required fields.
/// `ttl` is nullable (`null` means unlimited retention).
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
Expand All @@ -48,10 +39,9 @@ pub struct Task {
/// ISO-8601 creation timestamp.
pub created_at: String,
/// ISO-8601 timestamp for the most recent status change.
#[serde(skip_serializing_if = "Option::is_none")]
pub last_updated_at: Option<String>,
pub last_updated_at: String,
/// Retention window in milliseconds that the receiver agreed to honor.
#[serde(skip_serializing_if = "Option::is_none")]
/// `None` (serialized as `null`) means unlimited retention.
pub ttl: Option<u64>,
/// Suggested polling interval (milliseconds).
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -66,6 +56,43 @@ pub struct CreateTaskResult {
pub task: Task,
}

/// Response to a `tasks/get` request.
///
/// Per spec, `GetTaskResult = allOf[Result, Task]` — the Task fields are
/// flattened at the top level, not nested under a `task` key.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct GetTaskResult {
#[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")]
pub meta: Option<Meta>,
#[serde(flatten)]
pub task: Task,
}

/// Response to a `tasks/result` request.
///
/// Per spec, the result structure matches the original request type
/// (e.g., `CallToolResult` for `tools/call`). This is represented as
/// an open object. The payload is the original request's result
/// serialized as a JSON value.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct GetTaskPayloadResult(pub Value);

/// Response to a `tasks/cancel` request.
///
/// Per spec, `CancelTaskResult = allOf[Result, Task]` — same shape as `GetTaskResult`.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct CancelTaskResult {
#[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")]
pub meta: Option<Meta>,
#[serde(flatten)]
pub task: Task,
}

/// Paginated list of tasks
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand Down
Loading