Skip to content

Commit 6325e3c

Browse files
authored
[parallelisation] Define a transformation group (#693)
<!-- Copyright (C) 2020-2022 Arm Limited or its affiliates and Contributors. All rights reserved. SPDX-License-Identifier: Apache-2.0 --> ### Description - Add a transform group - reuse groups to define worker pools ### Test Coverage <!-- Please put an `x` in the correct box e.g. `[x]` to indicate the testing coverage of this change. --> - [x] This change is covered by existing or additional automated tests. - [ ] Manual testing has been performed (and evidence provided) as automated testing was not feasible. - [ ] Additional tests are not required for this change (e.g. documentation update).
1 parent 36cad0d commit 6325e3c

File tree

11 files changed

+416
-64
lines changed

11 files changed

+416
-64
lines changed

changes/20250905171217.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: `[parallelisation]` Define a transformation group

changes/20250908111211.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: `[collection]` Added a `Range` function to populate slices of integers

utils/collection/range.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package collection
2+
3+
import "github.com/ARM-software/golang-utils/utils/field"
4+
5+
func sign(x int) int {
6+
if x < 0 {
7+
return -1
8+
}
9+
return 1
10+
}
11+
12+
// Range returns a slice of integers similar to Python's built-in range().
13+
// https://docs.python.org/2/library/functions.html#range
14+
//
15+
// Note: The stop value is always exclusive.
16+
func Range(start, stop int, step *int) []int {
17+
s := field.OptionalInt(step, 1)
18+
if s == 0 {
19+
return []int{}
20+
}
21+
22+
// Compute length
23+
length := 0
24+
if (s > 0 && start < stop) || (s < 0 && start > stop) {
25+
length = (stop - start + s - sign(s)) / s
26+
}
27+
28+
result := make([]int, length)
29+
for i, v := 0, start; i < length; i, v = i+1, v+s {
30+
result[i] = v
31+
}
32+
return result
33+
}

utils/collection/range_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package collection
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
9+
"github.com/ARM-software/golang-utils/utils/field"
10+
)
11+
12+
func TestRange(t *testing.T) {
13+
tests := []struct {
14+
start int
15+
stop int
16+
step *int
17+
expected []int
18+
}{
19+
20+
{2, 5, nil, []int{2, 3, 4}},
21+
{5, 2, nil, []int{}}, // empty, since stop < start
22+
{2, 10, field.ToOptionalInt(2), []int{2, 4, 6, 8}},
23+
{0, 10, field.ToOptionalInt(3), []int{0, 3, 6, 9}},
24+
{1, 10, field.ToOptionalInt(3), []int{1, 4, 7}},
25+
{10, 2, field.ToOptionalInt(-2), []int{10, 8, 6, 4}},
26+
{5, -1, field.ToOptionalInt(-1), []int{5, 4, 3, 2, 1, 0}},
27+
{0, -5, field.ToOptionalInt(-2), []int{0, -2, -4}},
28+
{0, 5, nil, []int{0, 1, 2, 3, 4}},
29+
{0, 5, field.ToOptionalInt(0), []int{}},
30+
{2, 2, field.ToOptionalInt(1), []int{}},
31+
{2, 2, field.ToOptionalInt(-1), []int{}},
32+
}
33+
34+
for i := range tests {
35+
test := tests[i]
36+
t.Run(fmt.Sprintf("[%v,%v,%v]", test.start, test.stop, test.step), func(t *testing.T) {
37+
assert.Equal(t, test.expected, Range(test.start, test.stop, test.step))
38+
})
39+
}
40+
}

utils/collection/search.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ func AnyFunc[S ~[]E, E any](s S, f func(E) bool) bool {
7171
return conditions.Any()
7272
}
7373

