Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 310 additions & 1 deletion soroban-spec-rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ mod syn_ext;
pub mod r#trait;
pub mod types;

use std::borrow::Cow;
use std::{fs, io};

use proc_macro2::TokenStream;
use quote::quote;
use sha2::{Digest, Sha256};
use stellar_xdr::curr as stellar_xdr;
use stellar_xdr::ScSpecEntry;
use stellar_xdr::{ScSpecEntry, ScSpecTypeDef, ScSpecTypeUdt, ScSpecUdtUnionCaseV0};
use syn::Error;

use soroban_spec::read::{from_wasm, FromWasmError};
Expand Down Expand Up @@ -107,6 +108,9 @@ pub fn generate_without_file_with_options(
specs: &[ScSpecEntry],
opts: &GenerateOptions,
) -> Result<TokenStream, GenerateError> {
let specs = apply_error_udt_override(specs);
let specs: &[ScSpecEntry] = &specs;

let mut spec_fns = Vec::new();
let mut spec_structs = Vec::new();
let mut spec_unions = Vec::new();
Expand Down Expand Up @@ -161,6 +165,101 @@ pub fn generate_without_file_with_options(
})
}

/// The `#[contractimpl]` macro emits any type named `Error` in a contract's
/// function signatures as the built-in `ScSpecTypeDef::Error` in the spec,
/// regardless of whether the contract defined its own error enum named `Error`
/// or used `soroban_sdk::Error` directly. To let clients of contracts that
/// define their own `Error` enum see the user-defined type instead of
/// `soroban_sdk::Error`, this pass rewrites every `ScSpecTypeDef::Error`
/// reference in the spec to `Udt { name: "Error" }` whenever the spec also
/// contains a `UdtErrorEnumV0` named `Error`.
///
/// This keeps the on-the-wire spec format unchanged (so already-deployed
/// contracts benefit without redeployment) and shifts the resolution to the
/// client generator.
///
/// Returns a borrowed slice when no rewrite is needed, otherwise a
/// freshly-owned `Vec` with the rewrite applied.
fn apply_error_udt_override(specs: &[ScSpecEntry]) -> Cow<'_, [ScSpecEntry]> {
let has_error_udt = specs.iter().any(|e| {
matches!(
e,
ScSpecEntry::UdtErrorEnumV0(err) if err.name.to_utf8_string_lossy() == "Error"
)
});
if has_error_udt {
let mut v = specs.to_vec();
rewrite_error_to_udt(&mut v);
Cow::Owned(v)
} else {
Cow::Borrowed(specs)
}
}

/// Rewrites every `ScSpecTypeDef::Error` reference in the given entries to
/// `ScSpecTypeDef::Udt { name: "Error" }`. Called only when the spec contains
/// a user-defined error enum named `Error`, so the UDT reference resolves to
/// that enum during code generation.
fn rewrite_error_to_udt(entries: &mut [ScSpecEntry]) {
fn rewrite_ty(t: &mut ScSpecTypeDef) {
match t {
ScSpecTypeDef::Error => {
*t = ScSpecTypeDef::Udt(ScSpecTypeUdt {
name: "Error".try_into().unwrap(),
});
}
ScSpecTypeDef::Option(o) => rewrite_ty(&mut o.value_type),
ScSpecTypeDef::Result(r) => {
rewrite_ty(&mut r.ok_type);
rewrite_ty(&mut r.error_type);
}
ScSpecTypeDef::Vec(v) => rewrite_ty(&mut v.element_type),
ScSpecTypeDef::Map(m) => {
rewrite_ty(&mut m.key_type);
rewrite_ty(&mut m.value_type);
}
ScSpecTypeDef::Tuple(tu) => {
for vt in tu.value_types.iter_mut() {
rewrite_ty(vt);
}
}
_ => {}
}
}
for entry in entries.iter_mut() {
match entry {
ScSpecEntry::FunctionV0(f) => {
for input in f.inputs.iter_mut() {
rewrite_ty(&mut input.type_);
}
for output in f.outputs.iter_mut() {
rewrite_ty(output);
}
}
ScSpecEntry::UdtStructV0(s) => {
for field in s.fields.iter_mut() {
rewrite_ty(&mut field.type_);
}
}
ScSpecEntry::UdtUnionV0(u) => {
for case in u.cases.iter_mut() {
if let ScSpecUdtUnionCaseV0::TupleV0(t) = case {
for ty in t.type_.iter_mut() {
rewrite_ty(ty);
}
}
}
}
ScSpecEntry::UdtEnumV0(_) | ScSpecEntry::UdtErrorEnumV0(_) => {}
ScSpecEntry::EventV0(e) => {
for p in e.params.iter_mut() {
rewrite_ty(&mut p.type_);
}
}
}
}
}

