Skip to content

Commit 2a84238

Browse files
committed
Add ExprInterpreter test, and fix Ramp of vector.
1 parent cc6b154 commit 2a84238

File tree

3 files changed

+162
-21
lines changed

3 files changed

+162
-21
lines changed

src/ExprInterpreter.cpp

Lines changed: 157 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -378,24 +378,30 @@ void ExprInterpreter::visit(const Let *op) {
378378
void ExprInterpreter::visit(const Ramp *op) {
379379
EvalValue base = eval(op->base), stride = eval(op->stride);
380380
result = EvalValue(op->type);
381-
std::visit(
382-
[&](auto b, auto s) {
383-
if constexpr (std::is_same_v<decltype(b), decltype(s)>) {
384-
for (int j = 0; j < op->lanes; j++) {
385-
auto res = b + j * s;
386-
if (op->type.is_float()) {
387-
result.lanes[j] = static_cast<double>(res);
388-
} else if (op->type.is_int()) {
389-
result.lanes[j] = static_cast<int64_t>(res);
381+
382+
int n = base.type.lanes(); // The lane-width of the base and stride
383+
384+
// ramp(b, s, l) = concat_vectors(b, b + s, b + 2*s, ... b + (l-1)*s)
385+
for (int j = 0; j < op->lanes; j++) {
386+
for (int k = 0; k < n; k++) {
387+
std::visit(
388+
[&](auto b, auto s) {
389+
if constexpr (std::is_same_v<decltype(b), decltype(s)>) {
390+
auto res = b + j * s;
391+
if (op->type.is_float()) {
392+
result.lanes[j * n + k] = static_cast<double>(res);
393+
} else if (op->type.is_int()) {
394+
result.lanes[j * n + k] = static_cast<int64_t>(res);
395+
} else {
396+
result.lanes[j * n + k] = static_cast<uint64_t>(res);
397+
}
390398
} else {
391-
result.lanes[j] = static_cast<uint64_t>(res);
399+
internal_error << "Ramp base and stride type mismatch";
392400
}
393-
}
394-
} else {
395-
internal_error << "Ramp base and stride type mismatch";
396-
}
397-
},
398-
base.lanes[0], stride.lanes[0]);
401+
},
402+
base.lanes[k], stride.lanes[k]);
403+
}
404+
}
399405
}
400406

401407
void ExprInterpreter::visit(const Broadcast *op) {
@@ -587,15 +593,15 @@ void ExprInterpreter::visit(const Call *op) {
587593
result = args[0];
588594
} else if (op->is_intrinsic({Call::return_second, Call::require})) {
589595
result = args[1];
590-
} else if (op->name == "sin") {
596+
} else if (starts_with(op->name, "sin_")) {
591597
result = apply_unary(op->type, args[0], [](auto a) { return std::sin(a); });
592-
} else if (op->name == "cos") {
598+
} else if (starts_with(op->name, "cos_")) {
593599
result = apply_unary(op->type, args[0], [](auto a) { return std::cos(a); });
594-
} else if (op->name == "exp") {
600+
} else if (starts_with(op->name, "exp_")) {
595601
result = apply_unary(op->type, args[0], [](auto a) { return std::exp(a); });
596-
} else if (op->name == "log") {
602+
} else if (starts_with(op->name, "log_")) {
597603
result = apply_unary(op->type, args[0], [](auto a) { return std::log(a); });
598-
} else if (op->name == "sqrt") {
604+
} else if (starts_with(op->name, "sqrt_")) {
599605
result = apply_unary(op->type, args[0], [](auto a) { return std::sqrt(a); });
600606
} else if (op->is_intrinsic(Call::strict_fma)) {
601607
internal_assert(op->args.size() == 3);
@@ -634,5 +640,135 @@ void ExprInterpreter::visit(const Call *op) {
634640
}
635641
}
636642

643+
namespace {
644+
645+
void test_scalar_equivalence() {
646+
ExprInterpreter interp;
647+
648+
// 1. Integer scalar math equivalence
649+
auto math_test_int = [](const auto &x, const auto &y) {
650+
// Keeps values positive to align C++ truncation division with Halide's Euclidean division
651+
return (x + y) * (x - y) + (x / y) + (x % y);
652+
};
653+
654+
int32_t cx = 42, cy = 5;
655+
int32_t c_res = math_test_int(cx, cy);
656+
657+
Expr hx = Expr(cx), hy = Expr(cy);
658+
Expr h_ast = math_test_int(hx, hy);
659+
660+
auto eval_res = interp.eval(h_ast);
661+
internal_assert(eval_res.type.is_int() && eval_res.type.bits() == 32 && eval_res.type.lanes() == 1);
662+
internal_assert(std::get<int64_t>(eval_res.lanes[0]) == c_res)
663+
<< "Integer scalar evaluation mismatch. Expected: " << c_res
664+
<< ", Got: " << std::get<int64_t>(eval_res.lanes[0]);
665+
666+
// 2. Float scalar math equivalence
667+
using std::sin;
668+
using Halide::sin;
669+
auto math_test_float = [](const auto &x, const auto &y) {
670+
return (x * y) - sin(x / (y + 1.0f));
671+
};
672+
673+
float fx = 3.14f, fy = 2.0f;
674+
float f_res = math_test_float(fx, fy);
675+
676+
Expr hfx = Expr(fx), hfy = Expr(fy);
677+
Expr hf_ast = math_test_float(hfx, hfy);
678+
679+
auto eval_f_res = interp.eval(hf_ast);
680+
internal_assert(eval_f_res.type.is_float() && eval_f_res.type.bits() == 32 && eval_f_res.type.lanes() == 1);
681+
682+
double diff = std::abs(std::get<double>(eval_f_res.lanes[0]) - f_res);
683+
internal_assert(diff < 1e-5) << "Float scalar evaluation mismatch.";
684+
}
685+
686+
void test_vector_operations() {
687+
ExprInterpreter interp;
688+
689+
// 1. Ramp: create a vector <10, 13, 16, 19>
690+
Expr base = Expr(10);
691+
Expr stride = Expr(3);
692+
Expr ramp = Ramp::make(base, stride, 4);
693+
694+
auto eval_ramp = interp.eval(ramp);
695+
internal_assert(eval_ramp.type.lanes() == 4);
696+
internal_assert(std::get<int64_t>(eval_ramp.lanes[0]) == 10);
697+
internal_assert(std::get<int64_t>(eval_ramp.lanes[1]) == 13);
698+
internal_assert(std::get<int64_t>(eval_ramp.lanes[2]) == 16);
699+
internal_assert(std::get<int64_t>(eval_ramp.lanes[3]) == 19);
700+
701+
// 2. Broadcast: <5, 5, 5>
702+
Expr bc = Broadcast::make(Expr(5), 3);
703+
auto eval_bc = interp.eval(bc);
704+
internal_assert(eval_bc.type.lanes() == 3);
705+
internal_assert(std::get<int64_t>(eval_bc.lanes[0]) == 5);
706+
internal_assert(std::get<int64_t>(eval_bc.lanes[1]) == 5);
707+
internal_assert(std::get<int64_t>(eval_bc.lanes[2]) == 5);
708+
709+
// 3. Shuffle: reverse the ramp -> <19, 16, 13, 10>
710+
Expr reversed = Shuffle::make({ramp}, {3, 2, 1, 0});
711+
auto eval_rev = interp.eval(reversed);
712+
internal_assert(eval_rev.type.lanes() == 4);
713+
internal_assert(std::get<int64_t>(eval_rev.lanes[0]) == 19);
714+
internal_assert(std::get<int64_t>(eval_rev.lanes[1]) == 16);
715+
internal_assert(std::get<int64_t>(eval_rev.lanes[2]) == 13);
716+
internal_assert(std::get<int64_t>(eval_rev.lanes[3]) == 10);
717+
718+
// 4. VectorReduce: Sum the ramp -> 10 + 13 + 16 + 19 = 58
719+
Expr sum = VectorReduce::make(VectorReduce::Add, ramp, 1);
720+
auto eval_sum = interp.eval(sum);
721+
internal_assert(eval_sum.type.lanes() == 1);
722+
internal_assert(std::get<int64_t>(eval_sum.lanes[0]) == 58);
723+
724+
// 5. Ramp of Ramp
725+
Expr ramp_of_ramp = Ramp::make(ramp, Broadcast::make(100, 4), 4);
726+
auto eval_ror = interp.eval(ramp_of_ramp);
727+
internal_assert(eval_ror.type.lanes() == 16);
728+
for (int i = 0; i < 4; ++i) {
729+
internal_assert(std::get<int64_t>(eval_ror.lanes[4 * i + 0]) == 100 * i + 10);
730+
internal_assert(std::get<int64_t>(eval_ror.lanes[4 * i + 1]) == 100 * i + 13);
731+
internal_assert(std::get<int64_t>(eval_ror.lanes[4 * i + 2]) == 100 * i + 16);
732+
internal_assert(std::get<int64_t>(eval_ror.lanes[4 * i + 3]) == 100 * i + 19);
733+
}
734+
735+
// 6. Broadcast of Ramp
736+
Expr bc_of_ramp = Broadcast::make(ramp, 5);
737+
auto eval_bor = interp.eval(bc_of_ramp);
738+
internal_assert(eval_bor.type.lanes() == 20);
739+
for (int i = 0; i < 5; ++i) {
740+
internal_assert(std::get<int64_t>(eval_bor.lanes[4 * i + 0]) == 10);
741+
internal_assert(std::get<int64_t>(eval_bor.lanes[4 * i + 1]) == 13);
742+
internal_assert(std::get<int64_t>(eval_bor.lanes[4 * i + 2]) == 16);
743+
internal_assert(std::get<int64_t>(eval_bor.lanes[4 * i + 3]) == 19);
744+
}
745+
}
746+
747+
void test_let_and_scoping() {
748+
ExprInterpreter interp;
749+
750+
// Test: let x = 42 in (let x = x + 8 in x * 2)
751+
// Inner scoping should shadow outer scoping and evaluate cleanly
752+
Expr var_x = Variable::make(Int(32), "x");
753+
Expr inner_let = Let::make("x", var_x + Expr(8), var_x * Expr(2));
754+
Expr outer_let = Let::make("x", Expr(42), inner_let);
755+
756+
auto res = interp.eval(outer_let);
757+
internal_assert(res.type.is_int() && res.type.lanes() == 1);
758+
759+
// (42 + 8) * 2 = 100
760+
internal_assert(std::get<int64_t>(res.lanes[0]) == 100)
761+
<< "Variable scoping / Let evaluation failed.";
762+
}
763+
} // namespace
764+
765+
void ExprInterpreter::test() {
766+
test_scalar_equivalence();
767+
test_vector_operations();
768+
test_let_and_scoping();
769+
770+
std::cout << "ExprInterpreter tests passed!" << "\n";
771+
}
772+
637773
} // namespace Internal
638774
} // namespace Halide

src/ExprInterpreter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class ExprInterpreter : public IRVisitor {
7575

7676
template<typename F>
7777
EvalValue apply_cmp(Type t, const EvalValue &a, const EvalValue &b, F f);
78+
79+
public:
80+
static void test();
7881
};
7982

8083
} // namespace Internal

test/internal.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
#include "Solve.h"
1919
#include "SpirvIR.h"
2020
#include "UniquifyVariableNames.h"
21+
#include "ExprInterpreter.h"
2122

2223
using namespace Halide;
2324
using namespace Halide::Internal;
2425

2526
int main(int argc, const char **argv) {
2627
IRPrinter::test();
2728
CodeGen_C::test();
29+
ExprInterpreter::test();
2830
ir_equality_test();
2931
bounds_test();
3032
expr_match_test();

0 commit comments

Comments
 (0)