diff --git a/bridge.go b/bridge.go index 16ad1eb..d925019 100644 --- a/bridge.go +++ b/bridge.go @@ -1,8 +1,10 @@ package wasm import ( + "context" "encoding/binary" "fmt" + "log" "math" "reflect" "sync" @@ -16,7 +18,7 @@ import ( var undefined = &struct{}{} var bridges = map[string]*Bridge{} var mu sync.RWMutex // to protect bridges -type context struct{ n string } +type bctx struct{ n string } func getCtxData(b *Bridge) (unsafe.Pointer, error) { mu.Lock() @@ -26,12 +28,12 @@ func getCtxData(b *Bridge) (unsafe.Pointer, error) { } bridges[b.name] = b - return unsafe.Pointer(&context{n: b.name}), nil + return unsafe.Pointer(&bctx{n: b.name}), nil } func getBridge(ctx unsafe.Pointer) *Bridge { ictx := wasmer.IntoInstanceContext(ctx) - c := (*context)(ictx.Data()) + c := (*bctx)(ictx.Data()) mu.RLock() defer mu.RUnlock() return bridges[c.n] @@ -40,12 +42,13 @@ func getBridge(ctx unsafe.Pointer) *Bridge { type Bridge struct { name string instance wasmer.Instance - done chan bool exitCode int values []interface{} + valuesMu sync.RWMutex refs map[interface{}]int memory []byte exited bool + cancF context.CancelFunc } func BridgeFromBytes(name string, bytes []byte, imports *wasmer.Imports) (*Bridge, error) { @@ -175,11 +178,10 @@ func (b *Bridge) check() { } // Run start the wasm instance. -func (b *Bridge) Run(init chan error, done chan bool) { +func (b *Bridge) Run(ctx context.Context, init chan error) { b.check() defer b.instance.Close() - b.done = done run := b.instance.Exports["run"] _, err := run(0, 0) if err != nil { @@ -187,9 +189,15 @@ func (b *Bridge) Run(init chan error, done chan bool) { return } + ctx, cancF := context.WithCancel(ctx) + b.cancF = cancF init <- nil - <-b.done - fmt.Printf("WASM exited with code: %v\n", b.exitCode) + select { + case <-ctx.Done(): + log.Printf("stopping WASM[%s] instance...\n", b.name) + b.exited = true + return + } } func (b *Bridge) mem() []byte { @@ -297,6 +305,9 @@ func (b *Bridge) loadValue(addr int32) interface{} { return f } + b.valuesMu.RLock() + defer b.valuesMu.RUnlock() + return b.values[b.getUint32(addr)] } @@ -353,9 +364,11 @@ func (b *Bridge) storeValue(addr int32, v interface{}) { ref, ok := b.refs[v] if !ok { + b.valuesMu.RLock() ref = len(b.values) b.values = append(b.values, v) b.refs[v] = ref + b.valuesMu.RUnlock() } typeFlag := 0 @@ -422,7 +435,9 @@ type funcWrapper struct { } func (b *Bridge) makeFuncWrapper(id, this interface{}, args *[]interface{}) (interface{}, error) { + b.valuesMu.RLock() goObj := b.values[7].(*object) + b.valuesMu.RUnlock() event := propObject("_pendingEvent", map[string]interface{}{ "id": id, "this": nil, @@ -440,15 +455,21 @@ func (b *Bridge) makeFuncWrapper(id, this interface{}, args *[]interface{}) (int func (b *Bridge) CallFunc(fn string, args []interface{}) (interface{}, error) { b.check() + b.valuesMu.RLock() fw, ok := b.values[5].(*object).props[fn] if !ok { return nil, fmt.Errorf("missing function: %v", fn) } - return b.makeFuncWrapper(fw.(*funcWrapper).id, b.values[7], &args) + this := b.values[7] + b.valuesMu.RUnlock() + + return b.makeFuncWrapper(fw.(*funcWrapper).id, this, &args) } func (b *Bridge) SetFunc(fname string, fn Func) error { + b.valuesMu.RLock() + defer b.valuesMu.RUnlock() b.values[5].(*object).props[fname] = &fn return nil } diff --git a/examples/caller/main.go b/examples/caller/main.go index 32956cd..df9dfc9 100644 --- a/examples/caller/main.go +++ b/examples/caller/main.go @@ -1,36 +1,51 @@ package main import ( + "context" "log" "github.com/vedhavyas/go-wasm" ) -func proxy(b *wasm.Bridge) wasm.Func { +func addProxy(b *wasm.Bridge) wasm.Func { return func(args []interface{}) (i interface{}, e error) { log.Println("In Go", args) return b.CallFunc("addition", args) } } -func main() { - b, err := wasm.BridgeFromFile("test", "./examples/wasm/main.wasm", nil) +func multiply(b *wasm.Bridge, a int) (int, error) { + m, err := b.CallFunc("multiplier", nil) if err != nil { - log.Fatal(err) + return 0, err } - err = b.SetFunc("proxy", proxy(b)) + return a * int(m.(float64)), nil +} + +func main() { + b, err := wasm.BridgeFromFile("test", "./examples/wasm/main.wasm", nil) if err != nil { panic(err) } - init, done := make(chan error), make(chan bool) - go b.Run(init, done) + err = b.SetFunc("addProxy", addProxy(b)) + if err != nil { + panic(err) + } + + init := make(chan error) + ctx, cancF := context.WithCancel(context.Background()) + defer cancF() + go b.Run(ctx, init) err = <-init if err != nil { panic(err) } - <-done - log.Println("wasm exited", err) + mul, err := multiply(b, 10) + if err != nil { + panic(err) + } + log.Printf("Multiplier: %v\n", mul) } diff --git a/examples/wasm/main.go b/examples/wasm/main.go index d0dfd1e..bdccbe5 100644 --- a/examples/wasm/main.go +++ b/examples/wasm/main.go @@ -13,14 +13,18 @@ func addition(this js.Value, args []js.Value) interface{} { return a + b } +func multiplier(this js.Value, args []js.Value) interface{} { + return 10 +} + func main() { ch := make(chan bool) // register functions - fun := js.FuncOf(addition) - js.Global().Set("addition", fun) + js.Global().Set("addition", js.FuncOf(addition)) + js.Global().Set("multiplier", js.FuncOf(multiplier)) - res := js.Global().Get("proxy").Invoke(1, 2) + res := js.Global().Get("addProxy").Invoke(1, 2) log.Printf("1 + 2 = %d\n", res.Int()) <-ch } diff --git a/examples/wasm/main.wasm b/examples/wasm/main.wasm index 9fdae53..a8916e5 100755 Binary files a/examples/wasm/main.wasm and b/examples/wasm/main.wasm differ diff --git a/imports.go b/imports.go index d7871ba..6b0d69b 100644 --- a/imports.go +++ b/imports.go @@ -29,6 +29,7 @@ import "C" import ( "crypto/rand" "fmt" + "log" "reflect" "syscall" "time" @@ -38,16 +39,15 @@ import ( ) //export debug -func debug(ctx unsafe.Pointer, sp int32) { - fmt.Println(sp) +func debug(_ unsafe.Pointer, sp int32) { + log.Println(sp) } //export wexit func wexit(ctx unsafe.Pointer, sp int32) { b := getBridge(ctx) b.exitCode = int(b.getUint32(sp + 8)) - b.exited = true - close(b.done) + b.cancF() } //export wwrite @@ -56,7 +56,10 @@ func wwrite(ctx unsafe.Pointer, sp int32) { fd := int(b.getInt64(sp + 8)) p := int(b.getInt64(sp + 16)) l := int(b.getInt32(sp + 24)) - syscall.Write(fd, b.mem()[p:p+l]) + _, err := syscall.Write(fd, b.mem()[p:p+l]) + if err != nil { + panic(fmt.Errorf("wasm-write: %v", err)) + } } //export nanotime @@ -76,12 +79,12 @@ func walltime(ctx unsafe.Pointer, sp int32) { } //export scheduleCallback -func scheduleCallback(ctx unsafe.Pointer, sp int32) { +func scheduleCallback(_ unsafe.Pointer, _ int32) { panic("schedule callback") } //export clearScheduledCallback -func clearScheduledCallback(ctx unsafe.Pointer, sp int32) { +func clearScheduledCallback(_ unsafe.Pointer, _ int32) { panic("clear scheduled callback") } @@ -89,7 +92,6 @@ func clearScheduledCallback(ctx unsafe.Pointer, sp int32) { func getRandomData(ctx unsafe.Pointer, sp int32) { s := getBridge(ctx).loadSlice(sp + 8) _, err := rand.Read(s) - // TODO how to pass error? if err != nil { panic("failed: getRandomData") } @@ -146,7 +148,7 @@ func valueIndex(ctx unsafe.Pointer, sp int32) { } //export valueSetIndex -func valueSetIndex(ctx unsafe.Pointer, sp int32) { +func valueSetIndex(_ unsafe.Pointer, _ int32) { panic("valueSetIndex") } @@ -234,12 +236,12 @@ func valueLoadString(ctx unsafe.Pointer, sp int32) { } //export scheduleTimeoutEvent -func scheduleTimeoutEvent(ctx unsafe.Pointer, sp int32) { +func scheduleTimeoutEvent(_ unsafe.Pointer, _ int32) { panic("scheduleTimeoutEvent") } //export clearTimeoutEvent -func clearTimeoutEvent(ctx unsafe.Pointer, sp int32) { +func clearTimeoutEvent(_ unsafe.Pointer, _ int32) { panic("clearTimeoutEvent") }