Skip to content

Commit 5e6677f

Browse files
Add test and fix bug
1 parent 6b5ff5e commit 5e6677f

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) "Neo4j"
2+
// Neo4j Sweden AB [https://neo4j.com]
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License").
5+
// You may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
using System.Collections.Generic;
17+
using System.Linq;
18+
using System.Threading.Tasks;
19+
using Neo4j.Driver.Internal.Mapping;
20+
using Xunit;
21+
using Xunit.Abstractions;
22+
23+
namespace Neo4j.Driver.Tests.Mapping;
24+
25+
public class MappingConcurrencyTests(ITestOutputHelper testOutputHelper)
26+
{
27+
private interface ITestTask
28+
{
29+
Task Start();
30+
}
31+
32+
private class TestTask<T> : ITestTask
33+
{
34+
public Task Start()
35+
{
36+
return Task.Run(
37+
() =>
38+
{
39+
for (var i = 0; i < 50; i++)
40+
{
41+
DefaultMapper.Get<T>();
42+
DefaultMapper.Reset();
43+
}
44+
});
45+
}
46+
}
47+
48+
private record DummyType1(string Name, int Age);
49+
private record DummyType2(string Name, int Age);
50+
private record DummyType3(string Name, int Age);
51+
private record DummyType4(string Name, int Age);
52+
53+
[Fact]
54+
public async void DefaultMapperShouldBeThreadSafe()
55+
{
56+
List<ITestTask> threads =
57+
[
58+
new TestTask<DummyType1>(),
59+
new TestTask<DummyType2>(),
60+
new TestTask<DummyType3>(),
61+
new TestTask<DummyType4>()
62+
];
63+
64+
// wait for all threads to finish
65+
await Task.WhenAll(threads.Select(t => t.Start()));
66+
67+
testOutputHelper.WriteLine("All threads finished.");
68+
}
69+
}

Neo4j.Driver/Neo4j.Driver/Internal/Mapping/DefaultMapper.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
// limitations under the License.
1515

1616
using System;
17+
using System.Collections.Concurrent;
1718
using System.Collections.Generic;
1819
using System.Reflection;
1920
using Neo4j.Driver.Mapping;
@@ -22,7 +23,7 @@ namespace Neo4j.Driver.Internal.Mapping;
2223

2324
internal static class DefaultMapper
2425
{
25-
private static readonly Dictionary<Type, object> Mappers = new();
26+
private static readonly ConcurrentDictionary<Type, object> Mappers = new();
2627

2728
public static void Reset()
2829
{
@@ -75,7 +76,7 @@ public static IRecordMapper<T> Get<T>(HashSet<MethodInfo> mappedSetters = null)
7576
mapper = mappingBuilder.Build();
7677

7778
// cache the mapper for future use
78-
Mappers[type] = mapper;
79+
Mappers.TryAdd(type, mapper);
7980
return (IRecordMapper<T>)mapper;
8081
}
8182

0 commit comments

Comments
 (0)