- byron10000's blog
BF 语言 JIT 编译器
- 2024-5-2 22:30:29 @
amd64 linux only
#ifdef _USE_PCH_
#include "pch.hpp"
#else
#include <bits/stdc++.h>
#endif
#include <sys/mman.h>
#include <unistd.h>
using namespace std;
namespace
{
static string read_all(string_view filename) {
ifstream is(filename.data());
is.seekg(0, ios::end);
auto len = is.tellg();
is.seekg(0, ios::beg);
string ret(len, '\0');
is.read(ret.data(), len);
return ret;
}
static int count_repeat(string_view src, int& i) {
int c = 1;
while (i + 1 < (int) src.size() && src[i + 1] == src[i]) i++, c++;
return c;
}
static uint8_t bit_cast(int8_t x) { return *(const uint8_t*) &x; }
static uint32_t bit_cast(int32_t x) { return *(const uint32_t*) &x; }
} // namespace
namespace bfjit
{
struct BinStrBuilder {
vector<uint8_t> data;
template<uint32_t N> BinStrBuilder& add(const uint8_t (&x)[N]) {
auto cur = data.size();
data.resize(cur + N);
memcpy(data.data() + cur, x, N);
return *this;
}
__attribute__((no_sanitize("undefined"))) BinStrBuilder& add(uint32_t x) {
auto cur = data.size();
data.resize(cur + 4);
*(uint32_t*) (data.data() + cur) = x;
return *this;
}
BinStrBuilder& add(const BinStrBuilder& o) {
if (!o.size()) return *this;
auto cur = data.size();
data.resize(cur + o.size());
memcpy(data.data() + cur, o.data.data(), o.size());
return *this;
}
uint32_t size() const { return (uint32_t) data.size(); }
};
BinStrBuilder do_compile(string_view src, int& i) {
BinStrBuilder ret;
bool left_bracket = i > 0 && src[i - 1] == '[';
for (; i < (int) src.size(); i++) {
auto ch = src[i];
if (ch == '.') {
ret.add((uint8_t[]) {0x50, 0x4c, 0x89, 0xef, 0x40, 0x8a, 0x30, 0x41, 0xff, 0x14, 0x24, 0x58});
/*
push rax
mov rdi,r13
mov sil,byte [rax]
call 0x00[r12]
pop rax
*/
}
if (ch == ',') {
ret.add(
(uint8_t[]) {0x50, 0x4c, 0x89, 0xef, 0x41, 0xff, 0x54, 0x24, 0x08, 0x48, 0x89, 0xc1, 0x58, 0x88, 0x08});
/*
push rax
mov rdi,r13
call 0x08[r12]
mov rcx,rax
pop rax
mov byte[rax], cl
*/
}
if (ch == '+' || ch == '-') {
int c = count_repeat(src, i) & 255;
if (ch == '-') c = 256 - c;
ret.add((uint8_t[]) {0x80, 0x00, (uint8_t) c}); // add byte[rax],$x
}
if (ch == '<' || ch == '>') {
int c = count_repeat(src, i);
if (c <= 0x7f)
ret.add((uint8_t[]) {0x48, 0x83, 0xc0, bit_cast((int8_t) (ch == '>' ? c : -c))}); // add rax,$x
else ret.add((uint8_t[]) {0x48, 0x05}).add(bit_cast(ch == '>' ? c : -c)); // add rax,$x
}
if (ch == ']') {
assert(left_bracket && "Compile error: Unmatched bracket");
return ret;
}
if (ch == '[') {
i++;
auto sub = do_compile(src, i);
ret.add((uint8_t[]) {0x8a, 0x08, 0x84, 0xc9}); // mov rcx,[rax]; test rcx,rcx
if ((int) sub.size() + 4 + 2 + 2 <= 0x7f) {
ret.add((uint8_t[]) {0x74, bit_cast((int8_t) ((int) sub.size() + 2))}); // je $after_loop
sub.add((uint8_t[]) {0xeb, bit_cast((int8_t) (-((int) sub.size() + 2 + 2 + 4)))}); // jmp $begin_loop
} else {
ret.add((uint8_t[]) {0x0f, 0x84}).add(bit_cast(((int) sub.size() + 5))); // je $after_loop
sub.add((uint8_t[]) {0xe9}).add(bit_cast(-((int) (sub.size() - 1) + 5 + 6 + 4))); // jmp $begin_loop
}
ret.add(sub);
}
}
assert(!left_bracket && "Compile error: Unmatched bracket");
return ret;
}
vector<uint8_t> compile(string_view src) {
BinStrBuilder ret;
ret.add((uint8_t[]) {
0x55, 0x48, 0x89, 0xe5, 0x41, 0x54, 0x41, 0x55, 0x48, 0x89, 0xf8, 0x49, 0x89, 0xf4, 0x49, 0x89, 0xd5});
/*
push rbp
mov rbp,rsp
push r12
push r13
mov rax,rdi
mov r12,rsi
mov r13,rdx
*/
int i = 0;
ret.add(do_compile(src, i));
ret.add((uint8_t[]) {0x41, 0x5d, 0x41, 0x5c, 0x5d, 0xc3});
/*
pop r13
pop r12
pop rbp
ret
*/
return ret.data;
}
struct IOData {
FILE *infile, *outfile;
};
static void io_write(IOData* io, char ch) { fputc(ch, io->outfile); }
static char io_read(IOData* io) { return (char) (feof(io->infile) ? 0 : fgetc(io->infile)); }
void exec(const uint8_t* prog, int prog_len, int heap_size, FILE* infile, FILE* outfile) {
void* func[] = {(void*) io_write, (void*) io_read};
IOData io {infile, outfile};
auto page_size = getpagesize();
auto alloc_len = (prog_len + page_size - 1) / page_size * page_size;
auto prog_mem = mmap(nullptr, alloc_len, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
assert(prog_mem != MAP_FAILED);
memcpy(prog_mem, prog, prog_len);
auto data = mmap(nullptr, heap_size, PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
typedef void (*prog_fn_t)(void*, void*, void*);
((prog_fn_t) prog_mem)(data, (void*) func, &io);
munmap(data, alloc_len);
munmap(prog_mem, alloc_len);
}
void interpret(string_view src, int heap_size, FILE* infile, FILE* outfile) {
auto data = (uint8_t*) malloc(heap_size);
memset(data, 0, heap_size);
stack<int> stk;
for (int i = 0, p = 0; i < (int) src.size(); i++) {
auto ch = src[i];
assert(0 <= p && p < heap_size);
if (ch == '.') fputc(data[p], outfile);
if (ch == ',') data[p] = (char) (feof(infile) ? 0 : fgetc(infile));
if (ch == '+') data[p]++;
if (ch == '-') data[p]--;
if (ch == '<') p--;
if (ch == '>') p++;
if (ch == '[') {
if (!data[p]) {
i++;
for (int j = 1; j;) {
j += (src[i] == '[') - (src[i] == ']');
i++;
}
i--;
} else stk.push(i);
}
if (ch == ']') {
assert(stk.size());
int j = stk.top();
if (data[p]) i = j;
else stk.pop();
}
}
free(data);
}
} // namespace bfjit
namespace test
{ }
namespace
{
class Timer {
std::chrono::system_clock::time_point start_time_point;
public:
Timer() { start(); }
void start() { start_time_point = std::chrono::system_clock::now(); }
void print_duration() {
auto now = std::chrono::system_clock::now();
auto dur = std::chrono::duration_cast<std::chrono::milliseconds>(now - start_time_point);
cerr << dur.count() << "ms" << std::endl;
}
};
} // namespace
int main() {
auto src = read_all("1.bf");
auto prog = bfjit::compile(src);
ofstream out("exe", ios_base::out | ios_base::binary);
out.write((char*) prog.data(), prog.size());
out.close();
// system("objdump -D -b binary -m i386:x86-64:intel out");
{
Timer timer;
auto infile = fopen("in", "r");
// bfjit::interpret(src, 1048576, infile, stdout);
bfjit::exec(prog.data(), prog.size(), 1048576, infile, stdout);
timer.print_duration();
}
}