From ca351fc2c9d152516395667b29996d018ddf9d95 Mon Sep 17 00:00:00 2001 From: "Giau. Tran Minh" Date: Thu, 21 Jul 2022 02:12:12 +0700 Subject: [PATCH 1/3] entgql: reduce edge code by call Paginate --- entgql/internal/todo/ent/gql_edge.go | 242 ++------------------- entgql/internal/todogotype/ent/gql_edge.go | 242 ++------------------- entgql/internal/todopulid/ent/gql_edge.go | 242 ++------------------- entgql/internal/todouuid/ent/gql_edge.go | 242 ++------------------- entgql/template/edge.tmpl | 13 +- 5 files changed, 89 insertions(+), 892 deletions(-) diff --git a/entgql/internal/todo/ent/gql_edge.go b/entgql/internal/todo/ent/gql_edge.go index a30781f2d..8970d3093 100644 --- a/entgql/internal/todo/ent/gql_edge.go +++ b/entgql/internal/todo/ent/gql_edge.go @@ -16,11 +16,7 @@ package ent -import ( - "context" - - "github.com/99designs/gqlgen/graphql" -) +import "context" func (c *Category) Todos( ctx context.Context, after *Cursor, first *int, before *Cursor, last *int, orderBy *TodoOrder, where *TodoWhereInput, @@ -31,67 +27,18 @@ func (c *Category) Todos( } totalCount := c.Edges.totalCount[0] if nodes, err := c.Edges.TodosOrErr(); err == nil || totalCount != nil { - conn := &TodoConnection{Edges: []*TodoEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newTodoPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := c.QueryTodos() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newTodoPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &TodoConnection{Edges: []*TodoEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &TodoConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return c.QueryTodos().Paginate(ctx, after, first, before, last, opts...) } func (f *Friendship) User(ctx context.Context) (*User, error) { @@ -118,67 +65,18 @@ func (gr *Group) Users( } totalCount := gr.Edges.totalCount[0] if nodes, err := gr.Edges.UsersOrErr(); err == nil || totalCount != nil { - conn := &UserConnection{Edges: []*UserEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newUserPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := gr.QueryUsers() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newUserPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &UserConnection{Edges: []*UserEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &UserConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return gr.QueryUsers().Paginate(ctx, after, first, before, last, opts...) } func (t *Todo) Parent(ctx context.Context) (*Todo, error) { @@ -198,67 +96,18 @@ func (t *Todo) Children( } totalCount := t.Edges.totalCount[1] if nodes, err := t.Edges.ChildrenOrErr(); err == nil || totalCount != nil { - conn := &TodoConnection{Edges: []*TodoEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newTodoPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := t.QueryChildren() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newTodoPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &TodoConnection{Edges: []*TodoEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &TodoConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return t.QueryChildren().Paginate(ctx, after, first, before, last, opts...) } func (t *Todo) Category(ctx context.Context) (*Category, error) { @@ -277,67 +126,18 @@ func (u *User) Groups( } totalCount := u.Edges.totalCount[0] if nodes, err := u.Edges.GroupsOrErr(); err == nil || totalCount != nil { - conn := &GroupConnection{Edges: []*GroupEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newGroupPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := u.QueryGroups() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newGroupPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &GroupConnection{Edges: []*GroupEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &GroupConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return u.QueryGroups().Paginate(ctx, after, first, before, last, opts...) } func (u *User) Friends(ctx context.Context) ([]*User, error) { diff --git a/entgql/internal/todogotype/ent/gql_edge.go b/entgql/internal/todogotype/ent/gql_edge.go index a30781f2d..8970d3093 100644 --- a/entgql/internal/todogotype/ent/gql_edge.go +++ b/entgql/internal/todogotype/ent/gql_edge.go @@ -16,11 +16,7 @@ package ent -import ( - "context" - - "github.com/99designs/gqlgen/graphql" -) +import "context" func (c *Category) Todos( ctx context.Context, after *Cursor, first *int, before *Cursor, last *int, orderBy *TodoOrder, where *TodoWhereInput, @@ -31,67 +27,18 @@ func (c *Category) Todos( } totalCount := c.Edges.totalCount[0] if nodes, err := c.Edges.TodosOrErr(); err == nil || totalCount != nil { - conn := &TodoConnection{Edges: []*TodoEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newTodoPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := c.QueryTodos() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newTodoPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &TodoConnection{Edges: []*TodoEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &TodoConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return c.QueryTodos().Paginate(ctx, after, first, before, last, opts...) } func (f *Friendship) User(ctx context.Context) (*User, error) { @@ -118,67 +65,18 @@ func (gr *Group) Users( } totalCount := gr.Edges.totalCount[0] if nodes, err := gr.Edges.UsersOrErr(); err == nil || totalCount != nil { - conn := &UserConnection{Edges: []*UserEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newUserPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := gr.QueryUsers() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newUserPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &UserConnection{Edges: []*UserEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &UserConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return gr.QueryUsers().Paginate(ctx, after, first, before, last, opts...) } func (t *Todo) Parent(ctx context.Context) (*Todo, error) { @@ -198,67 +96,18 @@ func (t *Todo) Children( } totalCount := t.Edges.totalCount[1] if nodes, err := t.Edges.ChildrenOrErr(); err == nil || totalCount != nil { - conn := &TodoConnection{Edges: []*TodoEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newTodoPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := t.QueryChildren() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newTodoPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &TodoConnection{Edges: []*TodoEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &TodoConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return t.QueryChildren().Paginate(ctx, after, first, before, last, opts...) } func (t *Todo) Category(ctx context.Context) (*Category, error) { @@ -277,67 +126,18 @@ func (u *User) Groups( } totalCount := u.Edges.totalCount[0] if nodes, err := u.Edges.GroupsOrErr(); err == nil || totalCount != nil { - conn := &GroupConnection{Edges: []*GroupEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newGroupPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := u.QueryGroups() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newGroupPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &GroupConnection{Edges: []*GroupEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &GroupConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return u.QueryGroups().Paginate(ctx, after, first, before, last, opts...) } func (u *User) Friends(ctx context.Context) ([]*User, error) { diff --git a/entgql/internal/todopulid/ent/gql_edge.go b/entgql/internal/todopulid/ent/gql_edge.go index a30781f2d..8970d3093 100644 --- a/entgql/internal/todopulid/ent/gql_edge.go +++ b/entgql/internal/todopulid/ent/gql_edge.go @@ -16,11 +16,7 @@ package ent -import ( - "context" - - "github.com/99designs/gqlgen/graphql" -) +import "context" func (c *Category) Todos( ctx context.Context, after *Cursor, first *int, before *Cursor, last *int, orderBy *TodoOrder, where *TodoWhereInput, @@ -31,67 +27,18 @@ func (c *Category) Todos( } totalCount := c.Edges.totalCount[0] if nodes, err := c.Edges.TodosOrErr(); err == nil || totalCount != nil { - conn := &TodoConnection{Edges: []*TodoEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newTodoPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := c.QueryTodos() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newTodoPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &TodoConnection{Edges: []*TodoEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &TodoConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return c.QueryTodos().Paginate(ctx, after, first, before, last, opts...) } func (f *Friendship) User(ctx context.Context) (*User, error) { @@ -118,67 +65,18 @@ func (gr *Group) Users( } totalCount := gr.Edges.totalCount[0] if nodes, err := gr.Edges.UsersOrErr(); err == nil || totalCount != nil { - conn := &UserConnection{Edges: []*UserEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newUserPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := gr.QueryUsers() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newUserPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &UserConnection{Edges: []*UserEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &UserConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return gr.QueryUsers().Paginate(ctx, after, first, before, last, opts...) } func (t *Todo) Parent(ctx context.Context) (*Todo, error) { @@ -198,67 +96,18 @@ func (t *Todo) Children( } totalCount := t.Edges.totalCount[1] if nodes, err := t.Edges.ChildrenOrErr(); err == nil || totalCount != nil { - conn := &TodoConnection{Edges: []*TodoEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newTodoPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := t.QueryChildren() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newTodoPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &TodoConnection{Edges: []*TodoEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &TodoConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return t.QueryChildren().Paginate(ctx, after, first, before, last, opts...) } func (t *Todo) Category(ctx context.Context) (*Category, error) { @@ -277,67 +126,18 @@ func (u *User) Groups( } totalCount := u.Edges.totalCount[0] if nodes, err := u.Edges.GroupsOrErr(); err == nil || totalCount != nil { - conn := &GroupConnection{Edges: []*GroupEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newGroupPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := u.QueryGroups() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newGroupPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &GroupConnection{Edges: []*GroupEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &GroupConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return u.QueryGroups().Paginate(ctx, after, first, before, last, opts...) } func (u *User) Friends(ctx context.Context) ([]*User, error) { diff --git a/entgql/internal/todouuid/ent/gql_edge.go b/entgql/internal/todouuid/ent/gql_edge.go index a30781f2d..8970d3093 100644 --- a/entgql/internal/todouuid/ent/gql_edge.go +++ b/entgql/internal/todouuid/ent/gql_edge.go @@ -16,11 +16,7 @@ package ent -import ( - "context" - - "github.com/99designs/gqlgen/graphql" -) +import "context" func (c *Category) Todos( ctx context.Context, after *Cursor, first *int, before *Cursor, last *int, orderBy *TodoOrder, where *TodoWhereInput, @@ -31,67 +27,18 @@ func (c *Category) Todos( } totalCount := c.Edges.totalCount[0] if nodes, err := c.Edges.TodosOrErr(); err == nil || totalCount != nil { - conn := &TodoConnection{Edges: []*TodoEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newTodoPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := c.QueryTodos() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newTodoPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &TodoConnection{Edges: []*TodoEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &TodoConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return c.QueryTodos().Paginate(ctx, after, first, before, last, opts...) } func (f *Friendship) User(ctx context.Context) (*User, error) { @@ -118,67 +65,18 @@ func (gr *Group) Users( } totalCount := gr.Edges.totalCount[0] if nodes, err := gr.Edges.UsersOrErr(); err == nil || totalCount != nil { - conn := &UserConnection{Edges: []*UserEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newUserPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := gr.QueryUsers() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newUserPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &UserConnection{Edges: []*UserEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &UserConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return gr.QueryUsers().Paginate(ctx, after, first, before, last, opts...) } func (t *Todo) Parent(ctx context.Context) (*Todo, error) { @@ -198,67 +96,18 @@ func (t *Todo) Children( } totalCount := t.Edges.totalCount[1] if nodes, err := t.Edges.ChildrenOrErr(); err == nil || totalCount != nil { - conn := &TodoConnection{Edges: []*TodoEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newTodoPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := t.QueryChildren() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newTodoPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &TodoConnection{Edges: []*TodoEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &TodoConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return t.QueryChildren().Paginate(ctx, after, first, before, last, opts...) } func (t *Todo) Category(ctx context.Context) (*Category, error) { @@ -277,67 +126,18 @@ func (u *User) Groups( } totalCount := u.Edges.totalCount[0] if nodes, err := u.Edges.GroupsOrErr(); err == nil || totalCount != nil { - conn := &GroupConnection{Edges: []*GroupEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := newGroupPager(opts) if err != nil { return nil, err } - conn.build(nodes, pager, after, first, before, last) - return conn, nil - } - query := u.QueryGroups() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newGroupPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &GroupConnection{Edges: []*GroupEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 + conn := &GroupConnection{} + if totalCount != nil { + conn.TotalCount = *totalCount } + conn.build(nodes, pager, after, first, before, last) return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return u.QueryGroups().Paginate(ctx, after, first, before, last, opts...) } func (u *User) Friends(ctx context.Context) ([]*User, error) { diff --git a/entgql/template/edge.tmpl b/entgql/template/edge.tmpl index 53b264e30..e2c1fe81b 100644 --- a/entgql/template/edge.tmpl +++ b/entgql/template/edge.tmpl @@ -62,21 +62,18 @@ import "context" totalCount := {{ $r }}.Edges.totalCount[{{ $i }}] {{- /* Nodes were loaded, totalCount was loaded, or both. */}} if nodes, err := {{ $r }}.Edges.{{ $e.StructField }}OrErr(); err == nil || totalCount != nil { - conn := &{{ $conn }}{Edges: []*{{ $edge }}{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } pager, err := {{ $newPager }}(opts) if err != nil { return nil, err } + conn := &{{ $conn }}{} + if totalCount != nil { + conn.TotalCount = *totalCount + } conn.build(nodes, pager, after, first, before, last) return conn, nil } - query := {{ $r }}.Query{{ $e.StructField }}() - {{ with extend $n "Node" $e.Type "Query" "query" "TotalCount" "totalCount" -}} - {{ template "gql_pagination/helper/paginate" . }} - {{- end -}} + return {{ $r }}.Query{{ $e.StructField }}().Paginate(ctx, after, first, before, last, opts...) } {{ end }} From c053dd2b2068761d73f7b78ec178027dd4b2d45d Mon Sep 17 00:00:00 2001 From: "Giau. Tran Minh" Date: Thu, 21 Jul 2022 03:22:31 +0700 Subject: [PATCH 2/3] entgql: cleanup loadTotal logic --- entgql/internal/todo/ent/gql_collection.go | 220 +++++------------- .../internal/todogotype/ent/gql_collection.go | 220 +++++------------- .../internal/todopulid/ent/gql_collection.go | 220 +++++------------- .../internal/todouuid/ent/gql_collection.go | 220 +++++------------- entgql/template/collection.tmpl | 32 +-- 5 files changed, 257 insertions(+), 655 deletions(-) diff --git a/entgql/internal/todo/ent/gql_collection.go b/entgql/internal/todo/ent/gql_collection.go index 572582b41..784cfdb32 100644 --- a/entgql/internal/todo/ent/gql_collection.go +++ b/entgql/internal/todo/ent/gql_collection.go @@ -61,8 +61,10 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() c.loadTotal = append(c.loadTotal, func(ctx context.Context, nodes []*Category) error { ids := make([]driver.Value, len(nodes)) @@ -89,45 +91,20 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC } return nil }) + } else { + c.loadTotal = append(c.loadTotal, func(_ context.Context, nodes []*Category) error { + for i := range nodes { + n := len(nodes[i].Edges.Todos) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - c.loadTotal = append(c.loadTotal, func(ctx context.Context, nodes []*Category) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID int `sql:"category_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - s.Where(sql.InValues(category.TodosColumn, ids...)) - }) - if err := query.GroupBy(category.TodosColumn).Aggregate(Count()).Scan(ctx, &v); err != nil { - return err - } - m := make(map[int]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - c.loadTotal = append(c.loadTotal, func(_ context.Context, nodes []*Category) error { - for i := range nodes { - n := len(nodes[i].Edges.Todos) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(category.TodosColumn, limit, pager.orderExpr(args.last != nil)) @@ -298,8 +275,10 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() gr.loadTotal = append(gr.loadTotal, func(ctx context.Context, nodes []*Group) error { ids := make([]driver.Value, len(nodes)) @@ -330,49 +309,20 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon } return nil }) + } else { + gr.loadTotal = append(gr.loadTotal, func(_ context.Context, nodes []*Group) error { + for i := range nodes { + n := len(nodes[i].Edges.Users) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - gr.loadTotal = append(gr.loadTotal, func(ctx context.Context, nodes []*Group) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID int `sql:"group_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[1]), ids...)) - s.Select(joinT.C(group.UsersPrimaryKey[1]), sql.Count("*")) - s.GroupBy(joinT.C(group.UsersPrimaryKey[1])) - }) - if err := query.Select().Scan(ctx, &v); err != nil { - return err - } - m := make(map[int]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - gr.loadTotal = append(gr.loadTotal, func(_ context.Context, nodes []*Group) error { - for i := range nodes { - n := len(nodes[i].Edges.Users) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(group.UsersPrimaryKey[1], limit, pager.orderExpr(args.last != nil)) @@ -462,8 +412,10 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() t.loadTotal = append(t.loadTotal, func(ctx context.Context, nodes []*Todo) error { ids := make([]driver.Value, len(nodes)) @@ -490,45 +442,20 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte } return nil }) + } else { + t.loadTotal = append(t.loadTotal, func(_ context.Context, nodes []*Todo) error { + for i := range nodes { + n := len(nodes[i].Edges.Children) + nodes[i].Edges.totalCount[1] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - t.loadTotal = append(t.loadTotal, func(ctx context.Context, nodes []*Todo) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID int `sql:"todo_children"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - s.Where(sql.InValues(todo.ChildrenColumn, ids...)) - }) - if err := query.GroupBy(todo.ChildrenColumn).Aggregate(Count()).Scan(ctx, &v); err != nil { - return err - } - m := make(map[int]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[1] = &n - } - return nil - }) - } else { - t.loadTotal = append(t.loadTotal, func(_ context.Context, nodes []*Todo) error { - for i := range nodes { - n := len(nodes[i].Edges.Children) - nodes[i].Edges.totalCount[1] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(todo.ChildrenColumn, limit, pager.orderExpr(args.last != nil)) @@ -640,8 +567,10 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() u.loadTotal = append(u.loadTotal, func(ctx context.Context, nodes []*User) error { ids := make([]driver.Value, len(nodes)) @@ -672,49 +601,20 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte } return nil }) + } else { + u.loadTotal = append(u.loadTotal, func(_ context.Context, nodes []*User) error { + for i := range nodes { + n := len(nodes[i].Edges.Groups) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - u.loadTotal = append(u.loadTotal, func(ctx context.Context, nodes []*User) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID int `sql:"user_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), ids...)) - s.Select(joinT.C(user.GroupsPrimaryKey[0]), sql.Count("*")) - s.GroupBy(joinT.C(user.GroupsPrimaryKey[0])) - }) - if err := query.Select().Scan(ctx, &v); err != nil { - return err - } - m := make(map[int]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - u.loadTotal = append(u.loadTotal, func(_ context.Context, nodes []*User) error { - for i := range nodes { - n := len(nodes[i].Edges.Groups) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(user.GroupsPrimaryKey[0], limit, pager.orderExpr(args.last != nil)) diff --git a/entgql/internal/todogotype/ent/gql_collection.go b/entgql/internal/todogotype/ent/gql_collection.go index 4b2ed1ff4..3f042d77c 100644 --- a/entgql/internal/todogotype/ent/gql_collection.go +++ b/entgql/internal/todogotype/ent/gql_collection.go @@ -62,8 +62,10 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() c.loadTotal = append(c.loadTotal, func(ctx context.Context, nodes []*Category) error { ids := make([]driver.Value, len(nodes)) @@ -90,45 +92,20 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC } return nil }) + } else { + c.loadTotal = append(c.loadTotal, func(_ context.Context, nodes []*Category) error { + for i := range nodes { + n := len(nodes[i].Edges.Todos) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - c.loadTotal = append(c.loadTotal, func(ctx context.Context, nodes []*Category) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID bigintgql.BigInt `sql:"category_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - s.Where(sql.InValues(category.TodosColumn, ids...)) - }) - if err := query.GroupBy(category.TodosColumn).Aggregate(Count()).Scan(ctx, &v); err != nil { - return err - } - m := make(map[bigintgql.BigInt]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - c.loadTotal = append(c.loadTotal, func(_ context.Context, nodes []*Category) error { - for i := range nodes { - n := len(nodes[i].Edges.Todos) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(category.TodosColumn, limit, pager.orderExpr(args.last != nil)) @@ -299,8 +276,10 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() gr.loadTotal = append(gr.loadTotal, func(ctx context.Context, nodes []*Group) error { ids := make([]driver.Value, len(nodes)) @@ -331,49 +310,20 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon } return nil }) + } else { + gr.loadTotal = append(gr.loadTotal, func(_ context.Context, nodes []*Group) error { + for i := range nodes { + n := len(nodes[i].Edges.Users) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - gr.loadTotal = append(gr.loadTotal, func(ctx context.Context, nodes []*Group) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID string `sql:"group_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[1]), ids...)) - s.Select(joinT.C(group.UsersPrimaryKey[1]), sql.Count("*")) - s.GroupBy(joinT.C(group.UsersPrimaryKey[1])) - }) - if err := query.Select().Scan(ctx, &v); err != nil { - return err - } - m := make(map[string]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - gr.loadTotal = append(gr.loadTotal, func(_ context.Context, nodes []*Group) error { - for i := range nodes { - n := len(nodes[i].Edges.Users) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(group.UsersPrimaryKey[1], limit, pager.orderExpr(args.last != nil)) @@ -509,8 +459,10 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() t.loadTotal = append(t.loadTotal, func(ctx context.Context, nodes []*Todo) error { ids := make([]driver.Value, len(nodes)) @@ -537,45 +489,20 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte } return nil }) + } else { + t.loadTotal = append(t.loadTotal, func(_ context.Context, nodes []*Todo) error { + for i := range nodes { + n := len(nodes[i].Edges.Children) + nodes[i].Edges.totalCount[1] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - t.loadTotal = append(t.loadTotal, func(ctx context.Context, nodes []*Todo) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID string `sql:"todo_children"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - s.Where(sql.InValues(todo.ChildrenColumn, ids...)) - }) - if err := query.GroupBy(todo.ChildrenColumn).Aggregate(Count()).Scan(ctx, &v); err != nil { - return err - } - m := make(map[string]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[1] = &n - } - return nil - }) - } else { - t.loadTotal = append(t.loadTotal, func(_ context.Context, nodes []*Todo) error { - for i := range nodes { - n := len(nodes[i].Edges.Children) - nodes[i].Edges.totalCount[1] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(todo.ChildrenColumn, limit, pager.orderExpr(args.last != nil)) @@ -687,8 +614,10 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() u.loadTotal = append(u.loadTotal, func(ctx context.Context, nodes []*User) error { ids := make([]driver.Value, len(nodes)) @@ -719,49 +648,20 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte } return nil }) + } else { + u.loadTotal = append(u.loadTotal, func(_ context.Context, nodes []*User) error { + for i := range nodes { + n := len(nodes[i].Edges.Groups) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - u.loadTotal = append(u.loadTotal, func(ctx context.Context, nodes []*User) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID string `sql:"user_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), ids...)) - s.Select(joinT.C(user.GroupsPrimaryKey[0]), sql.Count("*")) - s.GroupBy(joinT.C(user.GroupsPrimaryKey[0])) - }) - if err := query.Select().Scan(ctx, &v); err != nil { - return err - } - m := make(map[string]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - u.loadTotal = append(u.loadTotal, func(_ context.Context, nodes []*User) error { - for i := range nodes { - n := len(nodes[i].Edges.Groups) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(user.GroupsPrimaryKey[0], limit, pager.orderExpr(args.last != nil)) diff --git a/entgql/internal/todopulid/ent/gql_collection.go b/entgql/internal/todopulid/ent/gql_collection.go index 7be03b676..286a79cb9 100644 --- a/entgql/internal/todopulid/ent/gql_collection.go +++ b/entgql/internal/todopulid/ent/gql_collection.go @@ -62,8 +62,10 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() c.loadTotal = append(c.loadTotal, func(ctx context.Context, nodes []*Category) error { ids := make([]driver.Value, len(nodes)) @@ -90,45 +92,20 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC } return nil }) + } else { + c.loadTotal = append(c.loadTotal, func(_ context.Context, nodes []*Category) error { + for i := range nodes { + n := len(nodes[i].Edges.Todos) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - c.loadTotal = append(c.loadTotal, func(ctx context.Context, nodes []*Category) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID pulid.ID `sql:"category_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - s.Where(sql.InValues(category.TodosColumn, ids...)) - }) - if err := query.GroupBy(category.TodosColumn).Aggregate(Count()).Scan(ctx, &v); err != nil { - return err - } - m := make(map[pulid.ID]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - c.loadTotal = append(c.loadTotal, func(_ context.Context, nodes []*Category) error { - for i := range nodes { - n := len(nodes[i].Edges.Todos) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(category.TodosColumn, limit, pager.orderExpr(args.last != nil)) @@ -299,8 +276,10 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() gr.loadTotal = append(gr.loadTotal, func(ctx context.Context, nodes []*Group) error { ids := make([]driver.Value, len(nodes)) @@ -331,49 +310,20 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon } return nil }) + } else { + gr.loadTotal = append(gr.loadTotal, func(_ context.Context, nodes []*Group) error { + for i := range nodes { + n := len(nodes[i].Edges.Users) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - gr.loadTotal = append(gr.loadTotal, func(ctx context.Context, nodes []*Group) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID pulid.ID `sql:"group_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[1]), ids...)) - s.Select(joinT.C(group.UsersPrimaryKey[1]), sql.Count("*")) - s.GroupBy(joinT.C(group.UsersPrimaryKey[1])) - }) - if err := query.Select().Scan(ctx, &v); err != nil { - return err - } - m := make(map[pulid.ID]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - gr.loadTotal = append(gr.loadTotal, func(_ context.Context, nodes []*Group) error { - for i := range nodes { - n := len(nodes[i].Edges.Users) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(group.UsersPrimaryKey[1], limit, pager.orderExpr(args.last != nil)) @@ -463,8 +413,10 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() t.loadTotal = append(t.loadTotal, func(ctx context.Context, nodes []*Todo) error { ids := make([]driver.Value, len(nodes)) @@ -491,45 +443,20 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte } return nil }) + } else { + t.loadTotal = append(t.loadTotal, func(_ context.Context, nodes []*Todo) error { + for i := range nodes { + n := len(nodes[i].Edges.Children) + nodes[i].Edges.totalCount[1] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - t.loadTotal = append(t.loadTotal, func(ctx context.Context, nodes []*Todo) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID pulid.ID `sql:"todo_children"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - s.Where(sql.InValues(todo.ChildrenColumn, ids...)) - }) - if err := query.GroupBy(todo.ChildrenColumn).Aggregate(Count()).Scan(ctx, &v); err != nil { - return err - } - m := make(map[pulid.ID]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[1] = &n - } - return nil - }) - } else { - t.loadTotal = append(t.loadTotal, func(_ context.Context, nodes []*Todo) error { - for i := range nodes { - n := len(nodes[i].Edges.Children) - nodes[i].Edges.totalCount[1] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(todo.ChildrenColumn, limit, pager.orderExpr(args.last != nil)) @@ -641,8 +568,10 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() u.loadTotal = append(u.loadTotal, func(ctx context.Context, nodes []*User) error { ids := make([]driver.Value, len(nodes)) @@ -673,49 +602,20 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte } return nil }) + } else { + u.loadTotal = append(u.loadTotal, func(_ context.Context, nodes []*User) error { + for i := range nodes { + n := len(nodes[i].Edges.Groups) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - u.loadTotal = append(u.loadTotal, func(ctx context.Context, nodes []*User) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID pulid.ID `sql:"user_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), ids...)) - s.Select(joinT.C(user.GroupsPrimaryKey[0]), sql.Count("*")) - s.GroupBy(joinT.C(user.GroupsPrimaryKey[0])) - }) - if err := query.Select().Scan(ctx, &v); err != nil { - return err - } - m := make(map[pulid.ID]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - u.loadTotal = append(u.loadTotal, func(_ context.Context, nodes []*User) error { - for i := range nodes { - n := len(nodes[i].Edges.Groups) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(user.GroupsPrimaryKey[0], limit, pager.orderExpr(args.last != nil)) diff --git a/entgql/internal/todouuid/ent/gql_collection.go b/entgql/internal/todouuid/ent/gql_collection.go index c9eea13f2..0ca61c337 100644 --- a/entgql/internal/todouuid/ent/gql_collection.go +++ b/entgql/internal/todouuid/ent/gql_collection.go @@ -62,8 +62,10 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() c.loadTotal = append(c.loadTotal, func(ctx context.Context, nodes []*Category) error { ids := make([]driver.Value, len(nodes)) @@ -90,45 +92,20 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC } return nil }) + } else { + c.loadTotal = append(c.loadTotal, func(_ context.Context, nodes []*Category) error { + for i := range nodes { + n := len(nodes[i].Edges.Todos) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - c.loadTotal = append(c.loadTotal, func(ctx context.Context, nodes []*Category) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID uuid.UUID `sql:"category_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - s.Where(sql.InValues(category.TodosColumn, ids...)) - }) - if err := query.GroupBy(category.TodosColumn).Aggregate(Count()).Scan(ctx, &v); err != nil { - return err - } - m := make(map[uuid.UUID]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - c.loadTotal = append(c.loadTotal, func(_ context.Context, nodes []*Category) error { - for i := range nodes { - n := len(nodes[i].Edges.Todos) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(category.TodosColumn, limit, pager.orderExpr(args.last != nil)) @@ -299,8 +276,10 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() gr.loadTotal = append(gr.loadTotal, func(ctx context.Context, nodes []*Group) error { ids := make([]driver.Value, len(nodes)) @@ -331,49 +310,20 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon } return nil }) + } else { + gr.loadTotal = append(gr.loadTotal, func(_ context.Context, nodes []*Group) error { + for i := range nodes { + n := len(nodes[i].Edges.Users) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - gr.loadTotal = append(gr.loadTotal, func(ctx context.Context, nodes []*Group) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID uuid.UUID `sql:"group_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[1]), ids...)) - s.Select(joinT.C(group.UsersPrimaryKey[1]), sql.Count("*")) - s.GroupBy(joinT.C(group.UsersPrimaryKey[1])) - }) - if err := query.Select().Scan(ctx, &v); err != nil { - return err - } - m := make(map[uuid.UUID]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - gr.loadTotal = append(gr.loadTotal, func(_ context.Context, nodes []*Group) error { - for i := range nodes { - n := len(nodes[i].Edges.Users) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(group.UsersPrimaryKey[1], limit, pager.orderExpr(args.last != nil)) @@ -463,8 +413,10 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() t.loadTotal = append(t.loadTotal, func(ctx context.Context, nodes []*Todo) error { ids := make([]driver.Value, len(nodes)) @@ -491,45 +443,20 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte } return nil }) + } else { + t.loadTotal = append(t.loadTotal, func(_ context.Context, nodes []*Todo) error { + for i := range nodes { + n := len(nodes[i].Edges.Children) + nodes[i].Edges.totalCount[1] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - t.loadTotal = append(t.loadTotal, func(ctx context.Context, nodes []*Todo) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID uuid.UUID `sql:"todo_children"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - s.Where(sql.InValues(todo.ChildrenColumn, ids...)) - }) - if err := query.GroupBy(todo.ChildrenColumn).Aggregate(Count()).Scan(ctx, &v); err != nil { - return err - } - m := make(map[uuid.UUID]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[1] = &n - } - return nil - }) - } else { - t.loadTotal = append(t.loadTotal, func(_ context.Context, nodes []*Todo) error { - for i := range nodes { - n := len(nodes[i].Edges.Children) - nodes[i].Edges.totalCount[1] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(todo.ChildrenColumn, limit, pager.orderExpr(args.last != nil)) @@ -641,8 +568,10 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() u.loadTotal = append(u.loadTotal, func(ctx context.Context, nodes []*User) error { ids := make([]driver.Value, len(nodes)) @@ -673,49 +602,20 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte } return nil }) + } else { + u.loadTotal = append(u.loadTotal, func(_ context.Context, nodes []*User) error { + for i := range nodes { + n := len(nodes[i].Edges.Groups) + nodes[i].Edges.totalCount[0] = &n + } + return nil + }) } - continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - u.loadTotal = append(u.loadTotal, func(ctx context.Context, nodes []*User) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID uuid.UUID `sql:"user_id"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), ids...)) - s.Select(joinT.C(user.GroupsPrimaryKey[0]), sql.Count("*")) - s.GroupBy(joinT.C(user.GroupsPrimaryKey[0])) - }) - if err := query.Select().Scan(ctx, &v); err != nil { - return err - } - m := make(map[uuid.UUID]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - u.loadTotal = append(u.loadTotal, func(_ context.Context, nodes []*User) error { - for i := range nodes { - n := len(nodes[i].Edges.Groups) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) + if ignoredEdges { + continue } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { modify := limitRows(user.GroupsPrimaryKey[0], limit, pager.orderExpr(args.last != nil)) diff --git a/entgql/template/collection.tmpl b/entgql/template/collection.tmpl index 6e9e4a6bd..2aa7018e9 100644 --- a/entgql/template/collection.tmpl +++ b/entgql/template/collection.tmpl @@ -62,28 +62,30 @@ func ({{ $receiver }} *{{ $query }}) collectField(ctx context.Context, op *graph if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + {{- /* Only add loadTotal query when needs */}} + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { {{- with extend $node "Edge" $e "Index" $i "Receiver" $receiver }} {{- template "gql_pagination/helper/load_total" . }} {{- end -}} + } else { + {{- /* All records will be loaded, so just count it */}} + {{ $receiver }}.loadTotal = append({{ $receiver }}.loadTotal, func(_ context.Context, nodes []*{{ $node.Name }}) error { + for i := range nodes { + n := len(nodes[i].Edges.{{ $e.StructField }}) + nodes[i].Edges.totalCount[{{ $i }}] = &n + } + return nil + }) } + } + if ignoredEdges { {{- /* Skip querying edges if "edges" "node" was not required. */}} continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - {{- with extend $node "Edge" $e "Index" $i "Receiver" $receiver }} - {{- template "gql_pagination/helper/load_total" . }} - {{- end -}} - } else { - {{ $receiver }}.loadTotal = append({{ $receiver }}.loadTotal, func(_ context.Context, nodes []*{{ $node.Name }}) error { - for i := range nodes { - n := len(nodes[i].Edges.{{ $e.StructField }}) - nodes[i].Edges.totalCount[{{ $i }}] = &n - } - return nil - }) - } + query = pager.applyCursors(query, args.after, args.before) if limit := paginateLimit(args.first, args.last); limit > 0 { {{- $fk := print $node.Package "." $fc.Edge.ColumnConstant }} From 3a215568fc84acbe205fb9fcf215ca2f9a47cc8f Mon Sep 17 00:00:00 2001 From: "Giau. Tran Minh" Date: Thu, 21 Jul 2022 03:50:35 +0700 Subject: [PATCH 3/3] entgql: skip query if no needs to load --- entgql/internal/todo/ent/gql_collection.go | 8 +++--- entgql/internal/todo/todo_test.go | 28 +++++++++++++++++++ .../internal/todogotype/ent/gql_collection.go | 8 +++--- .../internal/todopulid/ent/gql_collection.go | 8 +++--- .../internal/todouuid/ent/gql_collection.go | 8 +++--- entgql/template/collection.tmpl | 2 +- 6 files changed, 45 insertions(+), 17 deletions(-) diff --git a/entgql/internal/todo/ent/gql_collection.go b/entgql/internal/todo/ent/gql_collection.go index 784cfdb32..86a57210c 100644 --- a/entgql/internal/todo/ent/gql_collection.go +++ b/entgql/internal/todo/ent/gql_collection.go @@ -101,7 +101,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -319,7 +319,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -452,7 +452,7 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -611,7 +611,7 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } diff --git a/entgql/internal/todo/todo_test.go b/entgql/internal/todo/todo_test.go index 607c136f7..e04cd92f3 100644 --- a/entgql/internal/todo/todo_test.go +++ b/entgql/internal/todo/todo_test.go @@ -638,6 +638,34 @@ func (s *todoTestSuite) TestPaginationFiltering() { s.NoError(err) s.Equal(s.ent.Todo.Query().CountX(context.Background()), rsp.Todos.TotalCount) }) + + s.Run("Zero first", func() { + var ( + rsp response + query = `query() { + todos(first: 0) { + totalCount + } + }` + ) + err := s.Post(query, &rsp) + s.NoError(err) + s.Equal(s.ent.Todo.Query().CountX(context.Background()), rsp.Todos.TotalCount) + }) + + s.Run("Zero last", func() { + var ( + rsp response + query = `query() { + todos(last: 0) { + totalCount + } + }` + ) + err := s.Post(query, &rsp) + s.NoError(err) + s.Equal(s.ent.Todo.Query().CountX(context.Background()), rsp.Todos.TotalCount) + }) } func (s *todoTestSuite) TestFilteringWithCustomPredicate() { diff --git a/entgql/internal/todogotype/ent/gql_collection.go b/entgql/internal/todogotype/ent/gql_collection.go index 3f042d77c..10dd3d854 100644 --- a/entgql/internal/todogotype/ent/gql_collection.go +++ b/entgql/internal/todogotype/ent/gql_collection.go @@ -102,7 +102,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -320,7 +320,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -499,7 +499,7 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -658,7 +658,7 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } diff --git a/entgql/internal/todopulid/ent/gql_collection.go b/entgql/internal/todopulid/ent/gql_collection.go index 286a79cb9..496c4b987 100644 --- a/entgql/internal/todopulid/ent/gql_collection.go +++ b/entgql/internal/todopulid/ent/gql_collection.go @@ -102,7 +102,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -320,7 +320,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -453,7 +453,7 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -612,7 +612,7 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } diff --git a/entgql/internal/todouuid/ent/gql_collection.go b/entgql/internal/todouuid/ent/gql_collection.go index 0ca61c337..32c125c8b 100644 --- a/entgql/internal/todouuid/ent/gql_collection.go +++ b/entgql/internal/todouuid/ent/gql_collection.go @@ -102,7 +102,7 @@ func (c *CategoryQuery) collectField(ctx context.Context, op *graphql.OperationC }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -320,7 +320,7 @@ func (gr *GroupQuery) collectField(ctx context.Context, op *graphql.OperationCon }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -453,7 +453,7 @@ func (t *TodoQuery) collectField(ctx context.Context, op *graphql.OperationConte }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } @@ -612,7 +612,7 @@ func (u *UserQuery) collectField(ctx context.Context, op *graphql.OperationConte }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } diff --git a/entgql/template/collection.tmpl b/entgql/template/collection.tmpl index 2aa7018e9..4ee0794cb 100644 --- a/entgql/template/collection.tmpl +++ b/entgql/template/collection.tmpl @@ -81,7 +81,7 @@ func ({{ $receiver }} *{{ $query }}) collectField(ctx context.Context, op *graph }) } } - if ignoredEdges { + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { {{- /* Skip querying edges if "edges" "node" was not required. */}} continue }