// Copyright 2019 DeepMap, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package codegen
import (
"bufio"
"bytes"
"fmt"
"go/format"
"regexp"
"sort"
"strings"
"text/template"
"github.com/getkin/kin-openapi/openapi3"
"github.com/pkg/errors"
"github.com/deepmap/oapi-codegen/pkg/codegen/templates"
)
// Options defines the optional code to generate.
type Options struct {
GenerateChiServer bool // GenerateChiServer specifies whether to generate chi server boilerplate
GenerateEchoServer bool // GenerateEchoServer specifies whether to generate echo server boilerplate
GenerateClient bool // GenerateClient specifies whether to generate client boilerplate
GenerateTypes bool // GenerateTypes specifies whether to generate type definitions
EmbedSpec bool // Whether to embed the swagger spec in the generated code
SkipFmt bool // Whether to skip go fmt on the generated code
SkipPrune bool // Whether to skip pruning unused components on the generated code
IncludeTags []string // Only include operations that have one of these tags. Ignored when empty.
ExcludeTags []string // Exclude operations that have one of these tags. Ignored when empty.
UserTemplates map[string]string // Override built-in templates from user-provided files
ImportMapping map[string]string // ImportMapping specifies the golang package path for each external reference
ExcludeSchemas []string // Exclude from generation schemas with given names. Ignored when empty.
}
type goImport struct {
lookFor string
alias string
packageName string
}
func (i goImport) String() string {
if i.alias != "" {
return fmt.Sprintf("%s %q", i.alias, i.packageName)
}
return fmt.Sprintf("%q", i.packageName)
}
type goImports []goImport
var (
allGoImports = goImports{
{lookFor: "base64\\.", packageName: "encoding/base64"},
{lookFor: "bytes\\.", packageName: "bytes"},
{lookFor: "chi\\.", packageName: "github.com/go-chi/chi"},
{lookFor: "context\\.", packageName: "context"},
{lookFor: "echo\\.", packageName: "github.com/labstack/echo/v4"},
{lookFor: "errors\\.", packageName: "github.com/pkg/errors"},
{lookFor: "fmt\\.", packageName: "fmt"},
{lookFor: "gzip\\.", packageName: "compress/gzip"},
{lookFor: "http\\.", packageName: "net/http"},
{lookFor: "io\\.", packageName: "io"},
{lookFor: "ioutil\\.", packageName: "io/ioutil"},
{lookFor: "json\\.", packageName: "encoding/json"},
{lookFor: "openapi3\\.", packageName: "github.com/getkin/kin-openapi/openapi3"},
{lookFor: "openapi_types\\.", alias: "openapi_types", packageName: "github.com/deepmap/oapi-codegen/pkg/types"},
{lookFor: "path\\.", packageName: "path"},
{lookFor: "runtime\\.", packageName: "github.com/deepmap/oapi-codegen/pkg/runtime"},
{lookFor: "strings\\.", packageName: "strings"},
{lookFor: "time\\.Duration", packageName: "time"},
{lookFor: "time\\.Time", packageName: "time"},
{lookFor: "url\\.", packageName: "net/url"},
{lookFor: "xml\\.", packageName: "encoding/xml"},
{lookFor: "yaml\\.", packageName: "gopkg.in/yaml.v2"},
}
importMapping = map[string]goImport{}
)
func constructImportMapping(input map[string]string) map[string]goImport {
var (
nameToAlias = map[string]string{}
result = map[string]goImport{}
)
{
var packagePaths []string
for _, packageName := range input {
packagePaths = append(packagePaths, packageName)
}
sort.Strings(packagePaths)
for _, packageName := range packagePaths {
if _, ok := nameToAlias[packageName]; !ok {
nameToAlias[packageName] = fmt.Sprintf("externalRef%d", len(nameToAlias))
}
}
}
for urlOrPath, packageName := range input {
result[urlOrPath] = goImport{alias: nameToAlias[packageName], packageName: packageName}
}
return result
}
// Uses the Go templating engine to generate all of our server wrappers from
// the descriptions we've built up above from the schema objects.
// opts defines
func Generate(swagger *openapi3.Swagger, packageName string, opts Options) (string, error) {
importMapping = constructImportMapping(opts.ImportMapping)
filterOperationsByTag(swagger, opts)
if !opts.SkipPrune {
pruneUnusedComponents(swagger)
}
// This creates the golang templates text package
t := template.New("oapi-codegen").Funcs(TemplateFunctions)
// This parses all of our own template files into the template object
// above
t, err := templates.Parse(t)
if err != nil {
return "", errors.Wrap(err, "error parsing oapi-codegen templates")
}
// Override built-in templates with user-provided versions
for _, tpl := range t.Templates() {
if _, ok := opts.UserTemplates[tpl.Name()]; ok {
utpl := t.New(tpl.Name())
if _, err := utpl.Parse(opts.UserTemplates[tpl.Name()]); err != nil {
return "", errors.Wrapf(err, "error parsing user-provided template %q", tpl.Name())
}
}
}
ops, err := OperationDefinitions(swagger)
if err != nil {
return "", errors.Wrap(err, "error creating operation definitions")
}
var typeDefinitions string
if opts.GenerateTypes {
typeDefinitions, err = GenerateTypeDefinitions(t, swagger, ops, opts.ExcludeSchemas)
if err != nil {
return "", errors.Wrap(err, "error generating type definitions")
}
}
var echoServerOut string
if opts.GenerateEchoServer {
echoServerOut, err = GenerateEchoServer(t, ops)
if err != nil {
return "", errors.Wrap(err, "error generating Go handlers for Paths")
}
}
var chiServerOut string
if opts.GenerateChiServer {
chiServerOut, err = GenerateChiServer(t, ops)
if err != nil {
return "", errors.Wrap(err, "error generating Go handlers for Paths")
}
}
var clientOut string
if opts.GenerateClient {
clientOut, err = GenerateClient(t, ops)
if err != nil {
return "", errors.Wrap(err, "error generating client")
}
}
var clientWithResponsesOut string
if opts.GenerateClient {
clientWithResponsesOut, err = GenerateClientWithResponses(t, ops)
if err != nil {
return "", errors.Wrap(err, "error generating client with responses")
}
}
var inlinedSpec string
if opts.EmbedSpec {
inlinedSpec, err = GenerateInlinedSpec(t, swagger)
if err != nil {
return "", errors.Wrap(err, "error generating Go handlers for Paths")
}
}
// Imports needed for the generated code to compile
var imports []string
for _, importGo := range importMapping {
imports = append(imports, importGo.String())
}
var buf bytes.Buffer
w := bufio.NewWriter(&buf)
// Based on module prefixes, figure out which optional imports are required.
pkgs := make(map[string]int)
for _, str := range []string{typeDefinitions, chiServerOut, echoServerOut, clientOut, clientWithResponsesOut, inlinedSpec} {
for _, line := range strings.Split(strings.TrimSpace(str), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "//") {
continue
}
for _, goImport := range allGoImports {
match, err := regexp.MatchString(fmt.Sprintf("[^a-zA-Z0-9_]%s", goImport.lookFor), line)
if err != nil {
return "", errors.Wrap(err, "error figuring out imports")
}
if match {
pkgs[goImport.String()]++
}
}
}
}
for k := range pkgs {
imports = append(imports, k)
}
importsOut, err := GenerateImports(t, imports, packageName)
if err != nil {
return "", errors.Wrap(err, "error generating imports")
}
_, err = w.WriteString(importsOut)
if err != nil {
return "", errors.Wrap(err, "error writing imports")
}
_, err = w.WriteString(typeDefinitions)
if err != nil {
return "", errors.Wrap(err, "error writing type definitions")
}
if opts.GenerateClient {
_, err = w.WriteString(clientOut)
if err != nil {
return "", errors.Wrap(err, "error writing client")
}
_, err = w.WriteString(clientWithResponsesOut)
if err != nil {
return "", errors.Wrap(err, "error writing client")
}
}
if opts.GenerateEchoServer {
_, err = w.WriteString(echoServerOut)
if err != nil {
return "", errors.Wrap(err, "error writing server path handlers")
}
}
if opts.GenerateChiServer {
_, err = w.WriteString(chiServerOut)
if err != nil {
return "", errors.Wrap(err, "error writing server path handlers")
}
}
if opts.EmbedSpec {
_, err = w.WriteString(inlinedSpec)
if err != nil {
return "", errors.Wrap(err, "error writing inlined spec")
}
}
err = w.Flush()
if err != nil {
return "", errors.Wrap(err, "error flushing output buffer")
}
// remove any byte-order-marks which break Go-Code
goCode := SanitizeCode(buf.String())
// The generation code produces unindented horrors. Use the Go formatter
// to make it all pretty.
if opts.SkipFmt {
return goCode, nil
}
outBytes, err := format.Source([]byte(goCode))
if err != nil {
fmt.Println(goCode)
return "", errors.Wrap(err, "error formatting Go code")
}
return string(outBytes), nil
}
func GenerateTypeDefinitions(t *template.Template, swagger *openapi3.Swagger, ops []OperationDefinition, excludeSchemas []string) (string, error) {
schemaTypes, err := GenerateTypesForSchemas(t, swagger.Components.Schemas, excludeSchemas)
if err != nil {
return "", errors.Wrap(err, "error generating Go types for component schemas")
}
paramTypes, err := GenerateTypesForParameters(t, swagger.Components.Parameters)
if err != nil {
return "", errors.Wrap(err, "error generating Go types for component parameters")
}
allTypes := append(schemaTypes, paramTypes...)
responseTypes, err := GenerateTypesForResponses(t, swagger.Components.Responses)
if err != nil {
return "", errors.Wrap(err, "error generating Go types for component responses")
}
allTypes = append(allTypes, responseTypes...)
bodyTypes, err := GenerateTypesForRequestBodies(t, swagger.Components.RequestBodies)
if err != nil {
return "", errors.Wrap(err, "error generating Go types for component request bodies")
}
allTypes = append(allTypes, bodyTypes...)
paramTypesOut, err := GenerateTypesForOperations(t, ops)
if err != nil {
return "", errors.Wrap(err, "error generating Go types for operation parameters")
}
typesOut, err := GenerateTypes(t, allTypes)
if err != nil {
return "", errors.Wrap(err, "error generating code for type definitions")
}
allOfBoilerplate, err := GenerateAdditionalPropertyBoilerplate(t, allTypes)
if err != nil {
return "", errors.Wrap(err, "error generating allOf boilerplate")
}
typeDefinitions := strings.Join([]string{typesOut, paramTypesOut, allOfBoilerplate}, "")
return typeDefinitions, nil
}
// Generates type definitions for any custom types defined in the
// components/schemas section of the Swagger spec.
func GenerateTypesForSchemas(t *template.Template, schemas map[string]*openapi3.SchemaRef, excludeSchemas []string) ([]TypeDefinition, error) {
var excludeSchemasMap = make(map[string]bool)
for _, schema := range excludeSchemas {
excludeSchemasMap[schema] = true
}
types := make([]TypeDefinition, 0)
// We're going to define Go types for every object under components/schemas
for _, schemaName := range SortedSchemaKeys(schemas) {
if _, ok := excludeSchemasMap[schemaName]; ok {
continue
}
schemaRef := schemas[schemaName]
goSchema, err := GenerateGoSchema(schemaRef, []string{schemaName})
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("error converting Schema %s to Go type", schemaName))
}
types = append(types, TypeDefinition{
JsonName: schemaName,
TypeName: SchemaNameToTypeName(schemaName),
Schema: goSchema,
})
types = append(types, goSchema.GetAdditionalTypeDefs()...)
}
return types, nil
}
// Generates type definitions for any custom types defined in the
// components/parameters section of the Swagger spec.
func GenerateTypesForParameters(t *template.Template, params map[string]*openapi3.ParameterRef) ([]TypeDefinition, error) {
var types []TypeDefinition
for _, paramName := range SortedParameterKeys(params) {
paramOrRef := params[paramName]
goType, err := paramToGoType(paramOrRef.Value, nil)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for schema in parameter %s", paramName))
}
typeDef := TypeDefinition{
JsonName: paramName,
Schema: goType,
TypeName: SchemaNameToTypeName(paramName),
}
if paramOrRef.Ref != "" {
// Generate a reference type for referenced parameters
refType, err := RefPathToGoType(paramOrRef.Ref)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for (%s) in parameter %s", paramOrRef.Ref, paramName))
}
typeDef.TypeName = SchemaNameToTypeName(refType)
}
types = append(types, typeDef)
}
return types, nil
}
// Generates type definitions for any custom types defined in the
// components/responses section of the Swagger spec.
func GenerateTypesForResponses(t *template.Template, responses openapi3.Responses) ([]TypeDefinition, error) {
var types []TypeDefinition
for _, responseName := range SortedResponsesKeys(responses) {
responseOrRef := responses[responseName]
// We have to generate the response object. We're only going to
// handle application/json media types here. Other responses should
// simply be specified as strings or byte arrays.
response := responseOrRef.Value
jsonResponse, found := response.Content["application/json"]
if found {
goType, err := GenerateGoSchema(jsonResponse.Schema, []string{responseName})
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for schema in response %s", responseName))
}
typeDef := TypeDefinition{
JsonName: responseName,
Schema: goType,
TypeName: SchemaNameToTypeName(responseName),
}
if responseOrRef.Ref != "" {
// Generate a reference type for referenced parameters
refType, err := RefPathToGoType(responseOrRef.Ref)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for (%s) in parameter %s", responseOrRef.Ref, responseName))
}
typeDef.TypeName = SchemaNameToTypeName(refType)
}
types = append(types, typeDef)
}
}
return types, nil
}
// Generates type definitions for any custom types defined in the
// components/requestBodies section of the Swagger spec.
func GenerateTypesForRequestBodies(t *template.Template, bodies map[string]*openapi3.RequestBodyRef) ([]TypeDefinition, error) {
var types []TypeDefinition
for _, bodyName := range SortedRequestBodyKeys(bodies) {
bodyOrRef := bodies[bodyName]
// As for responses, we will only generate Go code for JSON bodies,
// the other body formats are up to the user.
response := bodyOrRef.Value
jsonBody, found := response.Content["application/json"]
if found {
goType, err := GenerateGoSchema(jsonBody.Schema, []string{bodyName})
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for schema in body %s", bodyName))
}
typeDef := TypeDefinition{
JsonName: bodyName,
Schema: goType,
TypeName: SchemaNameToTypeName(bodyName),
}
if bodyOrRef.Ref != "" {
// Generate a reference type for referenced bodies
refType, err := RefPathToGoType(bodyOrRef.Ref)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("error generating Go type for (%s) in body %s", bodyOrRef.Ref, bodyName))
}
typeDef.TypeName = SchemaNameToTypeName(refType)
}
types = append(types, typeDef)
}
}
return types, nil
}
// Helper function to pass a bunch of types to the template engine, and buffer
// its output into a string.
func GenerateTypes(t *template.Template, types []TypeDefinition) (string, error) {
var buf bytes.Buffer
w := bufio.NewWriter(&buf)
context := struct {
Types []TypeDefinition
}{
Types: types,
}
err := t.ExecuteTemplate(w, "typedef.tmpl", context)
if err != nil {
return "", errors.Wrap(err, "error generating types")
}
err = w.Flush()
if err != nil {
return "", errors.Wrap(err, "error flushing output buffer for types")
}
return buf.String(), nil
}
// Generate our import statements and package definition.
func GenerateImports(t *template.Template, imports []string, packageName string) (string, error) {
sort.Strings(imports)
var buf bytes.Buffer
w := bufio.NewWriter(&buf)
context := struct {
Imports []string
PackageName string
}{
Imports: imports,
PackageName: packageName,
}
err := t.ExecuteTemplate(w, "imports.tmpl", context)
if err != nil {
return "", errors.Wrap(err, "error generating imports")
}
err = w.Flush()
if err != nil {
return "", errors.Wrap(err, "error flushing output buffer for imports")
}
return buf.String(), nil
}
// Generate all the glue code which provides the API for interacting with
// additional properties and JSON-ification
func GenerateAdditionalPropertyBoilerplate(t *template.Template, typeDefs []TypeDefinition) (string, error) {
var buf bytes.Buffer
w := bufio.NewWriter(&buf)
var filteredTypes []TypeDefinition
for _, t := range typeDefs {
if t.Schema.HasAdditionalProperties {
filteredTypes = append(filteredTypes, t)
}
}
context := struct {
Types []TypeDefinition
}{
Types: filteredTypes,
}
err := t.ExecuteTemplate(w, "additional-properties.tmpl", context)
if err != nil {
return "", errors.Wrap(err, "error generating additional properties code")
}
err = w.Flush()
if err != nil {
return "", errors.Wrap(err, "error flushing output buffer for additional properties")
}
return buf.String(), nil
}
// SanitizeCode runs sanitizers across the generated Go code to ensure the
// generated code will be able to compile.
func SanitizeCode(goCode string) string {
// remove any byte-order-marks which break Go-Code
// See: https://groups.google.com/forum/#!topic/golang-nuts/OToNIPdfkks
return strings.Replace(goCode, "\uFEFF", "", -1)
}