Need to check if code contains certain identifiers
Asked Answered
S

2

8

I am going to be dynamically compiling and executing code using Roslyn like the example below. I want to make sure the code does not violate some of my rules, like:

  • Does not use Reflection
  • Does not use HttpClient or WebClient
  • Does not use File or Directory classes in System.IO namespace
  • Does not use Source Generators
  • Does not call unmanaged code

Where in the following code would I insert my rules/checks and how would I do them?

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;
using System.Reflection;
using System.Runtime.CompilerServices;

string code = @"using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.IO;

namespace Customization
{
    public class Script
    {
        public async Task<object?> RunAsync(object? data)
        {
            //The following should not be allowed
            File.Delete(@""C:\Temp\log.txt"");

            return await Task.FromResult(data);
        }
    }
}";

var compilation = Compile(code);
var bytes = Build(compilation);

Console.WriteLine("Done");

CSharpCompilation Compile(string code)
{
    SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);

    string? dotNetCoreDirectoryPath = Path.GetDirectoryName(typeof(object).GetTypeInfo().Assembly.Location);
    if (String.IsNullOrWhiteSpace(dotNetCoreDirectoryPath))
    {
        throw new ArgumentNullException("Cannot determine path to current assembly.");
    }

    string assemblyName = Path.GetRandomFileName();
    List<MetadataReference> references = new();
    references.Add(MetadataReference.CreateFromFile(typeof(object).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Console).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Dictionary<,>).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Task).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(Path.Combine(dotNetCoreDirectoryPath, "System.Runtime.dll")));

    CSharpCompilation compilation = CSharpCompilation.Create(
        assemblyName,
        syntaxTrees: new[] { syntaxTree },
        references: references,
        options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));


    SemanticModel model = compilation.GetSemanticModel(syntaxTree);
    CompilationUnitSyntax root = (CompilationUnitSyntax)syntaxTree.GetRoot();

    //TODO: Check the code for use classes that are not allowed such as File in the System.IO namespace.
    //Not exactly sure how to walk through identifiers.
    IEnumerable<IdentifierNameSyntax> identifiers = root.DescendantNodes()
        .Where(s => s is IdentifierNameSyntax)
        .Cast<IdentifierNameSyntax>();


    return compilation;
}

[MethodImpl(MethodImplOptions.NoInlining)]
byte[] Build(CSharpCompilation compilation)
{
    using (MemoryStream ms = new())
    {
        //Emit to catch build errors
        EmitResult emitResult = compilation.Emit(ms);

        if (!emitResult.Success)
        {
            Diagnostic? firstError =
                emitResult
                    .Diagnostics
                    .FirstOrDefault
                    (
                        diagnostic => diagnostic.IsWarningAsError ||
                            diagnostic.Severity == DiagnosticSeverity.Error
                    );

            throw new Exception(firstError?.GetMessage());
        }

        return ms.ToArray();
    }
}
System answered 25/3, 2022 at 3:54 Comment(0)
K
4

When checking for the use of a particular class you can look for IdentifierNameSyntax type nodes by using the OfType<>() method and filter the results by class name:

var names = root.DescendantNodes()
    .OfType<IdentifierNameSyntax>()
    .Where(i => string.Equals(i.Identifier.ValueText, className, StringComparison.OrdinalIgnoreCase));

You can then use the SemanticModel to check the namespace of the class:

foreach (var name in names)
{
    var typeInfo = model.GetTypeInfo(name);
    if (string.Equals(typeInfo.Type?.ContainingNamespace?.ToString(), containingNamespace, StringComparison.OrdinalIgnoreCase))
    {
        throw new Exception($"Class {containingNamespace}.{className} is not allowed.");
    }
}

To check for the use of reflection or unmanaged code you could check for the relevant usings System.Reflection and System.Runtime.InteropServices.

if (root.Usings.Any(u => string.Equals(u.Name.ToString(), disallowedNamespace, StringComparison.OrdinalIgnoreCase)))
{
    throw new Exception($"Namespace {disallowedNamespace} is not allowed.");
}

This would catch cases where the usings were unused i.e., no actual reflection or unmanaged code, but that seems like an acceptable trade off.

I'm not sure what to do about the source generator checks as these are normally included as project references so I don't know how they'd run against dynamically compiled code.

Keeping the checks in the same place and updating your code gives:

using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;

string code = @"using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.IO;
using System;
using System.Net.Http;
using System.Reflection;
using System.Runtime.InteropServices

namespace Customization
{
    public class Script
    {
        static readonly HttpClient client = new HttpClient();

        public async Task<object?> RunAsync(object? data)
        {
            //The following should not be allowed
            File.Delete(@""C:\Temp\log.txt"");

            return await Task.FromResult(data);
        }
    }
}";

