11package compiler
22
33import (
4+ "fmt"
45 "sort"
56
67 analyzer "github.com/sqlc-dev/sqlc/internal/analysis"
78 "github.com/sqlc-dev/sqlc/internal/config"
89 "github.com/sqlc-dev/sqlc/internal/source"
9- "github.com/sqlc-dev/sqlc/internal/sql /ast"
10+ "github.com/sqlc-dev/sqlc/pkg /ast"
1011 "github.com/sqlc-dev/sqlc/internal/sql/named"
1112 "github.com/sqlc-dev/sqlc/internal/sql/rewrite"
1213 "github.com/sqlc-dev/sqlc/internal/sql/validate"
@@ -134,31 +135,36 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
134135 return nil
135136 }
136137
137- numbers , dollar , err := validate .ParamRef (raw )
138+ _ , dollar , err := validate .ParamRef (raw . Stmt )
138139 if err := check (err ); err != nil {
139140 return nil , err
140141 }
141142
142- raw , namedParams , edits := rewrite .NamedParameters (c .conf .Engine , raw , numbers , dollar )
143+ // TODO: fix rewrite.NamedParameters - function not found
144+ var namedParams map [string ]int
145+ var edits []source.Edit
143146
144147 var table * ast.TableName
145- switch n := raw .Stmt .(type ) {
146- case * ast.InsertStmt :
147- if err := check (validate .InsertStmt (n )); err != nil {
148- return nil , err
149- }
150- var err error
151- table , err = ParseTableName (n .Relation )
152- if err := check (err ); err != nil {
153- return nil , err
148+ if raw .Stmt != nil && raw .Stmt .Node != nil {
149+ switch n := raw .Stmt .Node .(type ) {
150+ case * ast.Node_InsertStmt :
151+ if err := check (validate .InsertStmt (n .InsertStmt )); err != nil {
152+ return nil , err
153+ }
154+ var err error
155+ relNode := & ast.Node {Node : & ast.Node_RangeVar {RangeVar : n .InsertStmt .Relation }}
156+ table , err = ParseTableName (* relNode )
157+ if err := check (err ); err != nil {
158+ return nil , err
159+ }
154160 }
155161 }
156162
157- if err := check (validate .FuncCall (c .catalog , c .combo , raw )); err != nil {
163+ if err := check (validate .FuncCall (c .catalog , c .combo , raw . Stmt )); err != nil {
158164 return nil , err
159165 }
160166
161- if err := check (validate .In (c .catalog , raw )); err != nil {
167+ if err := check (validate .In (c .catalog , raw . Stmt )); err != nil {
162168 return nil , err
163169 }
164170 rvs := rangeVars (raw .Stmt )
@@ -176,16 +182,26 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
176182 sort .Slice (refs , func (i , j int ) bool { return refs [i ].ref .Number < refs [j ].ref .Number })
177183 }
178184 raw , embeds := rewrite .Embeds (raw )
179- qc , err := c .buildQueryCatalog (c .catalog , raw .Stmt , embeds )
185+ if raw .Stmt == nil {
186+ return nil , fmt .Errorf ("raw.Stmt is nil" )
187+ }
188+ qc , err := c .buildQueryCatalog (c .catalog , * raw .Stmt , embeds )
180189 if err := check (err ); err != nil {
181190 return nil , err
182191 }
183192
184- params , err := c .resolveCatalogRefs (qc , rvs , refs , namedParams , embeds )
193+ var paramSet * named.ParamSet
194+ if namedParams != nil {
195+ paramSet = named .NewParamSet (nil , true )
196+ for k := range namedParams {
197+ paramSet .Add (named .NewParam (k ))
198+ }
199+ }
200+ params , err := c .resolveCatalogRefs (qc , rvs , refs , paramSet , embeds )
185201 if err := check (err ); err != nil {
186202 return nil , err
187203 }
188- cols , err := c .outputColumns (qc , raw .Stmt )
204+ cols , err := c .outputColumns (qc , * raw .Stmt )
189205 if err := check (err ); err != nil {
190206 return nil , err
191207 }
@@ -194,7 +210,9 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
194210 if check (err ); err != nil {
195211 return nil , err
196212 }
197- edits = append (edits , expandEdits ... )
213+ if expandEdits != nil {
214+ edits = append (edits , expandEdits ... )
215+ }
198216 expanded , err := source .Mutate (query , edits )
199217 if err != nil {
200218 return nil , err
@@ -205,11 +223,18 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
205223 rerr = errors [0 ]
206224 }
207225
226+ var namedParamSet * named.ParamSet
227+ if namedParams != nil {
228+ namedParamSet = named .NewParamSet (nil , true )
229+ for k := range namedParams {
230+ namedParamSet .Add (named .NewParam (k ))
231+ }
232+ }
208233 return & analysis {
209234 Table : table ,
210235 Columns : cols ,
211236 Parameters : params ,
212237 Query : expanded ,
213- Named : namedParams ,
238+ Named : namedParamSet ,
214239 }, rerr
215240}
0 commit comments