446 lines
13 KiB
Go
446 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"go/format"
|
|
"io/ioutil"
|
|
"os"
|
|
"path"
|
|
"strings"
|
|
|
|
"github.com/varlink/go/varlink/idl"
|
|
)
|
|
|
|
func writeType(b *bytes.Buffer, t *idl.Type, json bool, ident int) {
|
|
switch t.Kind {
|
|
case idl.TypeBool:
|
|
b.WriteString("bool")
|
|
|
|
case idl.TypeInt:
|
|
b.WriteString("int64")
|
|
|
|
case idl.TypeFloat:
|
|
b.WriteString("float64")
|
|
|
|
case idl.TypeString, idl.TypeEnum:
|
|
b.WriteString("string")
|
|
|
|
case idl.TypeObject:
|
|
b.WriteString("json.RawMessage")
|
|
|
|
case idl.TypeArray:
|
|
b.WriteString("[]")
|
|
writeType(b, t.ElementType, json, ident)
|
|
|
|
case idl.TypeMap:
|
|
b.WriteString("map[string]")
|
|
writeType(b, t.ElementType, json, ident)
|
|
|
|
case idl.TypeMaybe:
|
|
b.WriteString("*")
|
|
writeType(b, t.ElementType, json, ident)
|
|
|
|
case idl.TypeAlias:
|
|
b.WriteString(t.Alias)
|
|
|
|
case idl.TypeStruct:
|
|
if len(t.Fields) == 0 {
|
|
b.WriteString("struct{}")
|
|
} else {
|
|
b.WriteString("struct {\n")
|
|
for _, field := range t.Fields {
|
|
for i := 0; i < ident+1; i++ {
|
|
b.WriteString("\t")
|
|
}
|
|
|
|
b.WriteString(strings.Title(field.Name) + " ")
|
|
writeType(b, field.Type, json, ident+1)
|
|
if json {
|
|
b.WriteString(" `json:\"" + field.Name)
|
|
if field.Type.Kind == idl.TypeMaybe {
|
|
b.WriteString(",omitempty")
|
|
}
|
|
b.WriteString("\"`")
|
|
}
|
|
b.WriteString("\n")
|
|
}
|
|
for i := 0; i < ident; i++ {
|
|
b.WriteString("\t")
|
|
}
|
|
b.WriteString("}")
|
|
}
|
|
}
|
|
}
|
|
|
|
func writeDocString(b *bytes.Buffer, s string) {
|
|
if s == "" {
|
|
return
|
|
}
|
|
|
|
// Quote multi-line docstrings
|
|
b.WriteString("// " + strings.Replace(s, "\n", "\n// ", -1))
|
|
b.WriteString("\n")
|
|
}
|
|
|
|
func generateTemplate(description string) (string, []byte, error) {
|
|
description = strings.TrimRight(description, "\n")
|
|
|
|
midl, err := idl.New(description)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
pkgname := strings.Replace(midl.Name, ".", "", -1)
|
|
|
|
var b bytes.Buffer
|
|
b.WriteString("// Generated with github.com/varlink/go/cmd/varlink-go-interface-generator\n\n")
|
|
|
|
writeDocString(&b, midl.Doc)
|
|
b.WriteString("package " + pkgname + "\n\n")
|
|
b.WriteString("@IMPORTS@\n\n")
|
|
|
|
b.WriteString("// Generated type declarations\n\n")
|
|
|
|
for _, a := range midl.Aliases {
|
|
writeDocString(&b, a.Doc)
|
|
b.WriteString("type " + a.Name + " ")
|
|
writeType(&b, a.Type, true, 0)
|
|
b.WriteString("\n\n")
|
|
}
|
|
|
|
for _, a := range midl.Errors {
|
|
writeDocString(&b, a.Doc)
|
|
b.WriteString("type " + a.Name + " ")
|
|
writeType(&b, a.Type, true, 0)
|
|
b.WriteString("\nfunc (e " + a.Name + ") Error() string {\n")
|
|
b.WriteString("\treturn \"" + midl.Name + "." + a.Name + "\"\n")
|
|
b.WriteString("}\n\n")
|
|
}
|
|
|
|
b.WriteString("func Dispatch_Error(err error) error {\n")
|
|
b.WriteString("\tif e, ok := err.(*varlink.Error); ok {\n")
|
|
b.WriteString("\t\tswitch e.Name {\n")
|
|
for _, a := range midl.Errors {
|
|
b.WriteString("\t\tcase \"" + midl.Name + "." + a.Name + "\":\n")
|
|
b.WriteString("\t\t\terrorRawParameters := e.Parameters.(*json.RawMessage)\n")
|
|
b.WriteString("\t\t\tif errorRawParameters == nil {\n")
|
|
b.WriteString("\t\t\t\treturn e\n")
|
|
b.WriteString("\t\t\t}\n")
|
|
b.WriteString("\t\t\tvar param " + a.Name + "\n")
|
|
b.WriteString("\t\t\terr := json.Unmarshal(*errorRawParameters, ¶m)\n")
|
|
b.WriteString("\t\t\tif err != nil {\n")
|
|
b.WriteString("\t\t\t\treturn e\n")
|
|
b.WriteString("\t\t\t}\n")
|
|
b.WriteString("\t\t\treturn ¶m\n")
|
|
}
|
|
b.WriteString("\t\t}\n")
|
|
b.WriteString("\t}\n")
|
|
b.WriteString("\treturn err\n")
|
|
b.WriteString("}\n\n")
|
|
|
|
b.WriteString("// Generated client method calls\n\n")
|
|
|
|
for _, m := range midl.Methods {
|
|
writeDocString(&b, m.Doc)
|
|
b.WriteString("type " + m.Name + "_methods struct{}\n")
|
|
b.WriteString("func " + m.Name + "() " + m.Name + "_methods { return " + m.Name + "_methods{} }\n\n")
|
|
|
|
b.WriteString("func (m " + m.Name + "_methods) Call(c *varlink.Connection")
|
|
for _, field := range m.In.Fields {
|
|
b.WriteString(", " + field.Name + "_in_ ")
|
|
writeType(&b, field.Type, false, 1)
|
|
}
|
|
b.WriteString(") (")
|
|
for _, field := range m.Out.Fields {
|
|
b.WriteString(field.Name + "_out_ ")
|
|
writeType(&b, field.Type, false, 1)
|
|
b.WriteString(", ")
|
|
}
|
|
b.WriteString("err_ error) {\n")
|
|
b.WriteString("receive, err_ := m.Send(c, 0")
|
|
for _, field := range m.In.Fields {
|
|
b.WriteString(", " + field.Name + "_in_ ")
|
|
}
|
|
b.WriteString(")\n")
|
|
b.WriteString("if err_ != nil {\n" +
|
|
"\treturn\n" +
|
|
"}\n")
|
|
b.WriteString("\t")
|
|
for _, field := range m.Out.Fields {
|
|
b.WriteString(field.Name + "_out_ ")
|
|
b.WriteString(", ")
|
|
}
|
|
b.WriteString("_, err_ = receive()\n")
|
|
b.WriteString("\treturn\n" +
|
|
"}\n\n")
|
|
|
|
b.WriteString("func (m " + m.Name + "_methods) Send(c *varlink.Connection, flags uint64")
|
|
for _, field := range m.In.Fields {
|
|
b.WriteString(", " + field.Name + "_in_ ")
|
|
writeType(&b, field.Type, false, 1)
|
|
}
|
|
b.WriteString(") (func() (")
|
|
for _, field := range m.Out.Fields {
|
|
writeType(&b, field.Type, false, 1)
|
|
b.WriteString(", ")
|
|
}
|
|
b.WriteString("uint64, error), error) {\n")
|
|
if len(m.In.Fields) > 0 {
|
|
b.WriteString("\tvar in ")
|
|
writeType(&b, m.In, true, 1)
|
|
b.WriteString("\n")
|
|
for _, field := range m.In.Fields {
|
|
switch field.Type.Kind {
|
|
case idl.TypeStruct, idl.TypeArray, idl.TypeMap:
|
|
b.WriteString("\tin." + strings.Title(field.Name) + " = ")
|
|
writeType(&b, field.Type, true, 1)
|
|
b.WriteString("(" + field.Name + "_in_)\n")
|
|
|
|
default:
|
|
b.WriteString("\tin." + strings.Title(field.Name) + " = " + field.Name + "_in_\n")
|
|
}
|
|
}
|
|
b.WriteString("\treceive, err := c.Send(\"" + midl.Name + "." + m.Name + "\", in, flags)\n")
|
|
} else {
|
|
b.WriteString("\treceive, err := c.Send(\"" + midl.Name + "." + m.Name + "\", nil, flags)\n")
|
|
}
|
|
b.WriteString("\tif err != nil {\n" +
|
|
"\t\treturn nil, err\n" +
|
|
"\t}\n")
|
|
b.WriteString("\treturn func() (")
|
|
for _, field := range m.Out.Fields {
|
|
b.WriteString(field.Name + "_out_ ")
|
|
writeType(&b, field.Type, false, 3)
|
|
b.WriteString(", ")
|
|
}
|
|
b.WriteString("flags uint64, err error) {\n")
|
|
if len(m.Out.Fields) > 0 {
|
|
b.WriteString("\t\tvar out ")
|
|
writeType(&b, m.Out, true, 2)
|
|
b.WriteString("\n")
|
|
b.WriteString("\t\tflags, err = receive(&out)\n")
|
|
} else {
|
|
b.WriteString("\t\tflags, err = receive(nil)\n")
|
|
}
|
|
b.WriteString("\t\tif err != nil {\n" +
|
|
"\t\t\terr = Dispatch_Error(err)\n" +
|
|
"\t\t\treturn\n" +
|
|
"\t\t}\n")
|
|
for _, field := range m.Out.Fields {
|
|
b.WriteString("\t\t" + field.Name + "_out_ = ")
|
|
switch field.Type.Kind {
|
|
case idl.TypeStruct, idl.TypeArray, idl.TypeMap:
|
|
writeType(&b, field.Type, false, 2)
|
|
b.WriteString("(out." + strings.Title(field.Name) + ")\n")
|
|
|
|
default:
|
|
b.WriteString("out." + strings.Title(field.Name) + "\n")
|
|
}
|
|
}
|
|
b.WriteString("\t\treturn\n" +
|
|
"\t}, nil\n")
|
|
b.WriteString("}\n\n")
|
|
}
|
|
|
|
b.WriteString("// Generated service interface with all methods\n\n")
|
|
|
|
b.WriteString("type " + pkgname + "Interface interface {\n")
|
|
for _, m := range midl.Methods {
|
|
b.WriteString("\t" + m.Name + "(c VarlinkCall")
|
|
for _, field := range m.In.Fields {
|
|
b.WriteString(", " + field.Name + "_ ")
|
|
writeType(&b, field.Type, false, 1)
|
|
}
|
|
b.WriteString(") error\n")
|
|
}
|
|
b.WriteString("}\n\n")
|
|
|
|
b.WriteString("// Generated service object with all methods\n\n")
|
|
|
|
b.WriteString("type VarlinkCall struct{ varlink.Call }\n\n")
|
|
|
|
b.WriteString("// Generated reply methods for all varlink errors\n\n")
|
|
|
|
for _, e := range midl.Errors {
|
|
writeDocString(&b, e.Doc)
|
|
b.WriteString("func (c *VarlinkCall) Reply" + e.Name + "(")
|
|
for i, field := range e.Type.Fields {
|
|
if i > 0 {
|
|
b.WriteString(", ")
|
|
}
|
|
b.WriteString(field.Name + "_ ")
|
|
writeType(&b, field.Type, false, 1)
|
|
}
|
|
b.WriteString(") error {\n")
|
|
b.WriteString("\tvar out " + e.Name + "\n")
|
|
if len(e.Type.Fields) > 0 {
|
|
for _, field := range e.Type.Fields {
|
|
switch field.Type.Kind {
|
|
case idl.TypeStruct, idl.TypeArray, idl.TypeMap:
|
|
b.WriteString("\tout." + strings.Title(field.Name) + " = ")
|
|
writeType(&b, field.Type, true, 1)
|
|
b.WriteString("(" + field.Name + "_)\n")
|
|
|
|
default:
|
|
b.WriteString("\tout." + strings.Title(field.Name) + " = " + field.Name + "_\n")
|
|
}
|
|
}
|
|
}
|
|
b.WriteString("\treturn c.ReplyError(\"" + midl.Name + "." + e.Name + "\", &out)\n")
|
|
b.WriteString("}\n\n")
|
|
}
|
|
|
|
b.WriteString("// Generated reply methods for all varlink methods\n\n")
|
|
|
|
for _, m := range midl.Methods {
|
|
b.WriteString("func (c *VarlinkCall) Reply" + m.Name + "(")
|
|
for i, field := range m.Out.Fields {
|
|
if i > 0 {
|
|
b.WriteString(", ")
|
|
}
|
|
b.WriteString(field.Name + "_ ")
|
|
writeType(&b, field.Type, false, 1)
|
|
}
|
|
b.WriteString(") error {\n")
|
|
if len(m.Out.Fields) > 0 {
|
|
b.WriteString("\tvar out ")
|
|
writeType(&b, m.Out, true, 1)
|
|
b.WriteString("\n")
|
|
for _, field := range m.Out.Fields {
|
|
switch field.Type.Kind {
|
|
case idl.TypeStruct, idl.TypeArray, idl.TypeMap:
|
|
b.WriteString("\tout." + strings.Title(field.Name) + " = ")
|
|
writeType(&b, field.Type, true, 1)
|
|
b.WriteString("(" + field.Name + "_)\n")
|
|
|
|
default:
|
|
b.WriteString("\tout." + strings.Title(field.Name) + " = " + field.Name + "_\n")
|
|
}
|
|
}
|
|
b.WriteString("\treturn c.Reply(&out)\n")
|
|
} else {
|
|
b.WriteString("\treturn c.Reply(nil)\n")
|
|
}
|
|
b.WriteString("}\n\n")
|
|
}
|
|
|
|
b.WriteString("// Generated dummy implementations for all varlink methods\n\n")
|
|
|
|
for _, m := range midl.Methods {
|
|
writeDocString(&b, m.Doc)
|
|
b.WriteString("func (s *VarlinkInterface) " + m.Name + "(c VarlinkCall")
|
|
for _, field := range m.In.Fields {
|
|
b.WriteString(", " + field.Name + "_ ")
|
|
writeType(&b, field.Type, false, 1)
|
|
}
|
|
b.WriteString(") error {\n" +
|
|
"\treturn c.ReplyMethodNotImplemented(\"" + midl.Name + "." + m.Name + "\")\n" +
|
|
"}\n\n")
|
|
}
|
|
|
|
b.WriteString("// Generated method call dispatcher\n\n")
|
|
|
|
b.WriteString("func (s *VarlinkInterface) VarlinkDispatch(call varlink.Call, methodname string) error {\n" +
|
|
"\tswitch methodname {\n")
|
|
for _, m := range midl.Methods {
|
|
b.WriteString("\tcase \"" + m.Name + "\":\n")
|
|
if len(m.In.Fields) > 0 {
|
|
b.WriteString("\t\tvar in ")
|
|
writeType(&b, m.In, true, 2)
|
|
b.WriteString("\n")
|
|
b.WriteString("\t\terr := call.GetParameters(&in)\n" +
|
|
"\t\tif err != nil {\n" +
|
|
"\t\t\treturn call.ReplyInvalidParameter(\"parameters\")\n" +
|
|
"\t\t}\n")
|
|
b.WriteString("\t\treturn s." + pkgname + "Interface." + m.Name + "(VarlinkCall{call}")
|
|
if len(m.In.Fields) > 0 {
|
|
for _, field := range m.In.Fields {
|
|
switch field.Type.Kind {
|
|
case idl.TypeStruct, idl.TypeArray, idl.TypeMap:
|
|
b.WriteString(", ")
|
|
writeType(&b, field.Type, false, 2)
|
|
b.WriteString("(in." + strings.Title(field.Name) + ")")
|
|
|
|
default:
|
|
b.WriteString(", in." + strings.Title(field.Name))
|
|
}
|
|
}
|
|
}
|
|
b.WriteString(")\n")
|
|
} else {
|
|
b.WriteString("\t\treturn s." + pkgname + "Interface." + m.Name + "(VarlinkCall{call})\n")
|
|
}
|
|
b.WriteString("\n")
|
|
}
|
|
b.WriteString("\tdefault:\n" +
|
|
"\t\treturn call.ReplyMethodNotFound(methodname)\n" +
|
|
"\t}\n" +
|
|
"}\n\n")
|
|
|
|
b.WriteString("// Generated varlink interface name\n\n")
|
|
|
|
b.WriteString("func (s *VarlinkInterface) VarlinkGetName() string {\n" +
|
|
"\treturn `" + midl.Name + "`\n" + "}\n\n")
|
|
|
|
b.WriteString("// Generated varlink interface description\n\n")
|
|
|
|
// Special-quote backtick, it cannot be part of a backtick-quoted string
|
|
b.WriteString("func (s *VarlinkInterface) VarlinkGetDescription() string {\n" +
|
|
"\treturn `" + strings.Replace(midl.Description, "`", "` + \"`\" + `", -1) + "\n`\n}\n\n")
|
|
|
|
b.WriteString("// Generated service interface\n\n")
|
|
|
|
b.WriteString("type VarlinkInterface struct {\n" +
|
|
"\t" + pkgname + "Interface\n" +
|
|
"}\n\n")
|
|
|
|
b.WriteString("func VarlinkNew(m " + pkgname + "Interface) *VarlinkInterface {\n" +
|
|
"\treturn &VarlinkInterface{m}\n" +
|
|
"}\n")
|
|
|
|
ret_string := b.String()
|
|
|
|
if strings.Contains(ret_string, "json.RawMessage") {
|
|
ret_string = strings.Replace(ret_string, "@IMPORTS@", "import (\n\t\"github.com/varlink/go/varlink\"\n\t\"encoding/json\"\n)", 1)
|
|
} else {
|
|
ret_string = strings.Replace(ret_string, "@IMPORTS@", `import "github.com/varlink/go/varlink"`, 1)
|
|
}
|
|
|
|
pretty, err := format.Source([]byte(ret_string))
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
return pkgname, pretty, nil
|
|
}
|
|
|
|
func generateFile(varlinkFile string) {
|
|
file, err := ioutil.ReadFile(varlinkFile)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error reading file '%s': %s\n", varlinkFile, err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
pkgname, b, err := generateTemplate(string(file))
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error parsing file '%s': %s\n", varlinkFile, err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
filename := path.Dir(varlinkFile) + "/" + pkgname + ".go"
|
|
err = ioutil.WriteFile(filename, b, 0660)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error writing file '%s': %s\n", filename, err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
if len(os.Args) != 2 {
|
|
fmt.Printf("Usage: %s <file>\n", os.Args[0])
|
|
os.Exit(1)
|
|
}
|
|
generateFile(os.Args[1])
|
|
}
|