Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/compiler/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
if err := check(err); err != nil {
return nil, err
}
if err := check(validate.OnConflictClause(c.catalog, n, table)); err != nil {
return nil, err
}
}

if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- name: AddItem :exec
INSERT INTO cart_items (owner_id, product_id, price_amount, price_currency)
VALUES ($1, $2, $3, $4)
ON CONFLICT (owner_id, product_id) DO UPDATE
SET price_amount1 = EXCLUDED.price_amount1,
price_currency = EXCLUDED.price_currency;
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CREATE TABLE cart_items (
owner_id VARCHAR(255) NOT NULL,
product_id UUID NOT NULL,
price_amount DECIMAL NOT NULL,
price_currency VARCHAR(3) NOT NULL,
PRIMARY KEY (owner_id, product_id)
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"version": "2",
"sql": [
{
"engine": "postgresql",
"schema": "schema.sql",
"queries": "query.sql",
"gen": {
"go": {
"package": "querytest",
"out": "go",
"sql_package": "pgx/v5"
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# package querytest
query.sql:5:9: column "price_amount1" of relation "cart_items" does not exist
106 changes: 106 additions & 0 deletions internal/sql/validate/on_conflict.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package validate

import (
"fmt"
"strings"

"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/astutils"
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
)

func OnConflictClause(cat *catalog.Catalog, stmt *ast.InsertStmt, tableName *ast.TableName) error {
if stmt.OnConflictClause == nil {
return nil
}

occ := stmt.OnConflictClause

if occ.Action != ast.OnConflictActionUpdate {
return nil
}

if tableName == nil {
return nil
}

tbl, err := cat.GetTable(tableName)
if err != nil {
return err
}

relName := ""
if tbl.Rel != nil {
relName = tbl.Rel.Name
}

validCols := make(map[string]struct{}, len(tbl.Columns))
for _, c := range tbl.Columns {
validCols[strings.ToLower(c.Name)] = struct{}{}
}

if occ.TargetList == nil {
return nil
}

for _, item := range occ.TargetList.Items {
res, ok := item.(*ast.ResTarget)
if !ok {
continue
}

if res.Name != nil {
colName := strings.ToLower(*res.Name)
if _, exists := validCols[colName]; !exists {
return &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("column %q of relation %q does not exist", *res.Name, relName),
Location: res.Location,
}
}
}

if res.Val != nil {
if err := validateExcludedRefs(res.Val, validCols, relName); err != nil {
return err
}
}
}

return nil
}

func validateExcludedRefs(node ast.Node, validCols map[string]struct{}, tableName string) error {
refs := astutils.Search(node, func(n ast.Node) bool {
_, ok := n.(*ast.ColumnRef)
return ok
})

for _, ref := range refs.Items {
colRef, ok := ref.(*ast.ColumnRef)
if !ok {
continue
}

parts := make([]string, 0, len(colRef.Fields.Items))
for _, field := range colRef.Fields.Items {
if s, ok := field.(*ast.String); ok {
parts = append(parts, s.Str)
}
}

if len(parts) == 2 && strings.ToLower(parts[0]) == "excluded" {
colName := strings.ToLower(parts[1])
if _, exists := validCols[colName]; !exists {
return &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("column %q does not exist in relation %q (via EXCLUDED)", parts[1], tableName),
Location: colRef.Location,
}
}
}
}

return nil
}
130 changes: 130 additions & 0 deletions internal/sql/validate/on_conflict_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package validate

import (
"strings"
"testing"

"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
)

func makeTestCatalog(t *testing.T) (*catalog.Catalog, *ast.TableName) {
t.Helper()

p := postgresql.NewParser()
stmts, err := p.Parse(strings.NewReader(`
CREATE TABLE cart_items (
owner_id VARCHAR(255) NOT NULL,
product_id UUID NOT NULL,
price_amount DECIMAL NOT NULL,
price_currency VARCHAR(3) NOT NULL,
PRIMARY KEY (owner_id, product_id)
);
`))
if err != nil {
t.Fatalf("parse schema: %v", err)
}

cat := catalog.New("public")
for _, stmt := range stmts {
if err := cat.Update(stmt, nil); err != nil {
t.Fatalf("update catalog: %v", err)
}
}

tableName := &ast.TableName{Schema: "public", Name: "cart_items"}
return cat, tableName
}

func makeStmt(action ast.OnConflictAction, setItems []struct{ col, val string }) *ast.InsertStmt {
stmt := &ast.InsertStmt{
Relation: &ast.RangeVar{
Schemaname: strPtr("public"),
Relname: strPtr("cart_items"),
},
}

if action == ast.OnConflictActionNone {
return stmt
}

items := make([]ast.Node, 0, len(setItems))
for _, si := range setItems {
colName := si.col
items = append(items, &ast.ResTarget{
Name: &colName,
Val: &ast.ColumnRef{
Fields: &ast.List{
Items: []ast.Node{
&ast.String{Str: "excluded"},
&ast.String{Str: si.val},
},
},
},
})
}

stmt.OnConflictClause = &ast.OnConflictClause{
Action: action,
TargetList: &ast.List{Items: items},
}
return stmt
}

func strPtr(s string) *string { return &s }

func TestOnConflictClause(t *testing.T) {
cat, tableName := makeTestCatalog(t)

tests := []struct {
name string
stmt *ast.InsertStmt
wantErr bool
}{
{
name: "valid columns in SET and EXCLUDED",
stmt: makeStmt(ast.OnConflictActionUpdate, []struct{ col, val string }{
{"price_amount", "price_amount"},
{"price_currency", "price_currency"},
}),
wantErr: false,
},
{
name: "invalid column on left side of SET",
stmt: makeStmt(ast.OnConflictActionUpdate, []struct{ col, val string }{
{"price_amount1", "price_amount"},
}),
wantErr: true,
},
{
name: "invalid EXCLUDED reference on right side",
stmt: makeStmt(ast.OnConflictActionUpdate, []struct{ col, val string }{
{"price_amount", "price_amount1"},
}),
wantErr: true,
},
{
name: "DO NOTHING skips column validation",
stmt: makeStmt(ast.OnConflictActionNothing, nil),
wantErr: false,
},
{
name: "no OnConflictClause passes without error",
stmt: makeStmt(ast.OnConflictActionNone, nil),
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := OnConflictClause(cat, tt.stmt, tableName)
if tt.wantErr && err == nil {
t.Error("expected error but got none")
}
if !tt.wantErr && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
Loading