/// Implemented by types that can be converted into pretty formatted Strings of
/// Rust code.
pub trait ToFormattedString {
Expand Down Expand Up @@ -225,6 +324,216 @@ pub enum UdtEnum2 {
A = 10,
B = 15,
}
"#,
);
}

const ADD_U64_WASM: &[u8] =
include_bytes!("../../target/wasm32v1-none/release/test_add_u64.wasm");

/// Test that Result types with user-defined error types are generated correctly.
/// This specifically tests that:
/// - An error enum named `Error` generates `Result<u64, Error>` (not `Result<u64, soroban_sdk::Error>`)
/// - An error enum named `MyError` generates `Result<u64, MyError>`
#[test]
fn test_add_u64_result_types() {
let entries = from_wasm(ADD_U64_WASM).unwrap();
let rust = generate(&entries, "<file>", "<sha256>")
.unwrap()
.to_formatted_string()
.unwrap();
assert_eq!(
rust,
r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
#[soroban_sdk::contractargs(name = "Args")]
#[soroban_sdk::contractclient(name = "Client")]
pub trait Contract {
fn add(env: soroban_sdk::Env, a: u64, b: u64) -> u64;
fn safe_add(env: soroban_sdk::Env, a: u64, b: u64) -> Result<u64, Error>;
fn safe_add_two(env: soroban_sdk::Env, a: u64, b: u64) -> Result<u64, MyError>;
}
#[soroban_sdk::contracterror(export = false)]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub enum Error {
Overflow = 1,
}
#[soroban_sdk::contracterror(export = false)]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub enum MyError {
Overflow = 1,
}
"#,
);
}

/// Test that shows the raw spec entries from the wasm.
/// Verifies that the on-the-wire spec format is unchanged: a contract
/// error enum named `Error` is still emitted as the built-in
/// `ScSpecTypeDef::Error` in function signatures (the user-defined-vs-SDK
/// disambiguation happens at client generation time, not here). A
/// differently-named error enum (`MyError`) is emitted as a UDT reference.
#[test]
fn test_add_u64_spec_entries() {
use super::ScSpecEntry;
use stellar_xdr::curr::ScSpecTypeDef;

let entries = from_wasm(ADD_U64_WASM).unwrap();

// Find the safe_add function spec
let safe_add_fn = entries
.iter()
.find_map(|e| match e {
ScSpecEntry::FunctionV0(f) if f.name.to_utf8_string().unwrap() == "safe_add" => {
Some(f)
}
_ => None,
})
.expect("safe_add function not found");

let output = safe_add_fn.outputs.to_option().expect("should have output");
let ScSpecTypeDef::Result(r) = output else {
panic!("output should be a Result type");
};
assert!(
matches!(r.ok_type.as_ref(), ScSpecTypeDef::U64),
"ok_type should be U64"
);
assert!(
matches!(r.error_type.as_ref(), ScSpecTypeDef::Error),
"error_type should be the built-in Error in the wasm spec, got {:?}",
r.error_type
);

// Find the safe_add_two function spec
let safe_add_two_fn = entries
.iter()
.find_map(|e| match e {
ScSpecEntry::FunctionV0(f)
if f.name.to_utf8_string().unwrap() == "safe_add_two" =>
{
Some(f)
}
_ => None,
})
.expect("safe_add_two function not found");

let output = safe_add_two_fn
.outputs
.to_option()
.expect("should have output");
let ScSpecTypeDef::Result(r) = output else {
panic!("output should be a Result type");
};
assert!(
matches!(r.ok_type.as_ref(), ScSpecTypeDef::U64),
"ok_type should be U64"
);
let ScSpecTypeDef::Udt(u) = r.error_type.as_ref() else {
panic!(
"error_type should be a UDT for MyError, got {:?}",
r.error_type
);
};
assert_eq!(
u.name.to_utf8_string().unwrap(),
"MyError",
"error_type should be MyError UDT"
);
}