var compilation = Compile(code);

var bytes = Build(compilation);
Console.WriteLine("Done");


CSharpCompilation Compile(string code)
{
    SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);

    string? dotNetCoreDirectoryPath = Path.GetDirectoryName(typeof(object).GetTypeInfo().Assembly.Location);
    if (String.IsNullOrWhiteSpace(dotNetCoreDirectoryPath))
    {
        throw new InvalidOperationException("Cannot determine path to current assembly.");
    }

    string assemblyName = Path.GetRandomFileName();
    List<MetadataReference> references = new();
    references.Add(MetadataReference.CreateFromFile(typeof(object).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Console).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Dictionary<,>).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Task).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(HttpClient).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(Path.Combine(dotNetCoreDirectoryPath, "System.Runtime.dll")));

    CSharpCompilation compilation = CSharpCompilation.Create(
        assemblyName,
        syntaxTrees: new[] { syntaxTree },
        references: references,
        options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));


    SemanticModel model = compilation.GetSemanticModel(syntaxTree);
    CompilationUnitSyntax root = (CompilationUnitSyntax)syntaxTree.GetRoot();

    ThrowOnDisallowedClass("File", "System.IO", root, model);
    ThrowOnDisallowedClass("HttpClient", "System.Net.Http", root, model);
    ThrowOnDisallowedNamespace("System.Reflection", root);
    ThrowOnDisallowedNamespace("System.Runtime.InteropServices", root);

    return compilation;
}

[MethodImpl(MethodImplOptions.NoInlining)]
byte[] Build(CSharpCompilation compilation)
{
    using (MemoryStream ms = new())
    {
        //Emit to catch build errors
        EmitResult emitResult = compilation.Emit(ms);

        if (!emitResult.Success)
        {
            Diagnostic? firstError =
                emitResult
                    .Diagnostics
                    .FirstOrDefault
                    (
                        diagnostic => diagnostic.IsWarningAsError ||
                            diagnostic.Severity == DiagnosticSeverity.Error
                    );

            throw new Exception(firstError?.GetMessage());
        }

        return ms.ToArray();
    }
}

void ThrowOnDisallowedClass(string className, string containingNamespace, CompilationUnitSyntax root, SemanticModel model)
{
    var names = root.DescendantNodes()
                    .OfType<IdentifierNameSyntax>()
                    .Where(i => string.Equals(i.Identifier.ValueText, className, StringComparison.OrdinalIgnoreCase));

    foreach (var name in names)
    {
        var typeInfo = model.GetTypeInfo(name);
        if (string.Equals(typeInfo.Type?.ContainingNamespace?.ToString(), containingNamespace, StringComparison.OrdinalIgnoreCase))
        {
            throw new Exception($"Class {containingNamespace}.{className} is not allowed.");
        }
    }
}

void ThrowOnDisallowedNamespace(string disallowedNamespace, CompilationUnitSyntax root)
{
    if (root.Usings.Any(u => string.Equals(u.Name.ToString(), disallowedNamespace, StringComparison.OrdinalIgnoreCase)))
    {
        throw new Exception($"Namespace {disallowedNamespace} is not allowed.");
    }
}

I've used throw for rule violations here which will mean that multiple violations will not be reported all at once so you may want to tweak that so it's a bit more efficient.

Killer answered 5/4, 2022 at 19:0 Comment(1)
I have awarded the bounty to you. Thank you. I also posted an answer that uses the SymbolInfo class.System
S
3

The SymbolInfo class provides some of the meatadata needed to create rules to restrict use of certain code. Here is what I came up with so far. Any suggestions on how to improve on this would be appreciated.

//Check for banned namespaces
string[] namespaceBlacklist = new string[] { "System.Net", "System.IO" };

foreach (IdentifierNameSyntax identifier in identifiers)
{
    SymbolInfo symbolInfo = semanticModel.GetSymbolInfo(identifier);

    if (symbolInfo.Symbol is { })
    {
        if (symbolInfo.Symbol.Kind == SymbolKind.Namespace)
        {
            if (namespaceBlacklist.Any(ns => ns == symbolInfo.Symbol.ToDisplayString()))
            {
                throw new Exception($"Declaration of namespace '{symbolInfo.Symbol.ToDisplayString()}' is not allowed.");
            }
        }
        else if (symbolInfo.Symbol.Kind == SymbolKind.NamedType)
        {
            if (namespaceBlacklist.Any(ns => symbolInfo.Symbol.ToDisplayString().StartsWith(ns + ".")))
            {
                throw new Exception($"Use of namespace '{identifier.Identifier.ValueText}' is not allowed.");
            }
        }
    }
}
System answered 7/4, 2022 at 18:19 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.