From 92e9b2458665617d6d05b968a4f8bd08afdb65a4 Mon Sep 17 00:00:00 2001 From: Maxime Chevalier-Boisvert Date: Wed, 13 Jan 2021 15:18:35 -0500 Subject: [PATCH] Keep track of incoming branches in ujit --- ujit_codegen.c | 28 ++++++++--------------- ujit_core.c | 62 ++++++++++++++++++++++++++++++++++++++++++++++---- ujit_core.h | 39 ++++++++++++++++--------------- 3 files changed, 86 insertions(+), 43 deletions(-) diff --git a/ujit_codegen.c b/ujit_codegen.c index 9a17ca942a..722f799e64 100644 --- a/ujit_codegen.c +++ b/ujit_codegen.c @@ -124,7 +124,6 @@ ujit_gen_entry(version_t* version) // Compile the block starting at this instruction uint32_t num_instrs = ujit_gen_code(version); - // FIXME: can we eliminate this check? // If no instructions were compiled if (num_instrs == 0) { return NULL; @@ -1003,24 +1002,15 @@ gen_jump(jitstate_t* jit, ctx_t* ctx) // RUBY_VM_CHECK_INTS(ec); // - // If the jump target was already compiled - if (find_block_version(jump_block, ctx)) - { - // Generate the jump instruction - gen_branch( - ctx, - jump_block, - ctx, - BLOCKID_NULL, - ctx, - gen_jump_branch - ); - } - else - { - // No need for a jump, compile the target block right here - gen_block_version(jump_block, ctx); - } + // Generate the jump instruction + gen_branch( + ctx, + jump_block, + ctx, + BLOCKID_NULL, + ctx, + gen_jump_branch + ); return true; } diff --git a/ujit_core.c b/ujit_core.c index 922ea4f1e8..2cb1c252de 100644 --- a/ujit_core.c +++ b/ujit_core.c @@ -91,6 +91,8 @@ version_t* gen_block_version(blockid_t blockid, const ctx_t* ctx) version_t* p_version = malloc(sizeof(version_t)); memcpy(&p_version->blockid, &blockid, sizeof(blockid_t)); memcpy(&p_version->ctx, ctx, sizeof(ctx_t)); + p_version->incoming = NULL; + p_version->num_incoming = 0; // Compile the block version p_version->start_pos = cb->write_pos; @@ -110,6 +112,8 @@ uint8_t* gen_entry_point(const rb_iseq_t *iseq, uint32_t insn_idx) version_t* p_version = malloc(sizeof(version_t)); blockid_t blockid = { iseq, insn_idx }; memcpy(&p_version->blockid, &blockid, sizeof(blockid_t)); + p_version->incoming = NULL; + p_version->num_incoming = 0; // The entry context makes no assumptions about types ctx_t ctx = { 0 }; @@ -120,12 +124,31 @@ uint8_t* gen_entry_point(const rb_iseq_t *iseq, uint32_t insn_idx) uint8_t* code_ptr = ujit_gen_entry(p_version); p_version->end_pos = cb->write_pos; + // If we couldn't generate any code + if (!code_ptr) + { + free(p_version); + return NULL; + } + // Keep track of the new block version st_insert(version_tbl, (st_data_t)&p_version->blockid, (st_data_t)p_version); return code_ptr; } +// Add an incoming branch for a given block version +static void add_incoming(version_t* p_version, uint32_t branch_idx) +{ + // Add this branch to the list of incoming branches for the target + uint32_t* new_list = malloc(sizeof(uint32_t) * p_version->num_incoming + 1); + memcpy(new_list, p_version->incoming, p_version->num_incoming); + new_list[p_version->num_incoming] = branch_idx; + p_version->incoming = new_list; + p_version->num_incoming += 1; + //fprintf(stderr, "num_incoming: %d\n", p_version->num_incoming); +} + // Called by the generated code when a branch stub is executed // Triggers compilation of branches and code patching uint8_t* branch_stub_hit(uint32_t branch_idx, uint32_t target_idx) @@ -161,6 +184,9 @@ uint8_t* branch_stub_hit(uint32_t branch_idx, uint32_t target_idx) p_version = gen_block_version(target, target_ctx); } + // Add this branch to the list of incoming branches for the target + add_incoming(p_version, branch_idx); + // Update the branch target address uint8_t* dst_addr = cb_get_ptr(cb, p_version->start_pos); branch->dst_addrs[target_idx] = dst_addr; @@ -192,6 +218,9 @@ uint8_t* get_branch_target( if (p_version) { + // Add an incoming branch for this version + add_incoming(p_version, branch_idx); + return cb_get_ptr(cb, p_version->start_pos); } @@ -233,9 +262,35 @@ void gen_branch( branchgen_fn gen_fn ) { - // Get branch targets or stubs (code pointers) - uint8_t* dst_addr0 = get_branch_target(target0, ctx0, ocb, num_branches, 0); - uint8_t* dst_addr1 = get_branch_target(target1, ctx1, ocb, num_branches, 1); + assert (num_branches < MAX_BRANCHES); + uint32_t branch_idx = num_branches; + + // Branch targets or stub adddresses (code pointers) + uint8_t* dst_addr0; + uint8_t* dst_addr1; + + // If there's only one branch target + if (target1.iseq == NULL) + { + version_t* p_version = find_block_version(target0, ctx0); + + // If the version doesn't already exist + if (!p_version) + { + // No need for a jump, compile the target block right here + p_version = gen_block_version(target0, ctx0); + } + + add_incoming(p_version, branch_idx); + dst_addr0 = cb_get_ptr(cb, p_version->start_pos); + dst_addr1 = NULL; + } + else + { + // Get the branch targets or stubs + dst_addr0 = get_branch_target(target0, ctx0, ocb, branch_idx, 0); + dst_addr1 = get_branch_target(target1, ctx1, ocb, branch_idx, 1); + } // Call the branch generation function uint32_t start_pos = cb->write_pos; @@ -254,7 +309,6 @@ void gen_branch( SHAPE_DEFAULT }; - assert (num_branches < MAX_BRANCHES); branch_entries[num_branches] = branch_entry; num_branches++; } diff --git a/ujit_core.h b/ujit_core.h index 065bafe786..f96d79df70 100644 --- a/ujit_core.h +++ b/ujit_core.h @@ -45,26 +45,6 @@ typedef struct BlockId // Null block id constant static const blockid_t BLOCKID_NULL = { 0, 0 }; -// Basic block version -typedef struct BlockVersion -{ - // Basic block this is a version of - blockid_t blockid; - - // Context at the start of the block - ctx_t ctx; - - // Positions where the generated code starts and ends - uint32_t start_pos; - uint32_t end_pos; - - // TODO - // TODO: list of incoming branches, branch entries - // TODO - // incoming; - -} version_t; - /// Branch code shape enumeration enum uint8_t { @@ -101,6 +81,25 @@ typedef struct BranchEntry } branch_t; +// Basic block version +typedef struct BlockVersion +{ + // Basic block this is a version of + blockid_t blockid; + + // Context at the start of the block + ctx_t ctx; + + // Positions where the generated code starts and ends + uint32_t start_pos; + uint32_t end_pos; + + // List of incoming branches indices + uint32_t* incoming; + uint32_t num_incoming; + +} version_t; + // Context object methods int ctx_get_opcode(ctx_t *ctx); uint32_t ctx_next_idx(ctx_t* ctx);