74+
type FilterFunc[E any] func(E) bool
75+
7476
// Filter returns a new slice that contains elements from the input slice which return true when they’re passed as a parameter to the provided filtering function f.
75-
func Filter[S ~[]E, E any](s S, f func(E) bool) (result S) {
77+
func Filter[S ~[]E, E any](s S, f FilterFunc[E]) (result S) {
7678
result = make(S, 0, len(s))
7779

7880
for i := range s {
@@ -84,8 +86,16 @@ func Filter[S ~[]E, E any](s S, f func(E) bool) (result S) {
8486
return result
8587
}
8688

89+
type MapFunc[T1, T2 any] func(T1) T2
90+
91+
func IdentityMapFunc[T any]() MapFunc[T, T] {
92+
return func(i T) T {
93+
return i
94+
}
95+
}
96+
8797
// Map creates a new slice and populates it with the results of calling the provided function on every element in input slice.
88-
func Map[T1 any, T2 any](s []T1, f func(T1) T2) (result []T2) {
98+
func Map[T1 any, T2 any](s []T1, f MapFunc[T1, T2]) (result []T2) {
8999
result = make([]T2, len(s))
90100

91101
for i := range s {
@@ -97,12 +107,14 @@ func Map[T1 any, T2 any](s []T1, f func(T1) T2) (result []T2) {
97107

98108
// Reject is the opposite of Filter and returns the elements of collection for which the filtering function f returns false.
99109
// This is functionally equivalent to slices.DeleteFunc but it returns a new slice.
100-
func Reject[S ~[]E, E any](s S, f func(E) bool) S {
110+
func Reject[S ~[]E, E any](s S, f FilterFunc[E]) S {
101111
return Filter(s, func(e E) bool { return !f(e) })
102112
}
103113

114+
type ReduceFunc[T1, T2 any] func(T2, T1) T2
115+
104116
// Reduce runs a reducer function f over all elements in the array, in ascending-index order, and accumulates them into a single value.
105-
func Reduce[T1, T2 any](s []T1, accumulator T2, f func(T2, T1) T2) (result T2) {
117+
func Reduce[T1, T2 any](s []T1, accumulator T2, f ReduceFunc[T1, T2]) (result T2) {
106118
result = accumulator
107119
for i := range s {
108120
result = f(result, s[i])

utils/parallelisation/group.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,14 @@ type ICompoundExecutionGroup[T any] interface {
224224
// NewExecutionGroup returns an execution group which executes functions according to store options.
225225
func NewExecutionGroup[T any](executeFunc ExecuteFunc[T], options ...StoreOption) *ExecutionGroup[T] {
226226

227+
return NewOrderedExecutionGroup(func(ctx context.Context, index int, element T) error {
228+
return executeFunc(ctx, element)
229+
}, options...)
230+
}
231+
232+
// NewOrderedExecutionGroup returns an execution group which executes functions according to store options. It also keeps track of the input index.
233+
func NewOrderedExecutionGroup[T any](executeFunc OrderedExecuteFunc[T], options ...StoreOption) *ExecutionGroup[T] {
234+
227235
opts := WithOptions(options...)
228236
return &ExecutionGroup[T]{
229237
mu: deadlock.RWMutex{},
@@ -235,10 +243,12 @@ func NewExecutionGroup[T any](executeFunc ExecuteFunc[T], options ...StoreOption
235243

236244
type ExecuteFunc[T any] func(ctx context.Context, element T) error
237245

246+
type OrderedExecuteFunc[T any] func(ctx context.Context, index int, element T) error
247+
238248
type ExecutionGroup[T any] struct {
239249
mu deadlock.RWMutex
240250
functions []wrappedElement[T]
241-
executeFunc ExecuteFunc[T]
251+
executeFunc OrderedExecuteFunc[T]
242252
options StoreOptions
243253
}
244254

@@ -294,7 +304,7 @@ func (s *ExecutionGroup[T]) executeConcurrently(ctx context.Context, stopOnFirst
294304
g.SetLimit(workers)
295305
for i := range s.functions {
296306
g.Go(func() error {
297-
_, subErr := s.executeFunction(gCtx, s.functions[i])
307+
_, subErr := s.executeFunction(gCtx, i, s.functions[i])
298308
errCh <- subErr
299309
return subErr
300310
})
@@ -323,7 +333,7 @@ func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirst
323333
collateErr := make([]error, funcNum)
324334
if reverse {
325335
for i := funcNum - 1; i >= 0; i-- {
326-
shouldBreak, subErr := s.executeFunction(ctx, s.functions[i])
336+
shouldBreak, subErr := s.executeFunction(ctx, i, s.functions[i])
327337
collateErr[funcNum-i-1] = subErr
328338
if shouldBreak {
329339
err = subErr
@@ -338,7 +348,7 @@ func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirst
338348
}
339349
} else {
340350
for i := range s.functions {
341-
shouldBreak, subErr := s.executeFunction(ctx, s.functions[i])
351+
shouldBreak, subErr := s.executeFunction(ctx, i, s.functions[i])
342352
collateErr[i] = subErr
343353
if shouldBreak {
344354
err = subErr
@@ -359,7 +369,7 @@ func (s *ExecutionGroup[T]) executeSequentially(ctx context.Context, stopOnFirst
359369
return
360370
}
361371

362-
func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, w wrappedElement[T]) (mustBreak bool, err error) {
372+
func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, index int, w wrappedElement[T]) (mustBreak bool, err error) {
363373
err = DetermineContextError(ctx)
364374
if err != nil {
365375
mustBreak = true
@@ -370,7 +380,9 @@ func (s *ExecutionGroup[T]) executeFunction(ctx context.Context, w wrappedElemen
370380
mustBreak = true
371381
return
372382
}
373-
err = w.Execute(ctx, s.executeFunc)
383+
err = w.Execute(ctx, func(ctx context.Context, element T) error {
384+
return s.executeFunc(ctx, index, element)
385+
})
374386

375387
return
376388
}

utils/parallelisation/parallelisation.go

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ import (
1212
"time"
1313

1414
"go.uber.org/atomic"
15-
"golang.org/x/sync/errgroup"
1615

16+
"github.com/ARM-software/golang-utils/utils/collection"
1717
"github.com/ARM-software/golang-utils/utils/commonerrors"
1818
)
1919

@@ -265,69 +265,35 @@ func WaitUntil(ctx context.Context, evalCondition func(ctx2 context.Context) (bo
265265
}
266266
}
267267

268-
func newWorker[JobType, ResultType any](ctx context.Context, f func(context.Context, JobType) (ResultType, bool, error), jobs chan JobType, results chan ResultType) (err error) {
269-
for job := range jobs {
270-
result, ok, subErr := f(ctx, job)
271-
if subErr != nil {
272-
err = commonerrors.WrapError(commonerrors.ErrUnexpected, subErr, "an error occurred whilst handling a job")
273-
return
274-
}
275-
276-
err = DetermineContextError(ctx)
277-
if err != nil {
278-
return
279-
}
280-
281-
if ok {
282-
results <- result
283-
}
268+
// WorkerPool parallelises an action using a worker pool of the size provided by numWorkers and retrieves all the results when all the actions have completed. It is similar to Parallelise but it uses generics instead of reflection and allows you to control the pool size
269+
func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, jobs []InputType, f TransformFunc[InputType, ResultType]) (results []ResultType, err error) {
270+
g, err := workerPoolGroup[InputType, ResultType](ctx, numWorkers, jobs, f)
271+
if err != nil {
272+
return
284273
}
285-
274+
results, err = g.Outputs(ctx)
286275
return
287276
}
288277

289-
// WorkerPool parallelises an action using a worker pool of the size provided by numWorkers and retrieves all the results when all the actions have completed. It is similar to Parallelise but it uses generics instead of reflection and allows you to control the pool size
290-
func WorkerPool[InputType, ResultType any](ctx context.Context, numWorkers int, jobs []InputType, f func(context.Context, InputType) (ResultType, bool, error)) (results []ResultType, err error) {
278+
func workerPoolGroup[I, O any](ctx context.Context, numWorkers int, jobs []I, f TransformFunc[I, O]) (g *TransformGroup[I, O], err error) {
291279
if numWorkers < 1 {
292280
err = commonerrors.New(commonerrors.ErrInvalid, "numWorkers must be greater than or equal to 1")
293281
return
294282
}
295-
296-
numJobs := len(jobs)
297-
jobsChan := make(chan InputType, numJobs)
298-
resultsChan := make(chan ResultType, numJobs)
299-
300-
g, gCtx := errgroup.WithContext(ctx)
301-
g.SetLimit(numWorkers)
302-
for range numWorkers {
303-
g.Go(func() error { return newWorker(gCtx, f, jobsChan, resultsChan) })
304-
}
305-
for i := range jobs {
306-
if DetermineContextError(ctx) != nil {
307-
break
308-
}
309-
jobsChan <- jobs[i]
310-
}
311-
312-
close(jobsChan)
313-
err = g.Wait()
314-
close(resultsChan)
315-
if err == nil {
316-
err = DetermineContextError(ctx)
317-
}
283+
g = NewTransformGroup[I, O](f, Workers(numWorkers), JoinErrors)
284+
err = g.Inputs(ctx, jobs...)
318285
if err != nil {
319286
return
320287
}
321-
322-
for result := range resultsChan {
323-
results = append(results, result)
288+
err = g.Transform(ctx)
289+
if err != nil {
290+
return
324291
}
325-
326292
return
327293
}
328294

329295
// Filter is similar to collection.Filter but uses parallelisation.
330-
func Filter[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) (result []T, err error) {
296+
func Filter[T any](ctx context.Context, numWorkers int, s []T, f collection.FilterFunc[T]) (result []T, err error) {
331297
result, err = WorkerPool[T, T](ctx, numWorkers, s, func(fCtx context.Context, item T) (r T, ok bool, fErr error) {
332298
fErr = DetermineContextError(fCtx)
333299
if fErr != nil {
@@ -340,9 +306,8 @@ func Filter[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) (
340306
return
341307
}
342308

343-
// Map is similar to collection.Map but uses parallelisation.
344-
func Map[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f func(T1) T2) (result []T2, err error) {
345-
result, err = WorkerPool[T1, T2](ctx, numWorkers, s, func(fCtx context.Context, item T1) (r T2, ok bool, fErr error) {
309+
func mapGroup[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f collection.MapFunc[T1, T2]) (*TransformGroup[T1, T2], error) {
310+
return workerPoolGroup[T1, T2](ctx, numWorkers, s, func(fCtx context.Context, item T1) (r T2, ok bool, fErr error) {
346311
fErr = DetermineContextError(fCtx)
347312
if fErr != nil {
348313
return
@@ -351,10 +316,29 @@ func Map[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f func(T1)
351316
ok = true
352317
return
353318
})
319+
}
320+
321+
// Map is similar to collection.Map but uses parallelisation.
322+
func Map[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f collection.MapFunc[T1, T2]) (result []T2, err error) {
323+
g, err := mapGroup[T1, T2](ctx, numWorkers, s, f)
324+
if err != nil {
325+
return
326+
}
327+
result, err = g.Outputs(ctx)
328+
return
329+
}
330+
331+
// OrderedMap is similar to Map but ensures the results are in the same order as the input.
332+
func OrderedMap[T1 any, T2 any](ctx context.Context, numWorkers int, s []T1, f collection.MapFunc[T1, T2]) (result []T2, err error) {
333+
g, err := mapGroup[T1, T2](ctx, numWorkers, s, f)
334+
if err != nil {
335+
return
336+
}
337+
result, err = g.OrderedOutputs(ctx)
354338
return
355339
}
356340

357341
// Reject is the opposite of Filter and returns the elements of collection for which the filtering function f returns false.
358-
func Reject[T any](ctx context.Context, numWorkers int, s []T, f func(T) bool) ([]T, error) {
342+
func Reject[T any](ctx context.Context, numWorkers int, s []T, f collection.FilterFunc[T]) ([]T, error) {
359343
return Filter[T](ctx, numWorkers, s, func(e T) bool { return !f(e) })
360344
}

utils/parallelisation/parallelisation_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ import (
1919
"go.uber.org/atomic"
2020
"go.uber.org/goleak"
2121

22+
"github.com/ARM-software/golang-utils/utils/collection"
2223
"github.com/ARM-software/golang-utils/utils/commonerrors"
2324
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
25+
"github.com/ARM-software/golang-utils/utils/field"
2426
)
2527

2628
var (
@@ -636,3 +638,35 @@ func TestMap(t *testing.T) {
636638
errortest.AssertError(t, err, commonerrors.ErrCancelled)
637639
})
638640
}
641+
642+
func TestMapAndOrderedMap(t *testing.T) {
643+
defer goleak.VerifyNone(t)
644+
ctx := context.Background()
645+
mapped, err := OrderedMap(ctx, 3, []int{1, 2}, func(i int) string {
646+
return fmt.Sprintf("Hello world %v", i)
647+
})
648+
require.NoError(t, err)
649+
assert.Equal(t, []string{"Hello world 1", "Hello world 2"}, mapped)
650+
mapped, err = OrderedMap(ctx, 3, []int64{1, 2, 3, 4}, func(x int64) string {
651+
return strconv.FormatInt(x, 10)
652+
})
653+
require.NoError(t, err)
654+
assert.Equal(t, []string{"1", "2", "3", "4"}, mapped)
655+
t.Run("cancelled context", func(t *testing.T) {
656+
cancelledCtx, cancel := context.WithCancel(context.Background())
657+
cancel()
658+
_, err := Map(cancelledCtx, 3, []int{1, 2}, func(i int) string {
659+
return fmt.Sprintf("Hello world %v", i)
660+
})
661+
errortest.AssertError(t, err, commonerrors.ErrCancelled)
662+
})
663+
664+
in := collection.Range(0, 1000, field.ToOptionalInt(5))
665+
mappedInt, err := OrderedMap(ctx, 3, in, collection.IdentityMapFunc[int]())
666+
require.NoError(t, err)
667+
assert.Equal(t, in, mappedInt)
668+
mappedInt, err = Map(ctx, 3, in, collection.IdentityMapFunc[int]())
669+
require.NoError(t, err)
670+
assert.NotEqual(t, in, mappedInt)
671+
assert.ElementsMatch(t, in, mappedInt)
672+
}

0 commit comments

Comments
 (0)