/// When the spec references `ScSpecTypeDef::Error` and contains no error
/// enum named `Error`, the generator must leave it as `soroban_sdk::Error`.
/// This covers contracts that use `soroban_sdk::Error` directly as their
/// Result error type, including every contract compiled before the
/// error-enum override was introduced.
#[test]
fn test_missing_error_udt_falls_back_to_sdk_error() {
use super::ScSpecEntry;
use stellar_xdr::curr::{ScSpecFunctionV0, ScSpecTypeDef, ScSpecTypeResult};

let func = ScSpecFunctionV0 {
doc: "".try_into().unwrap(),
name: "safe_add".try_into().unwrap(),
inputs: [].try_into().unwrap(),
outputs: [ScSpecTypeDef::Result(Box::new(ScSpecTypeResult {
ok_type: Box::new(ScSpecTypeDef::U64),
error_type: Box::new(ScSpecTypeDef::Error),
}))]
.try_into()
.unwrap(),
};
let entries = [ScSpecEntry::FunctionV0(func)];
let rust = generate(&entries, "<file>", "<sha256>")
.unwrap()
.to_formatted_string()
.unwrap();
assert_eq!(
rust,
r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
#[soroban_sdk::contractargs(name = "Args")]
#[soroban_sdk::contractclient(name = "Client")]
pub trait Contract {
fn safe_add(env: soroban_sdk::Env) -> Result<u64, soroban_sdk::Error>;
}
"#,
);
}

/// When the spec contains a user-defined `Error` error enum, every
/// `ScSpecTypeDef::Error` reference in the spec must be rewritten to
/// reference that UDT instead of `soroban_sdk::Error`.
#[test]
fn test_error_udt_overrides_sdk_error() {
use super::ScSpecEntry;
use stellar_xdr::curr::{
ScSpecFunctionV0, ScSpecTypeDef, ScSpecTypeResult, ScSpecUdtErrorEnumCaseV0,
ScSpecUdtErrorEnumV0,
};

let func = ScSpecFunctionV0 {
doc: "".try_into().unwrap(),
name: "safe_add".try_into().unwrap(),
inputs: [].try_into().unwrap(),
outputs: [ScSpecTypeDef::Result(Box::new(ScSpecTypeResult {
ok_type: Box::new(ScSpecTypeDef::U64),
error_type: Box::new(ScSpecTypeDef::Error),
}))]
.try_into()
.unwrap(),
};
let error_enum = ScSpecUdtErrorEnumV0 {
doc: "".try_into().unwrap(),
lib: "".try_into().unwrap(),
name: "Error".try_into().unwrap(),
cases: [ScSpecUdtErrorEnumCaseV0 {
doc: "".try_into().unwrap(),
name: "Overflow".try_into().unwrap(),
value: 1,
}]
.try_into()
.unwrap(),
};
let entries = [
ScSpecEntry::FunctionV0(func),
ScSpecEntry::UdtErrorEnumV0(error_enum),
];
let rust = generate(&entries, "<file>", "<sha256>")
.unwrap()
.to_formatted_string()
.unwrap();
assert_eq!(
rust,
r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
#[soroban_sdk::contractargs(name = "Args")]
#[soroban_sdk::contractclient(name = "Client")]
pub trait Contract {
fn safe_add(env: soroban_sdk::Env) -> Result<u64, Error>;
}
#[soroban_sdk::contracterror(export = false)]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub enum Error {
Overflow = 1,
}
"#,
);
}
Expand Down
Loading
Loading