LLVM MCJIT / SEH Exception handling
Asked Answered
T

1

8

Lately, I've been attempting to get SEH exception handling to work in LLVM (3.8.1) together with MCJIT. So far without any luck.

From what I understand from the website ( http://llvm.org/docs/ExceptionHandling.html ), this is pretty much how this should be implemented. Compiling a minimal piece of code with clang gives pretty much the same LLVM IR code. However, when I try it, the program crashes with a nasty Stack cookie instrumentation code detected a stack-based buffer overrun..

To illustrate what I've been attempting to do, I've created a minimum test case (I apologise for the amount of code...):

#include <string>
#include <iostream>
#include <exception>

#pragma warning(push)
#pragma warning(disable: 4267)
#pragma warning(disable: 4244)
#pragma warning(disable: 4800)
#pragma warning(disable: 4996)
#pragma warning(disable: 4141)
#pragma warning(disable: 4146)
#pragma warning(disable: 4624)
#pragma warning(disable: 4291)

#define DONT_GET_PLUGIN_LOADER_OPTION

#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/Triple.h"
#include "llvm/PassRegistry.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/LinkAllPasses.h"
#include "llvm/Analysis/Passes.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/GenericValue.h"
#include "llvm/ExecutionEngine/Interpreter.h"
#include "llvm/ExecutionEngine/MCJIT.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/FormattedStream.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/PluginLoader.h"
#include "llvm/Support/Host.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"

#pragma warning(pop)

static void test()
{
    // You can use this to see that function calls work fine.
    // std::cout << "Foo!" << std::endl; 

    throw std::exception("Something we should try to catch.");
}

int main()
{
    // Initialize LLVM
    std::cout << "Initializing LLVM." << std::endl;

    llvm::InitializeNativeTarget();
    llvm::InitializeAllTargetMCs();
    llvm::InitializeNativeTargetAsmPrinter();
    llvm::InitializeNativeTargetAsmParser();

    llvm::PassRegistry *Registry = llvm::PassRegistry::getPassRegistry();
    llvm::initializeCore(*Registry);
    llvm::initializeScalarOpts(*Registry);
    llvm::initializeObjCARCOpts(*Registry);
    llvm::initializeVectorization(*Registry);
    llvm::initializeIPO(*Registry);
    llvm::initializeAnalysis(*Registry);
    llvm::initializeTransformUtils(*Registry);
    llvm::initializeInstCombine(*Registry);
    llvm::initializeInstrumentation(*Registry);
    llvm::initializeTarget(*Registry);
    // For codegen passes, only passes that do IR to IR transformation are
    // supported.
    llvm::initializeCodeGenPreparePass(*Registry);
    llvm::initializeAtomicExpandPass(*Registry);
    llvm::initializeRewriteSymbolsPass(*Registry);
    llvm::initializeWinEHPreparePass(*Registry);
    llvm::initializeDwarfEHPreparePass(*Registry);
    llvm::initializeSjLjEHPreparePass(*Registry);

    llvm::StringRef MCPU = llvm::sys::getHostCPUName();
    std::string MTrip = llvm::sys::getProcessTriple();

    static llvm::StringMap<bool, llvm::MallocAllocator> features;
    llvm::sys::getHostCPUFeatures(features);

    // Initialize module & context:
    auto context = std::unique_ptr<llvm::LLVMContext>(new llvm::LLVMContext());
    auto module = std::unique_ptr<llvm::Module>(new llvm::Module("native", *context));

    // Create 'main' method:

    llvm::Type* returnType = llvm::Type::getInt32Ty(*context);
    std::vector<llvm::Type*> arguments;

    // MCJIT only supports main(int, char**)
    arguments.push_back(llvm::Type::getInt32Ty(*context));
    arguments.push_back(llvm::Type::getInt8PtrTy(*context)->getPointerTo());

    llvm::Function *fcn = llvm::cast<llvm::Function>(module->getOrInsertFunction("main", llvm::FunctionType::get(returnType, arguments, false)));

    // Generate exception handler info for main:
    llvm::AttrBuilder argBuilder;
    argBuilder.addAttribute(llvm::Attribute::UWTable);
    argBuilder.addAttribute("stack-protector-buffer-size", "8");
    fcn->addAttributes(llvm::AttributeSet::FunctionIndex, llvm::AttributeSet::get(*context, llvm::AttributeSet::FunctionIndex, argBuilder));

    // Exception handling requires a personality function. We want to use the SEH personality handler
    llvm::Function *personalityHandler = llvm::cast<llvm::Function>(module->getOrInsertFunction("__CxxFrameHandler3", llvm::FunctionType::get(llvm::Type::getInt32Ty(*context), true)));
    auto personalityPtr = llvm::ConstantExpr::getBitCast(personalityHandler, llvm::Type::getInt8PtrTy(*context));
    fcn->setPersonalityFn(personalityPtr);

    // Create some code. Basically we want to invoke our 'test' method
    auto block = llvm::BasicBlock::Create(*context, "code", fcn);
    llvm::IRBuilder<> builder(block);

    // all other cases might throw an exception
    auto continueBlock = llvm::BasicBlock::Create(*context, "invoke.cont", fcn);
    auto catchDispatch = llvm::BasicBlock::Create(*context, "catch.dispatch", fcn);

    // Register 'test' as an external function:
    const void* testFunctionPtr = &test;
    auto testFunctionType = llvm::FunctionType::get(builder.getVoidTy(), false);
    auto testFunction = llvm::Function::Create(testFunctionType, llvm::Function::ExternalLinkage, "test", module.get());

    // %call = invoke i32 @"test"() to label %invoke.cont unwind label %catch.dispatch
    auto call = builder.CreateInvoke(testFunction, continueBlock, catchDispatch);

    // return [ 0 from ok, 1 from catch handler ]
    builder.SetInsertPoint(continueBlock);
    auto phi = builder.CreatePHI(builder.getInt32Ty(), 2, "result");
    phi->addIncoming(builder.getInt32(0), block);
    builder.CreateRet(phi);

    // Create exception handler:

    // Create default catch block. Basically handles the exception and returns '1'. 
    builder.SetInsertPoint(catchDispatch);
    auto parentPad = llvm::ConstantTokenNone::get(*context);
    // %0 = catchswitch within none [label %catch] unwind to caller
    auto catchSwitch = builder.CreateCatchSwitch(parentPad, nullptr, 1);

    auto catchBlock = llvm::BasicBlock::Create(*context, "catch", fcn);
    builder.SetInsertPoint(catchBlock);
    catchSwitch->addHandler(catchBlock);

    // MSVC code:
    // %1 = catchpad within %0 [i8* null, i32 64, i8* null] == "catch all"
    llvm::Value *nullPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(*context));
    auto catchPad = builder.CreateCatchPad(catchSwitch, { nullPtr, builder.getInt32(0x40), nullPtr });

    // catchret from %1 to label %return
    auto const1 = builder.getInt32(1);
    builder.CreateCatchRet(catchPad, continueBlock);

    // set 1 for the catch handler
    phi->addIncoming(builder.getInt32(1), catchBlock);

    // *DONE* building the code. 

    // Dump the LLVM IR:
    module->dump();

    // Let's JIT the code:

    std::string error;
    auto trip = llvm::Triple::normalize(MTrip);
    llvm::Triple triple(trip);
    const llvm::Target *target = llvm::TargetRegistry::lookupTarget("x86-64", triple, error);
    if (!target)
    {
        throw error.c_str();
    }

    llvm::TargetOptions Options;

    std::unique_ptr<llvm::TargetMachine> targetMachine(
        target->createTargetMachine(trip, MCPU, "", Options, llvm::Reloc::Default, llvm::CodeModel::Default, llvm::CodeGenOpt::Aggressive));

    if (!targetMachine.get())
    {
        throw "Could not allocate target machine!";
    }

    // Create the target machine; set the module data layout to the correct values.
    auto DL = targetMachine->createDataLayout();
    module->setDataLayout(DL);
    module->setTargetTriple(trip);

    // Pass manager builder:
    llvm::PassManagerBuilder pmbuilder;
    pmbuilder.OptLevel = 3;
    pmbuilder.BBVectorize = false;
    pmbuilder.SLPVectorize = true;
    pmbuilder.LoopVectorize = true;
    pmbuilder.Inliner = llvm::createFunctionInliningPass(3, 2);
    llvm::TargetLibraryInfoImpl *TLI = new llvm::TargetLibraryInfoImpl(triple);
    pmbuilder.LibraryInfo = TLI;

    // Generate pass managers:

    // 1. Function pass manager:
    llvm::legacy::FunctionPassManager FPM(module.get());
    pmbuilder.populateFunctionPassManager(FPM);

    // 2. Module pass manager:
    llvm::legacy::PassManager PM;
    PM.add(llvm::createTargetTransformInfoWrapperPass(targetMachine->getTargetIRAnalysis()));
    pmbuilder.populateModulePassManager(PM);

    // 3. Execute passes:
    //    - Per-function passes:
    FPM.doInitialization();
    for (llvm::Module::iterator I = module->begin(), E = module->end(); I != E; ++I)
    {
        if (!I->isDeclaration())
        {
            FPM.run(*I);
        }
    }
    FPM.doFinalization();

    //   - Per-module passes:
    PM.run(*module);


    // All done, *RUN*.
    llvm::EngineBuilder engineBuilder(std::move(module));
    engineBuilder.setEngineKind(llvm::EngineKind::JIT);
    engineBuilder.setMCPU(MCPU);
    engineBuilder.setMArch("x86-64");
    engineBuilder.setUseOrcMCJITReplacement(false);
    engineBuilder.setOptLevel(llvm::CodeGenOpt::None);

    llvm::ExecutionEngine* engine = engineBuilder.create();

    // Register global 'test' function:
    engine->addGlobalMapping(testFunction, const_cast<void*>(testFunctionPtr)); // Yuck... 

    // Finalize
    engine->finalizeObject();

    // Invoke:
    std::vector<llvm::GenericValue> args(2);
    args[0].IntVal = llvm::APInt(32, static_cast<uint64_t>(0), true);
    args[1].PointerVal = nullptr;

    llvm::GenericValue gv = engine->runFunction(fcn, args);
    auto result = int(gv.IntVal.getSExtValue());

    std::cout << "Result after execution: " << result << std::endl;

    std::string s;
    std::getline(std::cin, s);
}

