Line data Source code
1 : #pragma once
2 : #include <iostream>
3 : #include <vector>
4 : #include <string>
5 : #include <cmath>
6 : #include <chrono>
7 :
8 : namespace test
9 : {
10 : struct TestCase
11 : {
12 : std::string name;
13 : void (*func)();
14 : };
15 :
16 : template<typename T1, typename T2, typename T3 = double>
17 314 : inline bool near(T1 a, T2 b, T3 precision = 1e-6)
18 : {
19 314 : return std::abs(static_cast<double>(a) - static_cast<double>(b)) < static_cast<double>(precision);
20 : }
21 : }
22 :
23 : #define TEST(name) static void test_##name()
24 : #define RUN_TEST(name) {#name, test_##name}
25 :
26 : #define ASSERT_EQ(a, b) \
27 : if (!(a == b)) \
28 : { \
29 : std::cerr << "\n [FAIL] Expected Equality: " << #a << " == " << #b << " at line " << __LINE__ << "\n"; \
30 : throw std::runtime_error("test failed"); \
31 : }
32 :
33 : #define ASSERT_APPROX(a, b, ...) \
34 : if (!(a.is_approx(b, ##__VA_ARGS__))) \
35 : { \
36 : std::cerr << "\n [FAIL] Expected Approx Equality: " << #a << " approx " << #b \
37 : << " at line " << __LINE__ << "\n"; \
38 : throw std::runtime_error("test failed"); \
39 : }
40 :
41 : #define ASSERT_TRUE(cond) \
42 : if (!(cond)) \
43 : { \
44 : std::cerr << " FAIL: " << #cond << " at line " << __LINE__ << "\n"; \
45 : throw std::runtime_error("test failed"); \
46 : }
47 :
48 : #define ASSERT_FALSE(cond) \
49 : if ((cond)) \
50 : { \
51 : std::cerr << " FAIL: " << #cond << " at line " << __LINE__ << " should be false\n"; \
52 : throw std::runtime_error("test failed"); \
53 : }
54 :
55 :
56 : #define ASSERT_NEAR(a, b, ...) \
57 : if (!test::near(a, b, ##__VA_ARGS__)) \
58 : { \
59 : std::cerr << " FAIL: " << a << " not near " << b << " at line " << __LINE__ << "\n"; \
60 : throw std::runtime_error("test failed"); \
61 : }
62 :
63 : #define TEST_SUITE(...) \
64 : int main() \
65 : { \
66 : test::TestCase tests[] = {__VA_ARGS__}; \
67 : int failed = 0; \
68 : for (auto &t : tests) \
69 : { \
70 : std::cout << "[RUN] " << t.name << "..." << std::flush; \
71 : try \
72 : { \
73 : t.func(); \
74 : std::cout << " PASS\n"; \
75 : } \
76 : catch (...) \
77 : { \
78 : std::cout << " FAILED\n"; \
79 : failed++; \
80 : } \
81 : } \
82 : std::cout << "\nResult: " << (sizeof(tests) / sizeof(tests[0]) - failed) << " passed, " << failed << " failed.\n"; \
83 : return failed > 0; \
84 : }
|