Skip to content

Commit eaa4286

Browse files
committed
Add a mod for using DSL types in bindings, scared to test it...
1 parent f0f2e61 commit eaa4286

File tree

1 file changed

+224
-0
lines changed

1 file changed

+224
-0
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.Linq;
5+
using Microsoft.CodeAnalysis;
6+
using Microsoft.CodeAnalysis.CSharp;
7+
using Microsoft.CodeAnalysis.CSharp.Syntax;
8+
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
9+
10+
namespace SilkTouchX.Mods;
11+
12+
/// <summary>
13+
/// Mods the bindings to use the Silk.NET.Core pointer types.
14+
/// </summary>
15+
public class UseSilkDSL : IMod
16+
{
17+
class Rewriter : CSharpSyntaxRewriter
18+
{
19+
private HashSet<string>? _parameterIdentifiers = null;
20+
private bool _doReplacement = false;
21+
private bool _returnTypeReplaceable = false;
22+
23+
public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node)
24+
{
25+
Debug.Assert(!_returnTypeReplaceable);
26+
Debug.Assert(_parameterIdentifiers is null);
27+
28+
// Make sure the function either has a body or is an extern function
29+
var consider = node.Body is not null || node.Modifiers.Any(x => x.IsKind(SyntaxKind.ExternKeyword));
30+
if (!consider)
31+
{
32+
return base.VisitMethodDeclaration(node);
33+
}
34+
35+
// Get the list of DSL applicable parameters
36+
var paramsToChange = node.ParameterList.Parameters
37+
.Where(x => x.Type is not null && IsDSLApplicable(x.Type))
38+
.ToArray();
39+
_parameterIdentifiers = paramsToChange.Select(x => x.Identifier.ToString()).ToHashSet();
40+
_returnTypeReplaceable = IsDSLApplicable(node.ReturnType);
41+
42+
// VisitParameter and VisitIdentifierName will change the parameter types and replace any references of
43+
// the parameter with the "inner identifier" - the name of the variable yielded from the fixed statement
44+
// that we're yet to generate.
45+
if (base.VisitMethodDeclaration(node) is not MethodDeclarationSyntax methWithReplacementsButNoFixed)
46+
{
47+
return null;
48+
}
49+
50+
// If we didn't do any replacements and aren't doing anything to the return type, don't do anything
51+
if (paramsToChange.Length == 0 && !_returnTypeReplaceable)
52+
{
53+
return methWithReplacementsButNoFixed;
54+
}
55+
56+
// If body is null, it would only be because the original body was null which must've meant we passed the
57+
// "extern" check when determining whether to consider this function, ergo we need to make a P/Invoke
58+
// wrapper.
59+
var hasRet = node.ReturnType.IsEquivalentTo(PredefinedType(Token(SyntaxKind.VoidKeyword)));
60+
var body = (StatementSyntax?)methWithReplacementsButNoFixed.Body;
61+
if (body is null)
62+
{
63+
var ident = IdentToPInvokeIdent(node.Identifier);
64+
// Declare the P/Invoke function
65+
var fun = LocalFunctionStatement(
66+
node.AttributeLists,
67+
TokenList(node.Modifiers.Where(x => x.Kind() switch {
68+
SyntaxKind.PublicKeyword or SyntaxKind.PrivateKeyword or SyntaxKind.InternalKeyword
69+
or SyntaxKind.ProtectedKeyword => true,
70+
_ => false
71+
})), node.ReturnType, ident, TypeParameterList(),
72+
node.ParameterList, List<TypeParameterConstraintClauseSyntax>(), node.Body,
73+
node.ExpressionBody);
74+
75+
// Call the P/Invoke function with the converted values
76+
var inv = InvocationExpression(IdentifierName(ident),
77+
ArgumentList(SeparatedList(node.ParameterList.Parameters.Select(x =>
78+
Argument(IdentifierName(x.Type is not null && IsDSLApplicable(x.Type)
79+
? IdentToInnerIdent(x.Identifier)
80+
: x.Identifier))))));
81+
body = Block(fun, hasRet ? ReturnStatement(inv) : ExpressionStatement(inv));
82+
}
83+
84+
// Convert expression bodies to statement bodies
85+
if (body is ExpressionStatementSyntax expr)
86+
{
87+
body = Block(hasRet ? ExpressionStatement(expr.Expression) : ReturnStatement(expr.Expression));
88+
}
89+
90+
// Generate the fixed blocks for the "inner idents"
91+
Debug.Assert(body is BlockSyntax);
92+
foreach (var param in paramsToChange)
93+
{
94+
Debug.Assert(param.Type is not null);
95+
body = Block(FixedStatement(
96+
VariableDeclaration(param.Type,
97+
SingletonSeparatedList(VariableDeclarator(IdentToInnerIdent(param.Identifier))
98+
.WithInitializer(EqualsValueClause(IdentifierName(param.Identifier))))), body));
99+
}
100+
101+
// Need to check on the return type, but assume that there's an implicit conversion in the DSL
102+
if (_returnTypeReplaceable)
103+
{
104+
_returnTypeReplaceable = false;
105+
methWithReplacementsButNoFixed =
106+
methWithReplacementsButNoFixed.WithReturnType(GetDSLType(node.ReturnType, node.AttributeLists,
107+
SyntaxKind.ReturnKeyword));
108+
}
109+
110+
return methWithReplacementsButNoFixed.WithBody((BlockSyntax)body);
111+
}
112+
113+
public override SyntaxNode? VisitParameter(ParameterSyntax node)
114+
{
115+
Debug.Assert(!_doReplacement);
116+
var ret = base.VisitParameter(node) as ParameterSyntax;
117+
if (_doReplacement && ret is { Type: not null })
118+
{
119+
_doReplacement = false;
120+
return ret.WithType(GetDSLType(ret.Type, node.AttributeLists, null));
121+
}
122+
123+
return ret;
124+
}
125+
126+
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
127+
{
128+
var ret = base.VisitIdentifierName(node) as IdentifierNameSyntax;
129+
if (ret is null)
130+
{
131+
return ret;
132+
}
133+
134+
if (!(_parameterIdentifiers?.Contains(node.Identifier.ToString()) ?? false))
135+
{
136+
return ret;
137+
}
138+
139+
if (node.Parent is not ParameterSyntax)
140+
{
141+
return IdentifierName(IdentToInnerIdent(ret.Identifier)).WithTriviaFrom(ret);
142+
}
143+
144+
_doReplacement = true;
145+
return ret;
146+
}
147+
148+
public override SyntaxNode? VisitAttribute(AttributeSyntax node)
149+
{
150+
if ((_parameterIdentifiers?.Count).GetValueOrDefault() == 0 && !_returnTypeReplaceable)
151+
{
152+
return base.VisitAttribute(node);
153+
}
154+
155+
var sep = node.Name.ToString().Split("::")[1];
156+
return sep == "DllImport" || sep == "DllImportAttribute" ||
157+
sep.EndsWith("System.Runtime.InteropServices.DllImport") ||
158+
sep.EndsWith("System.Runtime.InteropServices.DllImportAttribute")
159+
? null // Remove the attribute as it is being moved to a local function
160+
: base.VisitAttribute(node);
161+
}
162+
163+
private static SyntaxToken IdentToInnerIdent(SyntaxToken token)
164+
{
165+
Debug.Assert(token.IsKind(SyntaxKind.IdentifierToken));
166+
return Identifier($"__dsl_{token}");
167+
}
168+
169+
private static SyntaxToken IdentToPInvokeIdent(SyntaxToken token)
170+
{
171+
Debug.Assert(token.IsKind(SyntaxKind.IdentifierToken));
172+
return Identifier($"__DSL_{token}");
173+
}
174+
175+
176+
private static bool IsDSLApplicable(TypeSyntax syn) => syn is PointerTypeSyntax;
177+
178+
private static TypeSyntax GetDSLType(TypeSyntax syntax, IEnumerable<AttributeListSyntax?>? attrLists, SyntaxKind? target)
179+
{
180+
var indirectionLevels = 0;
181+
while (syntax is PointerTypeSyntax inner)
182+
{
183+
indirectionLevels++;
184+
syntax = inner.ElementType;
185+
}
186+
187+
if (indirectionLevels > 2)
188+
{
189+
throw new ArgumentOutOfRangeException(nameof(syntax),
190+
"Indirection levels greater than 2 are currently unsupported by SilkDSL.");
191+
}
192+
193+
var isConst = false;
194+
if (attrLists is not null)
195+
{
196+
foreach (var attrs in attrLists)
197+
{
198+
if (attrs is null ||
199+
(target is not null && !(attrs.Target?.Identifier.IsKind(target.Value)).GetValueOrDefault()) ||
200+
(target is null && attrs.Target is not null))
201+
{
202+
continue;
203+
}
204+
foreach (var attributeSyntax in attrs.Attributes)
205+
{
206+
if (attributeSyntax.Name.ToString() == "NativeTypeName" &&
207+
attributeSyntax.ArgumentList?.Arguments.FirstOrDefault()?.Expression is
208+
LiteralExpressionSyntax lit && lit.Token.ToString().StartsWith("const "))
209+
{
210+
isConst = true;
211+
}
212+
}
213+
}
214+
}
215+
216+
return GenericName(Identifier(isConst switch {
217+
true when indirectionLevels > 1 => $"ConstPtr{indirectionLevels}D",
218+
true => "ConstPtr",
219+
false when indirectionLevels > 1 => $"Ptr{indirectionLevels}D",
220+
false => "Ptr"
221+
})).WithTypeArgumentList(TypeArgumentList(SingletonSeparatedList(syntax)));
222+
}
223+
}
224+
}

0 commit comments

Comments
 (0)