wasm: update ABI for passing parameters to wasm UDFs

WebAssembly uses 32-bit address space, while also
having 64-bit integers as it native types. As a result,
when passing size of an object in memory and its address,
it can be combined into one 64-bit value. As a bonus,
if the object is null, we can signal it by passing -1 as
its size.

This patch implements handling of this new ABI and adjusts
expamples in test_wasm.py.

Signed-off-by: Wojciech Mitros <wojciech.mitros@scylladb.com>
This commit is contained in:
Wojciech Mitros
2022-03-30 04:28:29 +02:00
parent 7fd81e6dae
commit a7ee3ccf52
2 changed files with 294 additions and 319 deletions

View File

@@ -66,7 +66,7 @@ static void init_abstract_arg(const abstract_type& t, const bytes_opt& param, st
auto memory = std::get<wasmtime::Memory>(*memory_export);
uint8_t* data = memory.data(store).data();
size_t mem_size = memory.size(store) * WASM_PAGE_SIZE;
const int32_t serialized_size = param ? param->size() : 0;
int32_t serialized_size = param ? param->size() : 0;
if (serialized_size > std::numeric_limits<int32_t>::max()) {
throw wasm::exception(format("Serialized parameter is too large: {} > {}", param->size(), std::numeric_limits<int32_t>::max()));
}
@@ -75,18 +75,16 @@ static void init_abstract_arg(const abstract_type& t, const bytes_opt& param, st
throw wasm::exception(format("Failed to grow wasm memory to {}: {}", serialized_size, grown.err().message()));
}
if (param) {
// put the size in wasm module's memory
std::memcpy(data + mem_size, reinterpret_cast<const char*>(&serialized_size), sizeof(int32_t));
// put the argument in wasm module's memory
std::memcpy(data + mem_size + sizeof(int32_t), param->data(), serialized_size);
std::memcpy(data + mem_size, param->data(), serialized_size);
} else {
// size of -1 means that the value is null
const int32_t is_null = -1;
std::memcpy(data + mem_size, reinterpret_cast<const char*>(&is_null), sizeof(int32_t));
serialized_size = -1;
}
// the place inside wasm memory where the struct is placed
argv.push_back(int32_t(mem_size));
// the size of the struct in top 32 bits and the place inside wasm memory where the struct is placed in the bottom 32 bits
int64_t arg_combined = ((int64_t)serialized_size << 32) | mem_size;
argv.push_back(arg_combined);
}
struct init_arg_visitor {
@@ -185,20 +183,18 @@ struct from_val_visitor {
}
bytes_opt operator()(const abstract_type& t) {
expect_kind(wasmtime::ValKind::I32);
expect_kind(wasmtime::ValKind::I64);
auto memory_export = instance.get(store, "memory");
if (!memory_export) {
throw wasm::exception("memory export not found - please export `memory` in the wasm module");
}
auto memory = std::get<wasmtime::Memory>(*memory_export);
uint8_t* mem_base = memory.data(store).data();
uint8_t* data = mem_base + val.i32();
int32_t ret_size;
std::memcpy(reinterpret_cast<char*>(&ret_size), data, 4);
uint8_t* data = mem_base + (val.i64() & 0xffffffff);
int32_t ret_size = val.i64() >> 32;
if (ret_size == -1) {
return bytes_opt{};
}
data += sizeof(int32_t); // size of the return type was consumed
return t.decompose(t.deserialize(bytes_view(reinterpret_cast<int8_t*>(data), ret_size)));
}

View File

@@ -83,11 +83,6 @@ def test_fib(cql, test_keyspace, table1, scylla_with_wasm_only):
# Compiled from:
# const int WASM_PAGE_SIZE = 64 * 1024;
#
# struct __attribute__((packed)) nullable_bigint {
# int size;
# long long v;
# };
#
# static long long swap_int64(long long val) {
# val = ((val << 8) & 0xFF00FF00FF00FF00ULL ) | ((val >> 8) & 0x00FF00FF00FF00FFULL );
# val = ((val << 16) & 0xFFFF0000FFFF0000ULL ) | ((val >> 16) & 0x0000FFFF0000FFFFULL );
@@ -100,35 +95,31 @@ def test_fib(cql, test_keyspace, table1, scylla_with_wasm_only):
# }
# return fib_aux(n-1) + fib_aux(n-2);
# }
#
# int fib(struct nullable_bigint* p) {
# // Initialize memory for the return struct
# struct nullable_bigint* ret = (struct nullable_bigint*)(__builtin_wasm_memory_size(0) * WASM_PAGE_SIZE);
# __builtin_wasm_memory_grow(0, 1 + (sizeof(struct nullable_bigint) - 1) / WASM_PAGE_SIZE); // round up
#
# ret->size = sizeof(long long);
# if (p->size == -1) {
# ret->v = swap_int64(42);
# long long fib(long long p) {
# int size = p >> 32;
# long long* p_val = (long long*)(p & 0xffffffff);
# // Initialize memory for the return value
# long long* ret_val = (long long*)(__builtin_wasm_memory_size(0) * WASM_PAGE_SIZE);
# __builtin_wasm_memory_grow(0, 1); // long long fits in one wasm page
# if (size == -1) {
# *ret_val = swap_int64(42);
# } else {
# ret->v = swap_int64(fib_aux(swap_int64(p->v)));
# *ret_val = swap_int64(fib_aux(swap_int64(*p_val)));
# }
# return (int)ret;
# // 8 is the size of a bigint
# return (long long)(8ll << 32) | (long long)ret_val;
# }
#
# with:
# $ clang -O2 --target=wasm32 --no-standard-libraries -Wl,--export-all -Wl,--no-entry fibnull.c -o fibnull.wasm
# $ clang -O2 --target=wasm32 --no-standard-libraries -Wl,--export=fib -Wl,--no-entry fibnull.c -o fibnull.wasm
# $ wasm2wat fibnull.wasm > fibnull.wat
def test_fib_called_on_null(cql, test_keyspace, table1, scylla_with_wasm_only):
table = table1
fib_name = unique_name()
fib_source = f"""
(module
(type (;0;) (func))
(type (;1;) (func (param i64) (result i64)))
(type (;2;) (func (param i32) (result i32)))
(func (;0;) (type 0)
nop)
(func (;1;) (type 1) (param i64) (result i64)
(type (;0;) (func (param i64) (result i64)))
(func (;0;) (type 0) (param i64) (result i64)
(local i64 i32)
local.get 0
i64.const 2
@@ -141,7 +132,7 @@ def test_fib_called_on_null(cql, test_keyspace, table1, scylla_with_wasm_only):
local.get 0
i64.const 1
i64.sub
call 1
call 0
local.get 1
i64.add
local.set 1
@@ -159,132 +150,124 @@ def test_fib_called_on_null(cql, test_keyspace, table1, scylla_with_wasm_only):
local.get 0
local.get 1
i64.add)
(func (;2;) (type 2) (param i32) (result i32)
(local i64 i32)
(func (;1;) (type 0) (param i64) (result i64)
(local i32 i64)
memory.size
local.set 2
local.set 1
i32.const 1
memory.grow
drop
local.get 2
i64.const 3026418949592973312
local.set 2
local.get 1
i32.const 16
i32.shl
local.tee 2
i32.const 8
i32.store
local.tee 1
local.get 0
i32.load align=1
i32.const -1
i32.eq
if ;; label = @1
i64.const -4294967297
i64.le_u
if (result i64) ;; label = @1
local.get 0
i32.wrap_i64
i64.load
local.tee 0
i64.const 56
i64.shl
local.get 0
i64.const 40
i64.shl
i64.const 71776119061217280
i64.and
i64.or
local.get 0
i64.const 24
i64.shl
i64.const 280375465082880
i64.and
local.get 0
i64.const 8
i64.shl
i64.const 1095216660480
i64.and
i64.or
i64.or
local.get 0
i64.const 8
i64.shr_u
i64.const 4278190080
i64.and
local.get 0
i64.const 24
i64.shr_u
i64.const 16711680
i64.and
i64.or
local.get 0
i64.const 40
i64.shr_u
i64.const 65280
i64.and
local.get 0
i64.const 56
i64.shr_u
i64.or
i64.or
i64.or
call 0
local.tee 0
i64.const 56
i64.shl
local.get 0
i64.const 40
i64.shl
i64.const 71776119061217280
i64.and
i64.or
local.get 0
i64.const 24
i64.shl
i64.const 280375465082880
i64.and
local.get 0
i64.const 8
i64.shl
i64.const 1095216660480
i64.and
i64.or
i64.or
local.get 0
i64.const 8
i64.shr_u
i64.const 4278190080
i64.and
local.get 0
i64.const 24
i64.shr_u
i64.const 16711680
i64.and
i64.or
local.get 0
i64.const 40
i64.shr_u
i64.const 65280
i64.and
local.get 0
i64.const 56
i64.shr_u
i64.or
i64.or
i64.or
else
local.get 2
i64.const 3026418949592973312
i64.store offset=4 align=4
local.get 2
return
end
local.get 2
local.get 0
i64.load offset=4 align=1
local.tee 1
i64.const 56
i64.shl
i64.store
local.get 1
i64.const 40
i64.shl
i64.const 71776119061217280
i64.and
i64.or
local.get 1
i64.const 24
i64.shl
i64.const 280375465082880
i64.and
local.get 1
i64.const 8
i64.shl
i64.const 1095216660480
i64.and
i64.or
i64.or
local.get 1
i64.const 8
i64.shr_u
i64.const 4278190080
i64.and
local.get 1
i64.const 24
i64.shr_u
i64.const 16711680
i64.and
i64.or
local.get 1
i64.const 40
i64.shr_u
i64.const 65280
i64.and
local.get 1
i64.const 56
i64.shr_u
i64.or
i64.or
i64.or
call 1
local.tee 1
i64.const 56
i64.shl
local.get 1
i64.const 40
i64.shl
i64.const 71776119061217280
i64.and
i64.or
local.get 1
i64.const 24
i64.shl
i64.const 280375465082880
i64.and
local.get 1
i64.const 8
i64.shl
i64.const 1095216660480
i64.and
i64.or
i64.or
local.get 1
i64.const 8
i64.shr_u
i64.const 4278190080
i64.and
local.get 1
i64.const 24
i64.shr_u
i64.const 16711680
i64.and
i64.or
local.get 1
i64.const 40
i64.shr_u
i64.const 65280
i64.and
local.get 1
i64.const 56
i64.shr_u
i64.or
i64.or
i64.or
i64.store offset=4 align=4
local.get 2)
i64.extend_i32_u
i64.const 34359738368
i64.or)
(memory (;0;) 2)
(global (;0;) i32 (i32.const 1024))
(global (;1;) i32 (i32.const 1024))
(global (;2;) i32 (i32.const 1028))
(global (;3;) i32 (i32.const 1024))
(global (;4;) i32 (i32.const 66576))
(global (;5;) i32 (i32.const 0))
(global (;6;) i32 (i32.const 1))
(export "memory" (memory 0))
(export "{fib_name}" (func 2)))
(export "{fib_name}" (func 1)))
"""
src = f"(input bigint) CALLED ON NULL INPUT RETURNS bigint LANGUAGE xwasm AS '{fib_source}'"
with new_function(cql, test_keyspace, src, fib_name):
@@ -632,27 +615,21 @@ def test_validate_params(cql, test_keyspace, table1, scylla_with_wasm_only):
# Created with:
# const int WASM_PAGE_SIZE = 64 * 1024;
# struct __attribute__((packed)) param {
# int size;
# char buf[0];
# };
# int dbl(struct param* par) {
# int size = par->size;
# int position = (int)par->buf;
# long long dbl(long long par) {
# int size = par << 32;
# int position = par & 0xffffffff;
# int orig_size = __builtin_wasm_memory_size(0) * WASM_PAGE_SIZE;
# __builtin_wasm_memory_grow(0, 1 + (2 * size - 1) / WASM_PAGE_SIZE);
# char* p = (char*)0;
# for (int i = 0; i < size; ++i) {
# p[orig_size + sizeof(int) + i] = p[position + i];
# p[orig_size + size + sizeof(int) + i] = p[position + i];
# p[orig_size + i] = p[position + i];
# p[orig_size + size + i] = p[position + i];
# }
# int* ret = (int*)orig_size;
# *ret = 2*size;
# return orig_size;
# long long ret = ((long long)2 * size << 32) | (long long)orig_size;
# return ret;
# }
# ... and compiled with
# clang --target=wasm32 --no-standard-libraries -Wl,--export-all -Wl,--no-entry demo.c -o demo.wasm
# clang --target=wasm32 --no-standard-libraries -Wl,--export=dbl -Wl,--no-entry demo.c -o demo.wasm
# wasm2wat demo.wasm > demo.wat
def test_word_double(cql, test_keyspace, table1, scylla_with_wasm_only):
@@ -660,14 +637,12 @@ def test_word_double(cql, test_keyspace, table1, scylla_with_wasm_only):
dbl_name = unique_name()
dbl_source = f"""
(module
(type (;0;) (func))
(type (;1;) (func (param i32) (result i32)))
(func $__wasm_call_ctors (type 0))
(func $dbl (type 1) (param i32) (result i32)
(local i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32)
(type (;0;) (func (param i64) (result i64)))
(func $dbl (type 0) (param i64) (result i64)
(local i32 i32 i32 i64 i64 i64 i32 i64 i64 i64 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i32 i64 i64 i64 i64 i64 i32 i32 i64 i64 i64)
global.get $__stack_pointer
local.set 1
i32.const 32
i32.const 48
local.set 2
local.get 1
local.get 2
@@ -675,253 +650,257 @@ def test_word_double(cql, test_keyspace, table1, scylla_with_wasm_only):
local.set 3
local.get 3
local.get 0
i32.store offset=28
i64.store offset=40
local.get 3
i32.load offset=28
i64.load offset=40
local.set 4
local.get 4
i32.load align=1
i64.const 32
local.set 5
local.get 3
local.get 4
local.get 5
i32.store offset=24
local.get 3
i32.load offset=28
i64.shr_s
local.set 6
i32.const 4
local.set 7
local.get 6
local.get 7
i32.add
local.set 8
i32.wrap_i64
local.set 7
local.get 3
local.get 8
i32.store offset=20
memory.size
local.get 7
i32.store offset=36
local.get 3
i64.load offset=40
local.set 8
i64.const 4294967295
local.set 9
i32.const 16
local.set 10
local.get 8
local.get 9
i64.and
local.set 10
local.get 10
i32.shl
i32.wrap_i64
local.set 11
local.get 3
local.get 11
i32.store offset=16
local.get 3
i32.load offset=24
i32.store offset=32
memory.size
local.set 12
i32.const 1
i32.const 16
local.set 13
local.get 12
local.get 13
i32.shl
local.set 14
i32.const 1
local.set 15
local.get 3
local.get 14
local.get 15
i32.sub
local.set 16
i32.const 65536
local.set 17
local.get 16
local.get 17
i32.div_s
local.set 18
i32.store offset=28
local.get 3
i32.load offset=36
local.set 15
i32.const 1
local.set 19
local.set 16
local.get 15
local.get 16
i32.shl
local.set 17
i32.const 1
local.set 18
local.get 17
local.get 18
local.get 19
i32.add
i32.sub
local.set 19
i32.const 65536
local.set 20
local.get 19
local.get 20
i32.div_s
local.set 21
i32.const 1
local.set 22
local.get 21
local.get 22
i32.add
local.set 23
local.get 23
memory.grow
drop
i32.const 0
local.set 21
local.set 24
local.get 3
local.get 21
i32.store offset=12
local.get 24
i32.store offset=24
i32.const 0
local.set 22
local.set 25
local.get 3
local.get 22
i32.store offset=8
local.get 25
i32.store offset=20
block ;; label = @1
loop ;; label = @2
local.get 3
i32.load offset=8
local.set 23
local.get 3
i32.load offset=24
local.set 24
local.get 23
local.set 25
local.get 24
i32.load offset=20
local.set 26
local.get 25
local.get 26
i32.lt_s
local.get 3
i32.load offset=36
local.set 27
i32.const 1
local.get 26
local.set 28
local.get 27
local.get 28
i32.and
local.set 29
local.get 28
local.get 29
i32.lt_s
local.set 30
i32.const 1
local.set 31
local.get 30
local.get 31
i32.and
local.set 32
local.get 32
i32.eqz
br_if 1 (;@1;)
local.get 3
i32.load offset=12
local.set 30
i32.load offset=24
local.set 33
local.get 3
i32.load offset=32
local.set 34
local.get 3
i32.load offset=20
local.set 31
local.get 3
i32.load offset=8
local.set 32
local.get 31
local.get 32
i32.add
local.set 33
local.get 30
local.get 33
i32.add
local.set 34
local.get 34
i32.load8_u
local.set 35
local.get 3
i32.load offset=12
local.set 36
local.get 3
i32.load offset=16
local.set 37
i32.const 4
local.set 38
local.get 37
local.get 38
local.get 34
local.get 35
i32.add
local.set 36
local.get 33
local.get 36
i32.add
local.set 37
local.get 37
i32.load8_u
local.set 38
local.get 3
i32.load offset=24
local.set 39
local.get 3
i32.load offset=8
i32.load offset=28
local.set 40
local.get 39
local.get 40
i32.add
local.get 3
i32.load offset=20
local.set 41
local.get 36
local.get 40
local.get 41
i32.add
local.set 42
local.get 39
local.get 42
local.get 35
i32.add
local.set 43
local.get 43
local.get 38
i32.store8
local.get 3
i32.load offset=12
local.set 43
local.get 3
i32.load offset=20
i32.load offset=24
local.set 44
local.get 3
i32.load offset=8
i32.load offset=32
local.set 45
local.get 44
local.get 45
i32.add
local.get 3
i32.load offset=20
local.set 46
local.get 43
local.get 45
local.get 46
i32.add
local.set 47
local.get 44
local.get 47
i32.load8_u
i32.add
local.set 48
local.get 3
i32.load offset=12
local.get 48
i32.load8_u
local.set 49
local.get 3
i32.load offset=16
i32.load offset=24
local.set 50
local.get 3
i32.load offset=24
i32.load offset=28
local.set 51
local.get 50
local.get 51
i32.add
local.set 52
i32.const 4
local.set 53
local.get 52
local.get 53
i32.add
local.set 54
local.get 3
i32.load offset=8
local.set 55
i32.load offset=36
local.set 52
local.get 51
local.get 52
i32.add
local.set 53
local.get 3
i32.load offset=20
local.set 54
local.get 53
local.get 54
i32.add
local.set 55
local.get 50
local.get 55
i32.add
local.set 56
local.get 49
local.get 56
i32.add
local.set 57
local.get 57
local.get 48
local.get 49
i32.store8
local.get 3
i32.load offset=8
local.set 58
i32.load offset=20
local.set 57
i32.const 1
local.set 59
local.set 58
local.get 57
local.get 58
local.get 59
i32.add
local.set 60
local.set 59
local.get 3
local.get 60
i32.store offset=8
local.get 59
i32.store offset=20
br 0 (;@2;)
end
end
local.get 3
i32.load offset=16
i32.load offset=36
local.set 60
local.get 60
local.set 61
local.get 3
local.get 61
i32.store offset=4
local.get 3
i32.load offset=24
i64.extend_i32_s
local.set 62
i32.const 1
i64.const 1
local.set 63
local.get 62
local.get 63
i32.shl
i64.shl
local.set 64
local.get 3
i32.load offset=4
i64.const 32
local.set 65
local.get 65
local.get 64
i32.store
local.get 3
i32.load offset=16
local.get 65
i64.shl
local.set 66
local.get 3
i32.load offset=28
local.set 67
local.get 67
local.set 68
local.get 68
i64.extend_i32_s
local.set 69
local.get 66
local.get 69
i64.or
local.set 70
local.get 3
local.get 70
i64.store offset=8
local.get 3
i64.load offset=8
local.set 71
local.get 71
return)
(memory (;0;) 2)
(global $__stack_pointer (mut i32) (i32.const 66576))
(global (;1;) i32 (i32.const 1024))
(global (;2;) i32 (i32.const 1024))
(global (;3;) i32 (i32.const 1028))
(global (;4;) i32 (i32.const 1024))
(global (;5;) i32 (i32.const 66576))
(global (;6;) i32 (i32.const 0))
(global (;7;) i32 (i32.const 1))
(global $__stack_pointer (mut i32) (i32.const 66560))
(export "memory" (memory 0))
(export "{dbl_name}" (func $dbl)))
"""