108 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			108 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Go
		
	
	
	
| package db
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"strings"
 | |
| )
 | |
| 
 | |
| // MultiInserter makes it easy to construct a
 | |
| // `INSERT INTO table (...) VALUES ...;`
 | |
| // query which inserts multiple rows into the same table. It can also execute
 | |
| // the resulting query.
 | |
| type MultiInserter struct {
 | |
| 	// These are validated by the constructor as containing only characters
 | |
| 	// that are allowed in an unquoted identifier.
 | |
| 	// https://mariadb.com/kb/en/identifier-names/#unquoted
 | |
| 	table  string
 | |
| 	fields []string
 | |
| 
 | |
| 	values [][]interface{}
 | |
| }
 | |
| 
 | |
| // NewMultiInserter creates a new MultiInserter, checking for reasonable table
 | |
| // name and list of fields.
 | |
| // Safety: `table` and `fields` must contain only strings that are known at
 | |
| // compile time. They must not contain user-controlled strings.
 | |
| func NewMultiInserter(table string, fields []string) (*MultiInserter, error) {
 | |
| 	if len(table) == 0 || len(fields) == 0 {
 | |
| 		return nil, fmt.Errorf("empty table name or fields list")
 | |
| 	}
 | |
| 
 | |
| 	err := validMariaDBUnquotedIdentifier(table)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	for _, field := range fields {
 | |
| 		err := validMariaDBUnquotedIdentifier(field)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return &MultiInserter{
 | |
| 		table:  table,
 | |
| 		fields: fields,
 | |
| 		values: make([][]interface{}, 0),
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // Add registers another row to be included in the Insert query.
 | |
| func (mi *MultiInserter) Add(row []interface{}) error {
 | |
| 	if len(row) != len(mi.fields) {
 | |
| 		return fmt.Errorf("field count mismatch, got %d, expected %d", len(row), len(mi.fields))
 | |
| 	}
 | |
| 	mi.values = append(mi.values, row)
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // query returns the formatted query string, and the slice of arguments for
 | |
| // for borp to use in place of the query's question marks. Currently only
 | |
| // used by .Insert(), below.
 | |
| func (mi *MultiInserter) query() (string, []interface{}) {
 | |
| 	var questionsBuf strings.Builder
 | |
| 	var queryArgs []interface{}
 | |
| 	for _, row := range mi.values {
 | |
| 		// Safety: We are interpolating a string that will be used in a SQL
 | |
| 		// query, but we constructed that string in this function and know it
 | |
| 		// consists only of question marks joined with commas.
 | |
| 		fmt.Fprintf(&questionsBuf, "(%s),", QuestionMarks(len(mi.fields)))
 | |
| 		queryArgs = append(queryArgs, row...)
 | |
| 	}
 | |
| 
 | |
| 	questions := strings.TrimRight(questionsBuf.String(), ",")
 | |
| 
 | |
| 	// Safety: we are interpolating `mi.table` and `mi.fields` into an SQL
 | |
| 	// query. We know they contain, respectively, a valid unquoted identifier
 | |
| 	// and a slice of valid unquoted identifiers because we verified that in
 | |
| 	// the constructor. We know the query overall has valid syntax because we
 | |
| 	// generate it entirely within this function.
 | |
| 	query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", mi.table, strings.Join(mi.fields, ","), questions)
 | |
| 
 | |
| 	return query, queryArgs
 | |
| }
 | |
| 
 | |
| // Insert inserts all the collected rows into the database represented by
 | |
| // `queryer`.
 | |
| func (mi *MultiInserter) Insert(ctx context.Context, db Execer) error {
 | |
| 	if len(mi.values) == 0 {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	query, queryArgs := mi.query()
 | |
| 	res, err := db.ExecContext(ctx, query, queryArgs...)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	affected, err := res.RowsAffected()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if affected != int64(len(mi.values)) {
 | |
| 		return fmt.Errorf("unexpected number of rows inserted: %d != %d", affected, len(mi.values))
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 |