This produces the following IR code:

; ModuleID = 'native'

; Function Attrs: uwtable
define i32 @main(i32, i8**) #0 personality i8* bitcast (i32 (...)* @__CxxFrameHandler3 to i8*) {
code:
  invoke void @test()
          to label %invoke.cont unwind label %catch.dispatch

invoke.cont:                                      ; preds = %catch, %code
  %result = phi i32 [ 0, %code ], [ 1, %catch ]
  ret i32 %result

catch.dispatch:                                   ; preds = %code
  %2 = catchswitch within none [label %catch] unwind to caller

catch:                                            ; preds = %catch.dispatch
  %3 = catchpad within %2 [i8* null, i32 64, i8* null]
  catchret from %3 to label %invoke.cont
}

declare i32 @__CxxFrameHandler3(...)

declare void @test()

attributes #0 = { uwtable "stack-protector-buffer-size"="8" }

Q: What am I missing, why doesn't this work and how to fix it?

Tokenism answered 30/8, 2016 at 10:10 Comment(2)
Did you found out what code have caused this stack overflow?Etem
@ShmuelHazan It's not a stack overflow. And the details tell what causes the issue. Also, if I figured it out, I obviously wouldn't be putting a bounty on it.Tokenism
T
1

After posting this issue on the llvm developer list, I got a friendly reply explaining how this issue is related to a known bug: https://llvm.org/bugs/show_bug.cgi?id=24233 .

Basically what happens is that LLVM doesn't implement the code that Windows (more specifically: SEH and debugging) requires for handling stack frames. I'm by no means an expert on this subject, but until this is implemented, SEH won't know what to do, which means C++ exception basically won't work.

An obvious workaround is of course to pass the object as a pointer during the function call and perform an if-then-else. That way, exceptions are avoided. However, this is pretty nasty and will probably give a serious performance penalty. Also, this makes the flow in the compiler as well as the generated program much more complicated. In other words: let's just say I'd rather not.

I'll leave the question open; if someone happens to find a hack or figure out a workaround, I'll gladly accept it.

Tokenism answered 8/9, 2016 at 15:50 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.