@@ -378,24 +378,30 @@ void ExprInterpreter::visit(const Let *op) {
378378void ExprInterpreter::visit (const Ramp *op) {
379379 EvalValue base = eval (op->base ), stride = eval (op->stride );
380380 result = EvalValue (op->type );
381- std::visit (
382- [&](auto b, auto s) {
383- if constexpr (std::is_same_v<decltype (b), decltype (s)>) {
384- for (int j = 0 ; j < op->lanes ; j++) {
385- auto res = b + j * s;
386- if (op->type .is_float ()) {
387- result.lanes [j] = static_cast <double >(res);
388- } else if (op->type .is_int ()) {
389- result.lanes [j] = static_cast <int64_t >(res);
381+
382+ int n = base.type .lanes (); // The lane-width of the base and stride
383+
384+ // ramp(b, s, l) = concat_vectors(b, b + s, b + 2*s, ... b + (l-1)*s)
385+ for (int j = 0 ; j < op->lanes ; j++) {
386+ for (int k = 0 ; k < n; k++) {
387+ std::visit (
388+ [&](auto b, auto s) {
389+ if constexpr (std::is_same_v<decltype (b), decltype (s)>) {
390+ auto res = b + j * s;
391+ if (op->type .is_float ()) {
392+ result.lanes [j * n + k] = static_cast <double >(res);
393+ } else if (op->type .is_int ()) {
394+ result.lanes [j * n + k] = static_cast <int64_t >(res);
395+ } else {
396+ result.lanes [j * n + k] = static_cast <uint64_t >(res);
397+ }
390398 } else {
391- result. lanes [j] = static_cast < uint64_t >(res) ;
399+ internal_error << " Ramp base and stride type mismatch " ;
392400 }
393- }
394- } else {
395- internal_error << " Ramp base and stride type mismatch" ;
396- }
397- },
398- base.lanes [0 ], stride.lanes [0 ]);
401+ },
402+ base.lanes [k], stride.lanes [k]);
403+ }
404+ }
399405}
400406
401407void ExprInterpreter::visit (const Broadcast *op) {
@@ -587,15 +593,15 @@ void ExprInterpreter::visit(const Call *op) {
587593 result = args[0 ];
588594 } else if (op->is_intrinsic ({Call::return_second, Call::require})) {
589595 result = args[1 ];
590- } else if (op->name == " sin " ) {
596+ } else if (starts_with ( op->name , " sin_ " ) ) {
591597 result = apply_unary (op->type , args[0 ], [](auto a) { return std::sin (a); });
592- } else if (op->name == " cos " ) {
598+ } else if (starts_with ( op->name , " cos_ " ) ) {
593599 result = apply_unary (op->type , args[0 ], [](auto a) { return std::cos (a); });
594- } else if (op->name == " exp " ) {
600+ } else if (starts_with ( op->name , " exp_ " ) ) {
595601 result = apply_unary (op->type , args[0 ], [](auto a) { return std::exp (a); });
596- } else if (op->name == " log " ) {
602+ } else if (starts_with ( op->name , " log_ " ) ) {
597603 result = apply_unary (op->type , args[0 ], [](auto a) { return std::log (a); });
598- } else if (op->name == " sqrt " ) {
604+ } else if (starts_with ( op->name , " sqrt_ " ) ) {
599605 result = apply_unary (op->type , args[0 ], [](auto a) { return std::sqrt (a); });
600606 } else if (op->is_intrinsic (Call::strict_fma)) {
601607 internal_assert (op->args .size () == 3 );
@@ -634,5 +640,135 @@ void ExprInterpreter::visit(const Call *op) {
634640 }
635641}
636642
643+ namespace {
644+
645+ void test_scalar_equivalence () {
646+ ExprInterpreter interp;
647+
648+ // 1. Integer scalar math equivalence
649+ auto math_test_int = [](const auto &x, const auto &y) {
650+ // Keeps values positive to align C++ truncation division with Halide's Euclidean division
651+ return (x + y) * (x - y) + (x / y) + (x % y);
652+ };
653+
654+ int32_t cx = 42 , cy = 5 ;
655+ int32_t c_res = math_test_int (cx, cy);
656+
657+ Expr hx = Expr (cx), hy = Expr (cy);
658+ Expr h_ast = math_test_int (hx, hy);
659+
660+ auto eval_res = interp.eval (h_ast);
661+ internal_assert (eval_res.type .is_int () && eval_res.type .bits () == 32 && eval_res.type .lanes () == 1 );
662+ internal_assert (std::get<int64_t >(eval_res.lanes [0 ]) == c_res)
663+ << " Integer scalar evaluation mismatch. Expected: " << c_res
664+ << " , Got: " << std::get<int64_t >(eval_res.lanes [0 ]);
665+
666+ // 2. Float scalar math equivalence
667+ using std::sin;
668+ using Halide::sin;
669+ auto math_test_float = [](const auto &x, const auto &y) {
670+ return (x * y) - sin (x / (y + 1 .0f ));
671+ };
672+
673+ float fx = 3 .14f , fy = 2 .0f ;
674+ float f_res = math_test_float (fx, fy);
675+
676+ Expr hfx = Expr (fx), hfy = Expr (fy);
677+ Expr hf_ast = math_test_float (hfx, hfy);
678+
679+ auto eval_f_res = interp.eval (hf_ast);
680+ internal_assert (eval_f_res.type .is_float () && eval_f_res.type .bits () == 32 && eval_f_res.type .lanes () == 1 );
681+
682+ double diff = std::abs (std::get<double >(eval_f_res.lanes [0 ]) - f_res);
683+ internal_assert (diff < 1e-5 ) << " Float scalar evaluation mismatch." ;
684+ }
685+
686+ void test_vector_operations () {
687+ ExprInterpreter interp;
688+
689+ // 1. Ramp: create a vector <10, 13, 16, 19>
690+ Expr base = Expr (10 );
691+ Expr stride = Expr (3 );
692+ Expr ramp = Ramp::make (base, stride, 4 );
693+
694+ auto eval_ramp = interp.eval (ramp);
695+ internal_assert (eval_ramp.type .lanes () == 4 );
696+ internal_assert (std::get<int64_t >(eval_ramp.lanes [0 ]) == 10 );
697+ internal_assert (std::get<int64_t >(eval_ramp.lanes [1 ]) == 13 );
698+ internal_assert (std::get<int64_t >(eval_ramp.lanes [2 ]) == 16 );
699+ internal_assert (std::get<int64_t >(eval_ramp.lanes [3 ]) == 19 );
700+
701+ // 2. Broadcast: <5, 5, 5>
702+ Expr bc = Broadcast::make (Expr (5 ), 3 );
703+ auto eval_bc = interp.eval (bc);
704+ internal_assert (eval_bc.type .lanes () == 3 );
705+ internal_assert (std::get<int64_t >(eval_bc.lanes [0 ]) == 5 );
706+ internal_assert (std::get<int64_t >(eval_bc.lanes [1 ]) == 5 );
707+ internal_assert (std::get<int64_t >(eval_bc.lanes [2 ]) == 5 );
708+
709+ // 3. Shuffle: reverse the ramp -> <19, 16, 13, 10>
710+ Expr reversed = Shuffle::make ({ramp}, {3 , 2 , 1 , 0 });
711+ auto eval_rev = interp.eval (reversed);
712+ internal_assert (eval_rev.type .lanes () == 4 );
713+ internal_assert (std::get<int64_t >(eval_rev.lanes [0 ]) == 19 );
714+ internal_assert (std::get<int64_t >(eval_rev.lanes [1 ]) == 16 );
715+ internal_assert (std::get<int64_t >(eval_rev.lanes [2 ]) == 13 );
716+ internal_assert (std::get<int64_t >(eval_rev.lanes [3 ]) == 10 );
717+
718+ // 4. VectorReduce: Sum the ramp -> 10 + 13 + 16 + 19 = 58
719+ Expr sum = VectorReduce::make (VectorReduce::Add, ramp, 1 );
720+ auto eval_sum = interp.eval (sum);
721+ internal_assert (eval_sum.type .lanes () == 1 );
722+ internal_assert (std::get<int64_t >(eval_sum.lanes [0 ]) == 58 );
723+
724+ // 5. Ramp of Ramp
725+ Expr ramp_of_ramp = Ramp::make (ramp, Broadcast::make (100 , 4 ), 4 );
726+ auto eval_ror = interp.eval (ramp_of_ramp);
727+ internal_assert (eval_ror.type .lanes () == 16 );
728+ for (int i = 0 ; i < 4 ; ++i) {
729+ internal_assert (std::get<int64_t >(eval_ror.lanes [4 * i + 0 ]) == 100 * i + 10 );
730+ internal_assert (std::get<int64_t >(eval_ror.lanes [4 * i + 1 ]) == 100 * i + 13 );
731+ internal_assert (std::get<int64_t >(eval_ror.lanes [4 * i + 2 ]) == 100 * i + 16 );
732+ internal_assert (std::get<int64_t >(eval_ror.lanes [4 * i + 3 ]) == 100 * i + 19 );
733+ }
734+
735+ // 6. Broadcast of Ramp
736+ Expr bc_of_ramp = Broadcast::make (ramp, 5 );
737+ auto eval_bor = interp.eval (bc_of_ramp);
738+ internal_assert (eval_bor.type .lanes () == 20 );
739+ for (int i = 0 ; i < 5 ; ++i) {
740+ internal_assert (std::get<int64_t >(eval_bor.lanes [4 * i + 0 ]) == 10 );
741+ internal_assert (std::get<int64_t >(eval_bor.lanes [4 * i + 1 ]) == 13 );
742+ internal_assert (std::get<int64_t >(eval_bor.lanes [4 * i + 2 ]) == 16 );
743+ internal_assert (std::get<int64_t >(eval_bor.lanes [4 * i + 3 ]) == 19 );
744+ }
745+ }
746+
747+ void test_let_and_scoping () {
748+ ExprInterpreter interp;
749+
750+ // Test: let x = 42 in (let x = x + 8 in x * 2)
751+ // Inner scoping should shadow outer scoping and evaluate cleanly
752+ Expr var_x = Variable::make (Int (32 ), " x" );
753+ Expr inner_let = Let::make (" x" , var_x + Expr (8 ), var_x * Expr (2 ));
754+ Expr outer_let = Let::make (" x" , Expr (42 ), inner_let);
755+
756+ auto res = interp.eval (outer_let);
757+ internal_assert (res.type .is_int () && res.type .lanes () == 1 );
758+
759+ // (42 + 8) * 2 = 100
760+ internal_assert (std::get<int64_t >(res.lanes [0 ]) == 100 )
761+ << " Variable scoping / Let evaluation failed." ;
762+ }
763+ } // namespace
764+
765+ void ExprInterpreter::test () {
766+ test_scalar_equivalence ();
767+ test_vector_operations ();
768+ test_let_and_scoping ();
769+
770+ std::cout << " ExprInterpreter tests passed!" << " \n " ;
771+ }
772+
637773} // namespace Internal
638774} // namespace Halide
0 commit comments