|
| 1 | +using System; |
| 2 | +using System.Collections.Generic; |
| 3 | +using System.Diagnostics; |
| 4 | +using System.Linq; |
| 5 | +using Microsoft.CodeAnalysis; |
| 6 | +using Microsoft.CodeAnalysis.CSharp; |
| 7 | +using Microsoft.CodeAnalysis.CSharp.Syntax; |
| 8 | +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; |
| 9 | + |
| 10 | +namespace SilkTouchX.Mods; |
| 11 | + |
| 12 | +/// <summary> |
| 13 | +/// Mods the bindings to use the Silk.NET.Core pointer types. |
| 14 | +/// </summary> |
| 15 | +public class UseSilkDSL : IMod |
| 16 | +{ |
| 17 | + class Rewriter : CSharpSyntaxRewriter |
| 18 | + { |
| 19 | + private HashSet<string>? _parameterIdentifiers = null; |
| 20 | + private bool _doReplacement = false; |
| 21 | + private bool _returnTypeReplaceable = false; |
| 22 | + |
| 23 | + public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node) |
| 24 | + { |
| 25 | + Debug.Assert(!_returnTypeReplaceable); |
| 26 | + Debug.Assert(_parameterIdentifiers is null); |
| 27 | + |
| 28 | + // Make sure the function either has a body or is an extern function |
| 29 | + var consider = node.Body is not null || node.Modifiers.Any(x => x.IsKind(SyntaxKind.ExternKeyword)); |
| 30 | + if (!consider) |
| 31 | + { |
| 32 | + return base.VisitMethodDeclaration(node); |
| 33 | + } |
| 34 | + |
| 35 | + // Get the list of DSL applicable parameters |
| 36 | + var paramsToChange = node.ParameterList.Parameters |
| 37 | + .Where(x => x.Type is not null && IsDSLApplicable(x.Type)) |
| 38 | + .ToArray(); |
| 39 | + _parameterIdentifiers = paramsToChange.Select(x => x.Identifier.ToString()).ToHashSet(); |
| 40 | + _returnTypeReplaceable = IsDSLApplicable(node.ReturnType); |
| 41 | + |
| 42 | + // VisitParameter and VisitIdentifierName will change the parameter types and replace any references of |
| 43 | + // the parameter with the "inner identifier" - the name of the variable yielded from the fixed statement |
| 44 | + // that we're yet to generate. |
| 45 | + if (base.VisitMethodDeclaration(node) is not MethodDeclarationSyntax methWithReplacementsButNoFixed) |
| 46 | + { |
| 47 | + return null; |
| 48 | + } |
| 49 | + |
| 50 | + // If we didn't do any replacements and aren't doing anything to the return type, don't do anything |
| 51 | + if (paramsToChange.Length == 0 && !_returnTypeReplaceable) |
| 52 | + { |
| 53 | + return methWithReplacementsButNoFixed; |
| 54 | + } |
| 55 | + |
| 56 | + // If body is null, it would only be because the original body was null which must've meant we passed the |
| 57 | + // "extern" check when determining whether to consider this function, ergo we need to make a P/Invoke |
| 58 | + // wrapper. |
| 59 | + var hasRet = node.ReturnType.IsEquivalentTo(PredefinedType(Token(SyntaxKind.VoidKeyword))); |
| 60 | + var body = (StatementSyntax?)methWithReplacementsButNoFixed.Body; |
| 61 | + if (body is null) |
| 62 | + { |
| 63 | + var ident = IdentToPInvokeIdent(node.Identifier); |
| 64 | + // Declare the P/Invoke function |
| 65 | + var fun = LocalFunctionStatement( |
| 66 | + node.AttributeLists, |
| 67 | + TokenList(node.Modifiers.Where(x => x.Kind() switch { |
| 68 | + SyntaxKind.PublicKeyword or SyntaxKind.PrivateKeyword or SyntaxKind.InternalKeyword |
| 69 | + or SyntaxKind.ProtectedKeyword => true, |
| 70 | + _ => false |
| 71 | + })), node.ReturnType, ident, TypeParameterList(), |
| 72 | + node.ParameterList, List<TypeParameterConstraintClauseSyntax>(), node.Body, |
| 73 | + node.ExpressionBody); |
| 74 | + |
| 75 | + // Call the P/Invoke function with the converted values |
| 76 | + var inv = InvocationExpression(IdentifierName(ident), |
| 77 | + ArgumentList(SeparatedList(node.ParameterList.Parameters.Select(x => |
| 78 | + Argument(IdentifierName(x.Type is not null && IsDSLApplicable(x.Type) |
| 79 | + ? IdentToInnerIdent(x.Identifier) |
| 80 | + : x.Identifier)))))); |
| 81 | + body = Block(fun, hasRet ? ReturnStatement(inv) : ExpressionStatement(inv)); |
| 82 | + } |
| 83 | + |
| 84 | + // Convert expression bodies to statement bodies |
| 85 | + if (body is ExpressionStatementSyntax expr) |
| 86 | + { |
| 87 | + body = Block(hasRet ? ExpressionStatement(expr.Expression) : ReturnStatement(expr.Expression)); |
| 88 | + } |
| 89 | + |
| 90 | + // Generate the fixed blocks for the "inner idents" |
| 91 | + Debug.Assert(body is BlockSyntax); |
| 92 | + foreach (var param in paramsToChange) |
| 93 | + { |
| 94 | + Debug.Assert(param.Type is not null); |
| 95 | + body = Block(FixedStatement( |
| 96 | + VariableDeclaration(param.Type, |
| 97 | + SingletonSeparatedList(VariableDeclarator(IdentToInnerIdent(param.Identifier)) |
| 98 | + .WithInitializer(EqualsValueClause(IdentifierName(param.Identifier))))), body)); |
| 99 | + } |
| 100 | + |
| 101 | + // Need to check on the return type, but assume that there's an implicit conversion in the DSL |
| 102 | + if (_returnTypeReplaceable) |
| 103 | + { |
| 104 | + _returnTypeReplaceable = false; |
| 105 | + methWithReplacementsButNoFixed = |
| 106 | + methWithReplacementsButNoFixed.WithReturnType(GetDSLType(node.ReturnType, node.AttributeLists, |
| 107 | + SyntaxKind.ReturnKeyword)); |
| 108 | + } |
| 109 | + |
| 110 | + return methWithReplacementsButNoFixed.WithBody((BlockSyntax)body); |
| 111 | + } |
| 112 | + |
| 113 | + public override SyntaxNode? VisitParameter(ParameterSyntax node) |
| 114 | + { |
| 115 | + Debug.Assert(!_doReplacement); |
| 116 | + var ret = base.VisitParameter(node) as ParameterSyntax; |
| 117 | + if (_doReplacement && ret is { Type: not null }) |
| 118 | + { |
| 119 | + _doReplacement = false; |
| 120 | + return ret.WithType(GetDSLType(ret.Type, node.AttributeLists, null)); |
| 121 | + } |
| 122 | + |
| 123 | + return ret; |
| 124 | + } |
| 125 | + |
| 126 | + public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) |
| 127 | + { |
| 128 | + var ret = base.VisitIdentifierName(node) as IdentifierNameSyntax; |
| 129 | + if (ret is null) |
| 130 | + { |
| 131 | + return ret; |
| 132 | + } |
| 133 | + |
| 134 | + if (!(_parameterIdentifiers?.Contains(node.Identifier.ToString()) ?? false)) |
| 135 | + { |
| 136 | + return ret; |
| 137 | + } |
| 138 | + |
| 139 | + if (node.Parent is not ParameterSyntax) |
| 140 | + { |
| 141 | + return IdentifierName(IdentToInnerIdent(ret.Identifier)).WithTriviaFrom(ret); |
| 142 | + } |
| 143 | + |
| 144 | + _doReplacement = true; |
| 145 | + return ret; |
| 146 | + } |
| 147 | + |
| 148 | + public override SyntaxNode? VisitAttribute(AttributeSyntax node) |
| 149 | + { |
| 150 | + if ((_parameterIdentifiers?.Count).GetValueOrDefault() == 0 && !_returnTypeReplaceable) |
| 151 | + { |
| 152 | + return base.VisitAttribute(node); |
| 153 | + } |
| 154 | + |
| 155 | + var sep = node.Name.ToString().Split("::")[1]; |
| 156 | + return sep == "DllImport" || sep == "DllImportAttribute" || |
| 157 | + sep.EndsWith("System.Runtime.InteropServices.DllImport") || |
| 158 | + sep.EndsWith("System.Runtime.InteropServices.DllImportAttribute") |
| 159 | + ? null // Remove the attribute as it is being moved to a local function |
| 160 | + : base.VisitAttribute(node); |
| 161 | + } |
| 162 | + |
| 163 | + private static SyntaxToken IdentToInnerIdent(SyntaxToken token) |
| 164 | + { |
| 165 | + Debug.Assert(token.IsKind(SyntaxKind.IdentifierToken)); |
| 166 | + return Identifier($"__dsl_{token}"); |
| 167 | + } |
| 168 | + |
| 169 | + private static SyntaxToken IdentToPInvokeIdent(SyntaxToken token) |
| 170 | + { |
| 171 | + Debug.Assert(token.IsKind(SyntaxKind.IdentifierToken)); |
| 172 | + return Identifier($"__DSL_{token}"); |
| 173 | + } |
| 174 | + |
| 175 | + |
| 176 | + private static bool IsDSLApplicable(TypeSyntax syn) => syn is PointerTypeSyntax; |
| 177 | + |
| 178 | + private static TypeSyntax GetDSLType(TypeSyntax syntax, IEnumerable<AttributeListSyntax?>? attrLists, SyntaxKind? target) |
| 179 | + { |
| 180 | + var indirectionLevels = 0; |
| 181 | + while (syntax is PointerTypeSyntax inner) |
| 182 | + { |
| 183 | + indirectionLevels++; |
| 184 | + syntax = inner.ElementType; |
| 185 | + } |
| 186 | + |
| 187 | + if (indirectionLevels > 2) |
| 188 | + { |
| 189 | + throw new ArgumentOutOfRangeException(nameof(syntax), |
| 190 | + "Indirection levels greater than 2 are currently unsupported by SilkDSL."); |
| 191 | + } |
| 192 | + |
| 193 | + var isConst = false; |
| 194 | + if (attrLists is not null) |
| 195 | + { |
| 196 | + foreach (var attrs in attrLists) |
| 197 | + { |
| 198 | + if (attrs is null || |
| 199 | + (target is not null && !(attrs.Target?.Identifier.IsKind(target.Value)).GetValueOrDefault()) || |
| 200 | + (target is null && attrs.Target is not null)) |
| 201 | + { |
| 202 | + continue; |
| 203 | + } |
| 204 | + foreach (var attributeSyntax in attrs.Attributes) |
| 205 | + { |
| 206 | + if (attributeSyntax.Name.ToString() == "NativeTypeName" && |
| 207 | + attributeSyntax.ArgumentList?.Arguments.FirstOrDefault()?.Expression is |
| 208 | + LiteralExpressionSyntax lit && lit.Token.ToString().StartsWith("const ")) |
| 209 | + { |
| 210 | + isConst = true; |
| 211 | + } |
| 212 | + } |
| 213 | + } |
| 214 | + } |
| 215 | + |
| 216 | + return GenericName(Identifier(isConst switch { |
| 217 | + true when indirectionLevels > 1 => $"ConstPtr{indirectionLevels}D", |
| 218 | + true => "ConstPtr", |
| 219 | + false when indirectionLevels > 1 => $"Ptr{indirectionLevels}D", |
| 220 | + false => "Ptr" |
| 221 | + })).WithTypeArgumentList(TypeArgumentList(SingletonSeparatedList(syntax))); |
| 222 | + } |
| 223 | + } |
| 224 | +} |
0 commit comments