Skip to content

Commit e1d9c23

Browse files
committed
Code review
1 parent 9cf7226 commit e1d9c23

File tree

4 files changed

+71
-69
lines changed

4 files changed

+71
-69
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2207,6 +2207,16 @@ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggm
22072207

22082208
const auto [ir0, ir1] = get_thread_range(params, src0);
22092209

2210+
bool (*bipred)(int, int);
2211+
2212+
switch (ttype) {
2213+
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
2214+
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2215+
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
2216+
case GGML_TRI_TYPE_UPPER_DIAG:
2217+
default: bipred = [](int i, int r) { return i >= r; }; break;
2218+
}
2219+
22102220
for (int64_t ir = ir0; ir < ir1; ++ir) {
22112221
const int64_t i03 = ir/(ne02*ne01);
22122222
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -2215,7 +2225,7 @@ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggm
22152225
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
22162226
float * src_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
22172227

2218-
ggml_vec_tri_f32(ne0, i01, dst_ptr, src_ptr, keep_org_val, c, ttype);
2228+
ggml_vec_tri_f32(ne0, i01, dst_ptr, src_ptr, keep_org_val, c, bipred);
22192229
}
22202230

22212231
}

ggml/src/ggml-cpu/vec.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,18 +1424,10 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
14241424
// src - input array
14251425
// keep_org_val - if true, keep original value where mask applies; otherwise use constant 'c'
14261426
// c - constant value to use when not keeping original value
1427-
// type - type of triangular mask (lower, upper, etc.)
1428-
inline static void ggml_vec_tri_f32(const int n, const int r, float * dst, const float * src, bool keep_org_val, float c, enum ggml_tri_type type) {
1427+
// bipred - the predicate on coordinates, derived from tri_type
1428+
inline static void ggml_vec_tri_f32(const int n, const int r, float * dst, const float * src, bool keep_org_val, float c, bool (*bipred)(int, int)) {
14291429
for (int i = 0; i < n; ++i) {
1430-
bool cmp = false;
1431-
switch (type) {
1432-
case GGML_TRI_TYPE_LOWER: cmp = i < r; break;
1433-
case GGML_TRI_TYPE_LOWER_DIAG: cmp = i <= r; break;
1434-
case GGML_TRI_TYPE_UPPER: cmp = i > r; break;
1435-
case GGML_TRI_TYPE_UPPER_DIAG:
1436-
default: cmp = i >= r; break;
1437-
}
1438-
dst[i] = cmp ? (keep_org_val ? src[i] : c) : 0.0f;
1430+
dst[i] = bipred(i, r) ? (keep_org_val ? src[i] : c) : 0.0f;
14391431
}
14401432
}
14411433

ggml/src/ggml.c

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5082,6 +5082,9 @@ struct ggml_tensor * ggml_tri(
50825082
float constant,
50835083
enum ggml_tri_type tritype) {
50845084

5085+
GGML_ASSERT(ggml_is_contiguous(a));
5086+
GGML_ASSERT(a->ne[0] == a->ne[1]);
5087+
50855088
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
50865089

50875090
ggml_set_op_params_i32(result, 0, tritype);
@@ -5956,6 +5959,7 @@ struct ggml_tensor * ggml_opt_step_sgd(
59565959
}
59575960

59585961
// solve_tri
5962+
59595963
struct ggml_tensor * ggml_solve_tri(
59605964
struct ggml_context * ctx,
59615965
struct ggml_tensor * a,
@@ -5966,9 +5970,9 @@ struct ggml_tensor * ggml_solve_tri(
59665970
// B must have same outer dimension as A
59675971
GGML_ASSERT(a->ne[1] == b->ne[1]);
59685972

5969-
// B must be broadcastable to A
5970-
GGML_ASSERT(a->ne[2] % b->ne[2] == 0);
5971-
GGML_ASSERT(a->ne[3] % b->ne[3] == 0);
5973+
// batch dimensions must be equal
5974+
GGML_ASSERT(a->ne[2] == b->ne[2]);
5975+
GGML_ASSERT(a->ne[3] == b->ne[3]);
59725976

59735977
GGML_ASSERT(ggml_is_contiguous(a));
59745978
GGML_ASSERT(ggml_is_contiguous(b));
@@ -6565,12 +6569,12 @@ static void ggml_compute_backward(
65656569
struct ggml_tensor * neg_src0 = ggml_neg(ctx, src0);
65666570
struct ggml_tensor * exp_neg = ggml_exp(ctx, neg_src0);
65676571
struct ggml_tensor * ones =
6568-
ggml_exp(ctx, ggml_new_tensor_4d(ctx, src0->type, src0->ne[0], src0->ne[1], src0->ne[2],
6569-
src0->ne[3]));
6572+
ggml_scale_bias(ctx, ggml_new_tensor_4d(ctx, src0->type, src0->ne[0], src0->ne[1], src0->ne[2],
6573+
src0->ne[3]), 0.0f, 1.0f);
65706574
struct ggml_tensor * one_plus_exp = ggml_add(ctx, ones, exp_neg);
65716575
struct ggml_tensor * sigmoid = ggml_div(ctx, ones, one_plus_exp);
65726576
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, sigmoid));
6573-
}
6577+
}
65746578
} break;
65756579
default: {
65766580
fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",

tests/test-backend-ops.cpp

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,11 @@ static void init_tensor_causal(ggml_tensor * tensor, float min = -1.0f, float ma
188188
std::mt19937 gen(rd());
189189
std::uniform_real_distribution<float> dis(min, max);
190190

191-
for (int64_t i0 = 0; i0 < tensor->ne[0]; i0++) {
192-
for (int64_t i1 = 0; i1 < tensor->ne[1]; i1++) {
193-
for (int64_t i2 = 0; i2 < tensor->ne[2]; i2++) {
194-
for (int64_t i3 = 0; i3 < tensor->ne[3]; i3++) {
195-
int64_t idx = i0 * tensor->nb[0] / sizeof(float) + i1 * tensor->nb[1] / sizeof(float) +
196-
i2 * tensor->nb[2] / sizeof(float) + i3 * tensor->nb[3] / sizeof(float);
191+
for (int64_t i3 = 0; i3 < tensor->ne[3]; i3++) {
192+
for (int64_t i2 = 0; i2 < tensor->ne[2]; i2++) {
193+
for (int64_t i1 = 0; i2 < tensor->ne[1]; i1++) {
194+
for (int64_t i0 = 0; i0 < tensor->ne[0]; i0++) {
195+
int64_t idx = (i0 * tensor->nb[0] + i1 * tensor->nb[1] + i2 * tensor->nb[2] + i3 * tensor->nb[3]) / sizeof(float);
197196
if (i0 <= i1) {
198197
data_f32[idx] = dis(gen);
199198
} else {
@@ -4785,7 +4784,6 @@ struct test_argsort : public test_case {
47854784
}
47864785
};
47874786

4788-
// GGML_OP_TOPK_MOE
47894787
struct test_topk_moe: public test_case {
47904788
const std::array<int64_t, 4> ne;
47914789
const int n_expert_used;
@@ -4843,7 +4841,6 @@ struct test_topk_moe: public test_case {
48434841
}
48444842
};
48454843

4846-
// GGML_MOE_EXPERT_REDUCE
48474844
struct test_moe_expert_reduce : public test_case {
48484845
const int64_t n_embd;
48494846
const int64_t n_tokens;
@@ -5349,7 +5346,7 @@ struct test_pad : public test_case {
53495346
}
53505347
};
53515348

5352-
// GGML_OP_EXT
5349+
// GGML_OP_PAD (with extension)
53535350
struct test_pad_ext : public test_case {
53545351
const ggml_type type;
53555352
const std::array<int64_t, 4> ne_a;
@@ -5797,49 +5794,53 @@ struct test_opt_step_sgd : public test_case {
57975794
}
57985795
};
57995796

5800-
// GGML_OP_ADD
5801-
// GGML_OP_SUB
5802-
// GGML_OP_DIV
5803-
// GGML_OP_MUL
5804-
struct test_op_arith : public test_case {
5797+
// GGML_OP_CUMSUM
5798+
struct test_cumsum : public test_case {
58055799
const ggml_type type;
58065800
const std::array<int64_t, 4> ne;
5807-
const ggml_op op;
58085801

5809-
std::string vars() override { return VARS_TO_STR3(type, ne, op); }
5802+
std::string vars() override { return VARS_TO_STR2(type, ne); }
58105803

5811-
test_op_arith(ggml_op op, ggml_type type = GGML_TYPE_F32,
5804+
test_cumsum(ggml_type type = GGML_TYPE_F32,
58125805
std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
5813-
: type(type), ne(ne), op(op) {
5814-
GGML_ASSERT(op == GGML_OP_ADD || op == GGML_OP_SUB || op == GGML_OP_DIV || op == GGML_OP_MUL);
5815-
}
5806+
: type(type), ne(ne) {}
58165807

5817-
ggml_tensor * build_graph(ggml_context * ctx) override {
5808+
ggml_tensor * build_graph(ggml_context * ctx) override {
58185809
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
58195810
ggml_set_param(a);
58205811
ggml_set_name(a, "a");
58215812

5822-
ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
5823-
ggml_set_name(b, "b");
5813+
ggml_tensor * out = ggml_cumsum(ctx, a);
58245814

5825-
ggml_tensor * out;
5815+
ggml_set_name(out, "out");
58265816

5827-
switch (op) {
5828-
case GGML_OP_ADD:
5829-
out = ggml_add(ctx, a, b);
5830-
break;
5831-
case GGML_OP_SUB:
5832-
out = ggml_sub(ctx, a, b);
5833-
break;
5834-
case GGML_OP_DIV:
5835-
out = ggml_div(ctx, a, b);
5836-
break;
5837-
case GGML_OP_MUL:
5838-
out = ggml_mul(ctx, a, b);
5839-
break;
5840-
default:
5841-
GGML_ABORT("This test only supports ADD, SUB, DIV and MUL");
5817+
return out;
5818+
}
5819+
5820+
void initialize_tensors(ggml_context * ctx) override {
5821+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
5822+
init_tensor_uniform(t, -1.0f, 1.0f);
58425823
}
5824+
}
5825+
};
5826+
5827+
// GGML_OP_EXPM1
5828+
struct test_expm1 : public test_case {
5829+
const ggml_type type;
5830+
const std::array<int64_t, 4> ne;
5831+
5832+
std::string vars() override { return VARS_TO_STR2(type, ne); }
5833+
5834+
test_expm1(ggml_type type = GGML_TYPE_F32,
5835+
std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
5836+
: type(type), ne(ne) {}
5837+
5838+
ggml_tensor * build_graph(ggml_context * ctx) override {
5839+
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
5840+
ggml_set_param(a);
5841+
ggml_set_name(a, "a");
5842+
5843+
ggml_tensor * out = ggml_expm1(ctx, a);
58435844

58445845
ggml_set_name(out, "out");
58455846

@@ -5848,20 +5849,19 @@ struct test_op_arith : public test_case {
58485849

58495850
void initialize_tensors(ggml_context * ctx) override {
58505851
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
5851-
init_tensor_uniform(t, 0.1f, 1.0f); // no zeroes because div might complain
5852+
init_tensor_uniform(t, -1.0f, 1.0f);
58525853
}
58535854
}
5854-
58555855
};
58565856

5857-
// GGML_OP_CUMSUM
5858-
struct test_cumsum : public test_case {
5857+
// GGML_OP_SOFTPLUS
5858+
struct test_softplus : public test_case {
58595859
const ggml_type type;
58605860
const std::array<int64_t, 4> ne;
58615861

58625862
std::string vars() override { return VARS_TO_STR2(type, ne); }
58635863

5864-
test_cumsum(ggml_type type = GGML_TYPE_F32,
5864+
test_softplus(ggml_type type = GGML_TYPE_F32,
58655865
std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
58665866
: type(type), ne(ne) {}
58675867

@@ -5870,7 +5870,7 @@ struct test_cumsum : public test_case {
58705870
ggml_set_param(a);
58715871
ggml_set_name(a, "a");
58725872

5873-
ggml_tensor * out = ggml_cumsum(ctx, a);
5873+
ggml_tensor * out = ggml_softplus(ctx, a);
58745874

58755875
ggml_set_name(out, "out");
58765876

@@ -7256,6 +7256,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
72567256
test_cases.emplace_back(new test_ceil (type));
72577257
test_cases.emplace_back(new test_round (type));
72587258
test_cases.emplace_back(new test_trunc (type));
7259+
test_cases.emplace_back(new test_expm1 (type));
7260+
test_cases.emplace_back(new test_softplus (type));
72597261
test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3}));
72607262
test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3}));
72617263
test_cases.emplace_back(new test_log (type, {7, 1, 5, 3}));
@@ -7269,12 +7271,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
72697271
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
72707272
}
72717273

7272-
// basic arithmetic, have to do them manually now that fusion is not supported
7273-
test_cases.emplace_back(new test_op_arith(GGML_OP_ADD, GGML_TYPE_F32));
7274-
test_cases.emplace_back(new test_op_arith(GGML_OP_SUB, GGML_TYPE_F32));
7275-
test_cases.emplace_back(new test_op_arith(GGML_OP_DIV, GGML_TYPE_F32));
7276-
test_cases.emplace_back(new test_op_arith(GGML_OP_MUL, GGML_TYPE_F32));
7277-
72787274
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
72797275
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5));
72807276
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 2}, 5));

0 commit comments

Comments
 (0)