package codegen
import (
"fmt"
"github.com/getkin/kin-openapi/openapi3"
)
func stringInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}
type RefWrapper struct {
Ref string
HasValue bool
SourceRef interface{}
}
func walkSwagger(swagger *openapi3.Swagger, doFn func(RefWrapper) (bool, error)) error {
if swagger == nil {
return nil
}
for _, p := range swagger.Paths {
for _, param := range p.Parameters {
walkParameterRef(param, doFn)
}
for _, op := range p.Operations() {
walkOperation(op, doFn)
}
}
walkComponents(&swagger.Components, doFn)
return nil
}
func walkOperation(op *openapi3.Operation, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if op == nil {
return nil
}
for _, param := range op.Parameters {
_ = walkParameterRef(param, doFn)
}
_ = walkRequestBodyRef(op.RequestBody, doFn)
for _, response := range op.Responses {
walkResponseRef(response, doFn)
}
for _, callback := range op.Callbacks {
walkCallbackRef(callback, doFn)
}
return nil
}
func walkComponents(components *openapi3.Components, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if components == nil {
return nil
}
for _, schema := range components.Schemas {
_ = walkSchemaRef(schema, doFn)
}
for _, param := range components.Parameters {
_ = walkParameterRef(param, doFn)
}
for _, header := range components.Headers {
_ = walkHeaderRef(header, doFn)
}
for _, requestBody := range components.RequestBodies {
_ = walkRequestBodyRef(requestBody, doFn)
}
for _, response := range components.Responses {
_ = walkResponseRef(response, doFn)
}
for _, securityScheme := range components.SecuritySchemes {
_ = walkSecuritySchemeRef(securityScheme, doFn)
}
for _, example := range components.Examples {
_ = walkExampleRef(example, doFn)
}
for _, link := range components.Links {
_ = walkLinkRef(link, doFn)
}
for _, callback := range components.Callbacks {
_ = walkCallbackRef(callback, doFn)
}
return nil
}
func walkSchemaRef(ref *openapi3.SchemaRef, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if ref == nil {
return nil
}
refWrapper := RefWrapper{Ref: ref.Ref, HasValue: ref.Value != nil, SourceRef: ref}
shouldContinue, err := doFn(refWrapper)
if err != nil {
return err
}
if !shouldContinue {
return nil
}
if ref.Value == nil {
return nil
}
for _, ref := range ref.Value.OneOf {
walkSchemaRef(ref, doFn)
}
for _, ref := range ref.Value.AnyOf {
walkSchemaRef(ref, doFn)
}
for _, ref := range ref.Value.AllOf {
walkSchemaRef(ref, doFn)
}
walkSchemaRef(ref.Value.Not, doFn)
walkSchemaRef(ref.Value.Items, doFn)
for _, ref := range ref.Value.Properties {
walkSchemaRef(ref, doFn)
}
walkSchemaRef(ref.Value.AdditionalProperties, doFn)
return nil
}
func walkParameterRef(ref *openapi3.ParameterRef, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if ref == nil {
return nil
}
refWrapper := RefWrapper{Ref: ref.Ref, HasValue: ref.Value != nil, SourceRef: ref}
shouldContinue, err := doFn(refWrapper)
if err != nil {
return err
}
if !shouldContinue {
return nil
}
if ref.Value == nil {
return nil
}
walkSchemaRef(ref.Value.Schema, doFn)
for _, example := range ref.Value.Examples {
walkExampleRef(example, doFn)
}
for _, mediaType := range ref.Value.Content {
if mediaType == nil {
continue
}
walkSchemaRef(mediaType.Schema, doFn)
for _, example := range mediaType.Examples {
walkExampleRef(example, doFn)
}
}
return nil
}
func walkRequestBodyRef(ref *openapi3.RequestBodyRef, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if ref == nil {
return nil
}
refWrapper := RefWrapper{Ref: ref.Ref, HasValue: ref.Value != nil, SourceRef: ref}
shouldContinue, err := doFn(refWrapper)
if err != nil {
return err
}
if !shouldContinue {
return nil
}
if ref.Value == nil {
return nil
}
for _, mediaType := range ref.Value.Content {
if mediaType == nil {
continue
}
walkSchemaRef(mediaType.Schema, doFn)
for _, example := range mediaType.Examples {
walkExampleRef(example, doFn)
}
}
return nil
}
func walkResponseRef(ref *openapi3.ResponseRef, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if ref == nil {
return nil
}
refWrapper := RefWrapper{Ref: ref.Ref, HasValue: ref.Value != nil, SourceRef: ref}
shouldContinue, err := doFn(refWrapper)
if err != nil {
return err
}
if !shouldContinue {
return nil
}
if ref.Value == nil {
return nil
}
for _, header := range ref.Value.Headers {
walkHeaderRef(header, doFn)
}
for _, mediaType := range ref.Value.Content {
if mediaType == nil {
continue
}
walkSchemaRef(mediaType.Schema, doFn)
for _, example := range mediaType.Examples {
walkExampleRef(example, doFn)
}
}
for _, link := range ref.Value.Links {
walkLinkRef(link, doFn)
}
return nil
}
func walkCallbackRef(ref *openapi3.CallbackRef, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if ref == nil {
return nil
}
refWrapper := RefWrapper{Ref: ref.Ref, HasValue: ref.Value != nil, SourceRef: ref}
shouldContinue, err := doFn(refWrapper)
if err != nil {
return err
}
if !shouldContinue {
return nil
}
if ref.Value == nil {
return nil
}
for _, pathItem := range *ref.Value {
for _, parameter := range pathItem.Parameters {
walkParameterRef(parameter, doFn)
}
walkOperation(pathItem.Connect, doFn)
walkOperation(pathItem.Delete, doFn)
walkOperation(pathItem.Get, doFn)
walkOperation(pathItem.Head, doFn)
walkOperation(pathItem.Options, doFn)
walkOperation(pathItem.Patch, doFn)
walkOperation(pathItem.Post, doFn)
walkOperation(pathItem.Put, doFn)
walkOperation(pathItem.Trace, doFn)
}
return nil
}
func walkHeaderRef(ref *openapi3.HeaderRef, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if ref == nil {
return nil
}
refWrapper := RefWrapper{Ref: ref.Ref, HasValue: ref.Value != nil, SourceRef: ref}
shouldContinue, err := doFn(refWrapper)
if err != nil {
return err
}
if !shouldContinue {
return nil
}
if ref.Value == nil {
return nil
}
walkSchemaRef(ref.Value.Schema, doFn)
return nil
}
func walkSecuritySchemeRef(ref *openapi3.SecuritySchemeRef, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if ref == nil {
return nil
}
refWrapper := RefWrapper{Ref: ref.Ref, HasValue: ref.Value != nil, SourceRef: ref}
shouldContinue, err := doFn(refWrapper)
if err != nil {
return err
}
if !shouldContinue {
return nil
}
if ref.Value == nil {
return nil
}
// NOTE: `SecuritySchemeRef`s don't contain any children that can contain refs
return nil
}
func walkLinkRef(ref *openapi3.LinkRef, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if ref == nil {
return nil
}
refWrapper := RefWrapper{Ref: ref.Ref, HasValue: ref.Value != nil, SourceRef: ref}
shouldContinue, err := doFn(refWrapper)
if err != nil {
return err
}
if !shouldContinue {
return nil
}
if ref.Value == nil {
return nil
}
return nil
}
func walkExampleRef(ref *openapi3.ExampleRef, doFn func(RefWrapper) (bool, error)) error {
// Not a valid ref, ignore it and continue
if ref == nil {
return nil
}
refWrapper := RefWrapper{Ref: ref.Ref, HasValue: ref.Value != nil, SourceRef: ref}
shouldContinue, err := doFn(refWrapper)
if err != nil {
return err
}
if !shouldContinue {
return nil
}
if ref.Value == nil {
return nil
}
// NOTE: `ExampleRef`s don't contain any children that can contain refs
return nil
}
func findComponentRefs(swagger *openapi3.Swagger) []string {
refs := []string{}
walkSwagger(swagger, func(ref RefWrapper) (bool, error) {
if ref.Ref != "" {
refs = append(refs, ref.Ref)
return false, nil
}
return true, nil
})
return refs
}
func removeOrphanedComponents(swagger *openapi3.Swagger, refs []string) int {
countRemoved := 0
for key, _ := range swagger.Components.Schemas {
ref := fmt.Sprintf("#/components/schemas/%s", key)
if !stringInSlice(ref, refs) {
countRemoved++
delete(swagger.Components.Schemas, key)
}
}
for key, _ := range swagger.Components.Parameters {
ref := fmt.Sprintf("#/components/parameters/%s", key)
if !stringInSlice(ref, refs) {
countRemoved++
delete(swagger.Components.Parameters, key)
}
}
// securitySchemes are an exception. definitions in securitySchemes
// are referenced directly by name. and not by $ref
// for key, _ := range swagger.Components.SecuritySchemes {
// ref := fmt.Sprintf("#/components/securitySchemes/%s", key)
// if !stringInSlice(ref, refs) {
// countRemoved++
// delete(swagger.Components.SecuritySchemes, key)
// }
// }
for key, _ := range swagger.Components.RequestBodies {
ref := fmt.Sprintf("#/components/requestBodies/%s", key)
if !stringInSlice(ref, refs) {
countRemoved++
delete(swagger.Components.RequestBodies, key)
}
}
for key, _ := range swagger.Components.Responses {
ref := fmt.Sprintf("#/components/responses/%s", key)
if !stringInSlice(ref, refs) {
countRemoved++
delete(swagger.Components.Responses, key)
}
}
for key, _ := range swagger.Components.Headers {
ref := fmt.Sprintf("#/components/headers/%s", key)
if !stringInSlice(ref, refs) {
countRemoved++
delete(swagger.Components.Headers, key)
}
}
for key, _ := range swagger.Components.Examples {
ref := fmt.Sprintf("#/components/examples/%s", key)
if !stringInSlice(ref, refs) {
countRemoved++
delete(swagger.Components.Examples, key)
}
}
for key, _ := range swagger.Components.Links {
ref := fmt.Sprintf("#/components/links/%s", key)
if !stringInSlice(ref, refs) {
countRemoved++
delete(swagger.Components.Links, key)
}
}
for key, _ := range swagger.Components.Callbacks {
ref := fmt.Sprintf("#/components/callbacks/%s", key)
if !stringInSlice(ref, refs) {
countRemoved++
delete(swagger.Components.Callbacks, key)
}
}
return countRemoved
}
func pruneUnusedComponents(swagger *openapi3.Swagger) {
for {
refs := findComponentRefs(swagger)
countRemoved := removeOrphanedComponents(swagger, refs)
if countRemoved < 1 {
break
}
}
}