Migrate some SQL sanitizer tests to java (#7148)

This commit is contained in:
jason plumb 2022-11-14 22:15:32 -08:00 committed by GitHub
parent c6bbf28eac
commit d4b29cf521
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 628 additions and 374 deletions

View File

@ -50,7 +50,7 @@ WHITESPACE = [ \t\r\n]+
}
// max length of the sanitized statement - SQLs longer than this will be trimmed
private static final int LIMIT = 32 * 1024;
static final int LIMIT = 32 * 1024;
private final StringBuilder builder = new StringBuilder();

View File

@ -1,148 +0,0 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/
package io.opentelemetry.instrumentation.api.db
import spock.lang.Specification
import spock.lang.Unroll
class RedisCommandSanitizerTest extends Specification {
@Unroll
def "should sanitize #expected"() {
when:
def sanitized = RedisCommandSanitizer.create(true).sanitize(command, args)
then:
sanitized == expected
where:
command | args | expected
// Connection
"AUTH" | ["password"] | "AUTH ?"
"HELLO" | ["3", "AUTH", "username", "password"] | "HELLO 3 AUTH ? ?"
"HELLO" | ["3"] | "HELLO 3"
// Hashes
"HMSET" | ["hash", "key1", "value1", "key2", "value2"] | "HMSET hash key1 ? key2 ?"
"HSET" | ["hash", "key1", "value1", "key2", "value2"] | "HSET hash key1 ? key2 ?"
"HSETNX" | ["hash", "key", "value"] | "HSETNX hash key ?"
// HyperLogLog
"PFADD" | ["hll", "a", "b", "c"] | "PFADD hll ? ? ?"
// Keys
"MIGRATE" | ["127.0.0.1", "4242", "key", "0", "5000", "AUTH", "password"] | "MIGRATE 127.0.0.1 4242 key 0 5000 AUTH ?"
"RESTORE" | ["key", "42", "value"] | "RESTORE key 42 ?"
// Lists
"LINSERT" | ["list", "BEFORE", "value1", "value2"] | "LINSERT list BEFORE ? ?"
"LPOS" | ["list", "value"] | "LPOS list ?"
"LPUSH" | ["list", "value1", "value2"] | "LPUSH list ? ?"
"LPUSHX" | ["list", "value1", "value2"] | "LPUSHX list ? ?"
"LREM" | ["list", "2", "value"] | "LREM list ? ?"
"LSET" | ["list", "2", "value"] | "LSET list ? ?"
"RPUSH" | ["list", "value1", "value2"] | "RPUSH list ? ?"
"RPUSHX" | ["list", "value1", "value2"] | "RPUSHX list ? ?"
// Pub/Sub
"PUBLISH" | ["channel", "message"] | "PUBLISH channel ?"
// Scripting
"EVAL" | ["script", "2", "key1", "key2", "value"] | "EVAL script 2 key1 key2 ?"
"EVALSHA" | ["sha1", "0", "value1", "value2"] | "EVALSHA sha1 0 ? ?"
// Sets
"SADD" | ["set", "value1", "value2"] | "SADD set ? ?"
"SISMEMBER" | ["set", "value"] | "SISMEMBER set ?"
"SMISMEMBER" | ["set", "value1", "value2"] | "SMISMEMBER set ? ?"
"SMOVE" | ["set1", "set2", "value"] | "SMOVE set1 set2 ?"
"SREM" | ["set", "value1", "value2"] | "SREM set ? ?"
// Server
"CONFIG" | ["SET", "masterpassword", "password"] | "CONFIG SET masterpassword ?"
// Sorted Sets
"ZADD" | ["sset", "1", "value1", "2", "value2"] | "ZADD sset ? ? ? ?"
"ZCOUNT" | ["sset", "1", "10"] | "ZCOUNT sset ? ?"
"ZINCRBY" | ["sset", "1", "value"] | "ZINCRBY sset ? ?"
"ZLEXCOUNT" | ["sset", "1", "10"] | "ZLEXCOUNT sset ? ?"
"ZMSCORE" | ["sset", "value1", "value2"] | "ZMSCORE sset ? ?"
"ZRANGEBYLEX" | ["sset", "1", "10"] | "ZRANGEBYLEX sset ? ?"
"ZRANGEBYSCORE" | ["sset", "1", "10"] | "ZRANGEBYSCORE sset ? ?"
"ZRANK" | ["sset", "value"] | "ZRANK sset ?"
"ZREM" | ["sset", "value1", "value2"] | "ZREM sset ? ?"
"ZREMRANGEBYLEX" | ["sset", "1", "10"] | "ZREMRANGEBYLEX sset ? ?"
"ZREMRANGEBYSCORE" | ["sset", "1", "10"] | "ZREMRANGEBYSCORE sset ? ?"
"ZREVRANGEBYLEX" | ["sset", "1", "10"] | "ZREVRANGEBYLEX sset ? ?"
"ZREVRANGEBYSCORE" | ["sset", "1", "10"] | "ZREVRANGEBYSCORE sset ? ?"
"ZREVRANK" | ["sset", "value"] | "ZREVRANK sset ?"
"ZSCORE" | ["sset", "value"] | "ZSCORE sset ?"
// Streams
"XADD" | ["stream", "*", "key1", "value1", "key2", "value2"] | "XADD stream * key1 ? key2 ?"
// Strings
"APPEND" | ["key", "value"] | "APPEND key ?"
"GETSET" | ["key", "value"] | "GETSET key ?"
"MSET" | ["key1", "value1", "key2", "value2"] | "MSET key1 ? key2 ?"
"MSETNX" | ["key1", "value1", "key2", "value2"] | "MSETNX key1 ? key2 ?"
"PSETEX" | ["key", "10000", "value"] | "PSETEX key 10000 ?"
"SET" | ["key", "value"] | "SET key ?"
"SETEX" | ["key", "10", "value"] | "SETEX key 10 ?"
"SETNX" | ["key", "value"] | "SETNX key ?"
"SETRANGE" | ["key", "42", "value"] | "SETRANGE key ? ?"
}
@Unroll
def "should keep all arguments of #command"() {
given:
def args = ["arg1", "arg 2"]
when:
def sanitized = RedisCommandSanitizer.create(true).sanitize(command, args)
then:
sanitized == command + " " + args.join(" ")
where:
command << [
// Cluster
"CLUSTER", "READONLY", "READWRITE",
// Connection
"CLIENT", "ECHO", "PING", "QUIT", "SELECT",
// Geo
"GEOADD", "GEODIST", "GEOHASH", "GEOPOS", "GEORADIUS", "GEORADIUSBYMEMBER",
// Hashes
"HDEL", "HEXISTS", "HGET", "HGETALL", "HINCRBY", "HINCRBYFLOAT", "HKEYS", "HLEN", "HMGET",
"HSCAN", "HSTRLEN", "HVALS",
// HyperLogLog
"PFCOUNT", "PFMERGE",
// Keys
"DEL", "DUMP", "EXISTS", "EXPIRE", "EXPIREAT", "KEYS", "MOVE", "OBJECT", "PERSIST", "PEXPIRE",
"PEXPIREAT", "PTTL", "RANDOMKEY", "RENAME", "RENAMENX", "RESTORE", "SCAN", "SORT", "TOUCH",
"TTL", "TYPE", "UNLINK", "WAIT",
// Lists
"BLMOVE", "BLPOP", "BRPOP", "BRPOPLPUSH", "LINDEX", "LLEN", "LMOVE", "LPOP", "LRANGE",
"LTRIM", "RPOP", "RPOPLPUSH",
// Pub/Sub
"PSUBSCRIBE", "PUBSUB", "PUNSUBSCRIBE", "SUBSCRIBE", "UNSUBSCRIBE",
// Server
"ACL", "BGREWRITEAOF", "BGSAVE", "COMMAND", "DBSIZE", "DEBUG", "FLUSHALL", "FLUSHDB", "INFO",
"LASTSAVE", "LATENCY", "LOLWUT", "MEMORY", "MODULE", "MONITOR", "PSYNC", "REPLICAOF", "ROLE",
"SAVE", "SHUTDOWN", "SLAVEOF", "SLOWLOG", "SWAPDB", "SYNC", "TIME",
// Sets
"SCARD", "SDIFF", "SDIFFSTORE", "SINTER", "SINTERSTORE", "SMEMBERS", "SPOP", "SRANDMEMBER",
"SSCAN", "SUNION", "SUNIONSTORE",
// Sorted Sets
"BZPOPMAX", "BZPOPMIN", "ZCARD", "ZINTER", "ZINTERSTORE", "ZPOPMAX", "ZPOPMIN", "ZRANGE",
"ZREMRANGEBYRANK", "ZREVRANGE", "ZSCAN", "ZUNION", "ZUNIONSTORE",
// Streams
"XACK", "XCLAIM", "XDEL", "XGROUP", "XINFO", "XLEN", "XPENDING", "XRANGE", "XREAD",
"XREADGROUP", "XREVRANGE", "XTRIM",
// Strings
"BITCOUNT", "BITFIELD", "BITOP", "BITPOS", "DECR", "DECRBY", "GET", "GETBIT", "GETRANGE",
"INCR", "INCRBY", "INCRBYFLOAT", "MGET", "SETBIT", "STRALGO", "STRLEN",
// Transactions
"DISCARD", "EXEC", "MULTI", "UNWATCH", "WATCH"
]
}
def "should mask all arguments of an unknown command"() {
when:
def sanitized = RedisCommandSanitizer.create(true).sanitize("NEWAUTH", ["password", "secret"])
then:
sanitized == "NEWAUTH ? ?"
}
}

View File

@ -1,225 +0,0 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/
package io.opentelemetry.instrumentation.api.db
import spock.lang.Specification
import spock.lang.Unroll
class SqlStatementSanitizerTest extends Specification {
def "normalize #originalSql"() {
setup:
def actualSanitized = SqlStatementSanitizer.create(true).sanitize(originalSql)
expect:
actualSanitized.getFullStatement() == sanitizedSql
where:
originalSql | sanitizedSql
// Numbers
"SELECT * FROM TABLE WHERE FIELD=1234" | "SELECT * FROM TABLE WHERE FIELD=?"
"SELECT * FROM TABLE WHERE FIELD = 1234" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD>=-1234" | "SELECT * FROM TABLE WHERE FIELD>=?"
"SELECT * FROM TABLE WHERE FIELD<-1234" | "SELECT * FROM TABLE WHERE FIELD<?"
"SELECT * FROM TABLE WHERE FIELD <.1234" | "SELECT * FROM TABLE WHERE FIELD <?"
"SELECT 1.2" | "SELECT ?"
"SELECT -1.2" | "SELECT ?"
"SELECT -1.2e-9" | "SELECT ?"
"SELECT 2E+9" | "SELECT ?"
"SELECT +0.2" | "SELECT ?"
"SELECT .2" | "SELECT ?"
"7" | "?"
".7" | "?"
"-7" | "?"
"+7" | "?"
"SELECT 0x0af764" | "SELECT ?"
"SELECT 0xdeadBEEF" | "SELECT ?"
"SELECT * FROM \"TABLE\"" | "SELECT * FROM \"TABLE\""
// Not numbers but could be confused as such
"SELECT A + B" | "SELECT A + B"
"SELECT -- comment" | "SELECT -- comment"
"SELECT * FROM TABLE123" | "SELECT * FROM TABLE123"
"SELECT FIELD2 FROM TABLE_123 WHERE X<>7" | "SELECT FIELD2 FROM TABLE_123 WHERE X<>?"
// Semi-nonsensical almost-numbers to elide or not
"SELECT --83--...--8e+76e3E-1" | "SELECT ?"
"SELECT DEADBEEF" | "SELECT DEADBEEF"
"SELECT 123-45-6789" | "SELECT ?"
"SELECT 1/2/34" | "SELECT ?/?/?"
// Basic ' strings
"SELECT * FROM TABLE WHERE FIELD = ''" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = 'words and spaces'" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = ' an escaped '' quote mark inside'" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = '\\\\'" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = '\"inside doubles\"'" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = '\"\$\$\$\$\"'" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = 'a single \" doublequote inside'" | "SELECT * FROM TABLE WHERE FIELD = ?"
// Some databases allow using dollar-quoted strings
"SELECT * FROM TABLE WHERE FIELD = \$\$\$\$" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = \$\$words and spaces\$\$" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = \$\$quotes '\" inside\$\$" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = \$\$\"''\"\$\$" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = \$\$\\\\\$\$" | "SELECT * FROM TABLE WHERE FIELD = ?"
// Unicode, including a unicode identifier with a trailing number
"SELECT * FROM TABLE\u09137 WHERE FIELD = '\u0194'" | "SELECT * FROM TABLE\u09137 WHERE FIELD = ?"
// whitespace normalization
"SELECT * \t\r\nFROM TABLE WHERE FIELD1 = 12344 AND FIELD2 = 5678" | "SELECT * FROM TABLE WHERE FIELD1 = ? AND FIELD2 = ?"
// hibernate/jpa query language
"FROM TABLE WHERE FIELD=1234" | "FROM TABLE WHERE FIELD=?"
}
def "normalize couchbase #originalSql"() {
setup:
def actualSanitized = SqlStatementSanitizer.create(true).sanitize(originalSql, SqlDialect.COUCHBASE)
expect:
actualSanitized.getFullStatement() == sanitizedSql
where:
originalSql | sanitizedSql
// Some databases support/encourage " instead of ' with same escape rules
"SELECT * FROM TABLE WHERE FIELD = \"\"" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = \"words and spaces'\"" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = \" an escaped \"\" quote mark inside\"" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = \"\\\\\"" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = \"'inside singles'\"" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = \"'\$\$\$\$'\"" | "SELECT * FROM TABLE WHERE FIELD = ?"
"SELECT * FROM TABLE WHERE FIELD = \"a single ' singlequote inside\"" | "SELECT * FROM TABLE WHERE FIELD = ?"
}
@Unroll
def "should simplify #sql"() {
expect:
SqlStatementSanitizer.create(true).sanitize(sql) == expected
where:
sql | expected
// Select
'SELECT x, y, z FROM schema.table' | SqlStatementInfo.create(sql, 'SELECT', 'schema.table')
'SELECT x, y, z FROM `schema table`' | SqlStatementInfo.create(sql, 'SELECT', 'schema table')
'SELECT x, y, z FROM "schema table"' | SqlStatementInfo.create(sql, 'SELECT', 'schema table')
'WITH subquery as (select a from b) SELECT x, y, z FROM table' | SqlStatementInfo.create(sql, 'SELECT', null)
'SELECT x, y, (select a from b) as z FROM table' | SqlStatementInfo.create(sql, 'SELECT', null)
'select delete, insert into, merge, update from table' | SqlStatementInfo.create(sql, 'SELECT', 'table')
'select col /* from table2 */ from table' | SqlStatementInfo.create(sql, 'SELECT', 'table')
'select col from table join anotherTable' | SqlStatementInfo.create(sql, 'SELECT', null)
'select col from (select * from anotherTable)' | SqlStatementInfo.create(sql, 'SELECT', null)
'select col from (select * from anotherTable) alias' | SqlStatementInfo.create(sql, 'SELECT', null)
'select col from table1 union select col from table2' | SqlStatementInfo.create(sql, 'SELECT', null)
'select col from table where col in (select * from anotherTable)' | SqlStatementInfo.create(sql, 'SELECT', null)
'select col from table1, table2' | SqlStatementInfo.create(sql, 'SELECT', null)
'select col from table1 t1, table2 t2' | SqlStatementInfo.create(sql, 'SELECT', null)
'select col from table1 as t1, table2 as t2' | SqlStatementInfo.create(sql, 'SELECT', null)
'select col from table where col in (1, 2, 3)' | SqlStatementInfo.create('select col from table where col in (?, ?, ?)', 'SELECT', 'table')
'select col from table order by col, col2' | SqlStatementInfo.create(sql, 'SELECT', 'table')
'select ąś∂ń© from źćļńĶ order by col, col2' | SqlStatementInfo.create(sql, 'SELECT', 'źćļńĶ')
'select 12345678' | SqlStatementInfo.create('select ?', 'SELECT', null)
'/* update comment */ select * from table1' | SqlStatementInfo.create(sql, 'SELECT', 'table1')
'select /*((*/abc from table' | SqlStatementInfo.create(sql, 'SELECT', 'table')
'SeLeCT * FrOm TAblE' | SqlStatementInfo.create(sql, 'SELECT', 'TAblE')
// hibernate/jpa
'FROM schema.table' | SqlStatementInfo.create(sql, 'SELECT', 'schema.table')
'/* update comment */ from table1' | SqlStatementInfo.create(sql, 'SELECT', 'table1')
// Insert
' insert into table where lalala' | SqlStatementInfo.create(sql, 'INSERT', 'table')
'insert insert into table where lalala' | SqlStatementInfo.create(sql, 'INSERT', 'table')
'insert into db.table where lalala' | SqlStatementInfo.create(sql, 'INSERT', 'db.table')
'insert into `db table` where lalala' | SqlStatementInfo.create(sql, 'INSERT', 'db table')
'insert into "db table" where lalala' | SqlStatementInfo.create(sql, 'INSERT', 'db table')
'insert without i-n-t-o' | SqlStatementInfo.create(sql, 'INSERT', null)
// Delete
'delete from table where something something' | SqlStatementInfo.create(sql, 'DELETE', 'table')
'delete from `my table` where something something' | SqlStatementInfo.create(sql, 'DELETE', 'my table')
'delete from "my table" where something something' | SqlStatementInfo.create(sql, 'DELETE', 'my table')
'delete from 12345678' | SqlStatementInfo.create('delete from ?', 'DELETE', null)
'delete (((' | SqlStatementInfo.create('delete (((', 'DELETE', null)
// Update
'update table set answer=42' | SqlStatementInfo.create('update table set answer=?', 'UPDATE', 'table')
'update `my table` set answer=42' | SqlStatementInfo.create('update `my table` set answer=?', 'UPDATE', 'my table')
'update "my table" set answer=42' | SqlStatementInfo.create('update "my table" set answer=?', 'UPDATE', 'my table')
'update /*table' | SqlStatementInfo.create(sql, 'UPDATE', null)
// Merge
'merge into table' | SqlStatementInfo.create(sql, 'MERGE', 'table')
'merge into `my table`' | SqlStatementInfo.create(sql, 'MERGE', 'my table')
'merge into "my table"' | SqlStatementInfo.create(sql, 'MERGE', 'my table')
'merge table (into is optional in some dbs)' | SqlStatementInfo.create(sql, 'MERGE', 'table')
'merge (into )))' | SqlStatementInfo.create(sql, 'MERGE', null)
// Unknown operation
'and now for something completely different' | SqlStatementInfo.create(sql, null, null)
'' | SqlStatementInfo.create(sql, null, null)
null | SqlStatementInfo.create(sql, null, null)
}
def "very long SELECT statements don't cause problems"() {
given:
def sb = new StringBuilder("SELECT * FROM table WHERE")
for (int i = 0; i < 2000; i++) {
sb.append(" column").append(i).append("=123 and")
}
def query = sb.toString()
expect:
def sanitizedQuery = query.replace('=123', '=?').substring(0, AutoSqlSanitizer.LIMIT)
SqlStatementSanitizer.create(true).sanitize(query) == SqlStatementInfo.create(sanitizedQuery, "SELECT", "table")
}
def "lots and lots of ticks don't cause stack overflow or long runtimes"() {
setup:
String s = "'"
for (int i = 0; i < 10000; i++) {
assert SqlStatementSanitizer.create(true).sanitize(s) != null
s += "'"
}
}
def "very long numbers don't cause a problem"() {
setup:
String s = ""
for (int i = 0; i < 10000; i++) {
s += String.valueOf(i)
}
assert "?" == SqlStatementSanitizer.create(true).sanitize(s).getFullStatement()
}
def "very long numbers at end of table name don't cause problem"() {
setup:
String s = "A"
for (int i = 0; i < 10000; i++) {
s += String.valueOf(i)
}
assert s.substring(0, AutoSqlSanitizer.LIMIT) == SqlStatementSanitizer.create(true).sanitize(s).getFullStatement()
}
def "test 32k truncation"() {
setup:
StringBuffer s = new StringBuffer()
for (int i = 0; i < 10000; i++) {
s.append("SELECT * FROM TABLE WHERE FIELD = 1234 AND ")
}
String sanitized = SqlStatementSanitizer.create(true).sanitize(s.toString()).getFullStatement()
System.out.println(sanitized.length())
assert sanitized.length() <= AutoSqlSanitizer.LIMIT
assert !sanitized.contains("1234")
}
def "random bytes don't cause exceptions or timeouts"() {
setup:
Random r = new Random(0)
for (int i = 0; i < 1000; i++) {
StringBuffer sb = new StringBuffer()
for (int c = 0; c < 1000; c++) {
sb.append((char) r.nextInt((int) Character.MAX_VALUE))
}
SqlStatementSanitizer.create(true).sanitize(sb.toString())
}
}
}

View File

@ -0,0 +1,303 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/
package io.opentelemetry.instrumentation.api.db;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;
class RedisCommandSanitizerTest {
@ParameterizedTest
@ArgumentsSource(SanitizeArgs.class)
void shouldSanitizeExpected(String command, List<String> args, String expected) {
String result = RedisCommandSanitizer.create(true).sanitize(command, args);
assertThat(result).isEqualTo(expected);
}
@ParameterizedTest
@ArgumentsSource(KeepArguments.class)
void shouldKeepAllArguments(String command) {
List<String> args = list("arg1", "arg 2");
String result = RedisCommandSanitizer.create(true).sanitize(command, args);
assertThat(result).isEqualTo(command + " " + String.join(" ", args));
}
@Test
void maskAllArgsOfUnknownCommand() {
String result =
RedisCommandSanitizer.create(true).sanitize("NEWAUTH", list("password", "secret"));
assertThat(result).isEqualTo("NEWAUTH ? ?");
}
static class SanitizeArgs implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(
// Connection
Arguments.of("AUTH", list("password"), "AUTH ?"),
Arguments.of("HELLO", list("3", "AUTH", "username", "password"), "HELLO 3 AUTH ? ?"),
Arguments.of("HELLO", list("3"), "HELLO 3"),
// Hashes
Arguments.of(
"HMSET",
list("hash", "key1", "value1", "key2", "value2"),
"HMSET hash key1 ? key2 ?"),
Arguments.of(
"HSET", list("hash", "key1", "value1", "key2", "value2"), "HSET hash key1 ? key2 ?"),
Arguments.of("HSETNX", list("hash", "key", "value"), "HSETNX hash key ?"),
// HyperLogLog
Arguments.of("PFADD", list("hll", "a", "b", "c"), "PFADD hll ? ? ?"),
// Keys
Arguments.of(
"MIGRATE",
list("127.0.0.1", "4242", "key", "0", "5000", "AUTH", "password"),
"MIGRATE 127.0.0.1 4242 key 0 5000 AUTH ?"),
Arguments.of("RESTORE", list("key", "42", "value"), "RESTORE key 42 ?"),
// Lists
Arguments.of(
"LINSERT", list("list", "BEFORE", "value1", "value2"), "LINSERT list BEFORE ? ?"),
Arguments.of("LPOS", list("list", "value"), "LPOS list ?"),
Arguments.of("LPUSH", list("list", "value1", "value2"), "LPUSH list ? ?"),
Arguments.of("LPUSHX", list("list", "value1", "value2"), "LPUSHX list ? ?"),
Arguments.of("LREM", list("list", "2", "value"), "LREM list ? ?"),
Arguments.of("LSET", list("list", "2", "value"), "LSET list ? ?"),
Arguments.of("RPUSH", list("list", "value1", "value2"), "RPUSH list ? ?"),
Arguments.of("RPUSHX", list("list", "value1", "value2"), "RPUSHX list ? ?"),
// Pub/Sub
Arguments.of("PUBLISH", list("channel", "message"), "PUBLISH channel ?"),
// Scripting
Arguments.of(
"EVAL", list("script", "2", "key1", "key2", "value"), "EVAL script 2 key1 key2 ?"),
Arguments.of("EVALSHA", list("sha1", "0", "value1", "value2"), "EVALSHA sha1 0 ? ?"),
// Sets),
Arguments.of("SADD", list("set", "value1", "value2"), "SADD set ? ?"),
Arguments.of("SISMEMBER", list("set", "value"), "SISMEMBER set ?"),
Arguments.of("SMISMEMBER", list("set", "value1", "value2"), "SMISMEMBER set ? ?"),
Arguments.of("SMOVE", list("set1", "set2", "value"), "SMOVE set1 set2 ?"),
Arguments.of("SREM", list("set", "value1", "value2"), "SREM set ? ?"),
// Server
Arguments.of(
"CONFIG", list("SET", "masterpassword", "password"), "CONFIG SET masterpassword ?"),
// Sorted Sets
Arguments.of("ZADD", list("sset", "1", "value1", "2", "value2"), "ZADD sset ? ? ? ?"),
Arguments.of("ZCOUNT", list("sset", "1", "10"), "ZCOUNT sset ? ?"),
Arguments.of("ZINCRBY", list("sset", "1", "value"), "ZINCRBY sset ? ?"),
Arguments.of("ZLEXCOUNT", list("sset", "1", "10"), "ZLEXCOUNT sset ? ?"),
Arguments.of("ZMSCORE", list("sset", "value1", "value2"), "ZMSCORE sset ? ?"),
Arguments.of("ZRANGEBYLEX", list("sset", "1", "10"), "ZRANGEBYLEX sset ? ?"),
Arguments.of("ZRANGEBYSCORE", list("sset", "1", "10"), "ZRANGEBYSCORE sset ? ?"),
Arguments.of("ZRANK", list("sset", "value"), "ZRANK sset ?"),
Arguments.of("ZREM", list("sset", "value1", "value2"), "ZREM sset ? ?"),
Arguments.of("ZREMRANGEBYLEX", list("sset", "1", "10"), "ZREMRANGEBYLEX sset ? ?"),
Arguments.of("ZREMRANGEBYSCORE", list("sset", "1", "10"), "ZREMRANGEBYSCORE sset ? ?"),
Arguments.of("ZREVRANGEBYLEX", list("sset", "1", "10"), "ZREVRANGEBYLEX sset ? ?"),
Arguments.of("ZREVRANGEBYSCORE", list("sset", "1", "10"), "ZREVRANGEBYSCORE sset ? ?"),
Arguments.of("ZREVRANK", list("sset", "value"), "ZREVRANK sset ?"),
Arguments.of("ZSCORE", list("sset", "value"), "ZSCORE sset ?"),
// Streams
Arguments.of(
"XADD",
list("stream", "*", "key1", "value1", "key2", "value2"),
"XADD stream * key1 ? key2 ?"),
// Strings
Arguments.of("APPEND", list("key", "value"), "APPEND key ?"),
Arguments.of("GETSET", list("key", "value"), "GETSET key ?"),
Arguments.of("MSET", list("key1", "value1", "key2", "value2"), "MSET key1 ? key2 ?"),
Arguments.of("MSETNX", list("key1", "value1", "key2", "value2"), "MSETNX key1 ? key2 ?"),
Arguments.of("PSETEX", list("key", "10000", "value"), "PSETEX key 10000 ?"),
Arguments.of("SET", list("key", "value"), "SET key ?"),
Arguments.of("SETEX", list("key", "10", "value"), "SETEX key 10 ?"),
Arguments.of("SETNX", list("key", "value"), "SETNX key ?"),
Arguments.of("SETRANGE", list("key", "42", "value"), "SETRANGE key ? ?"));
}
}
static class KeepArguments implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(
// Cluster
"CLUSTER",
"READONLY",
"READWRITE",
// Connection
"CLIENT",
"ECHO",
"PING",
"QUIT",
"SELECT",
// Geo
"GEOADD",
"GEODIST",
"GEOHASH",
"GEOPOS",
"GEORADIUS",
"GEORADIUSBYMEMBER",
// Hashes
"HDEL",
"HEXISTS",
"HGET",
"HGETALL",
"HINCRBY",
"HINCRBYFLOAT",
"HKEYS",
"HLEN",
"HMGET",
"HSCAN",
"HSTRLEN",
"HVALS",
// HyperLogLog
"PFCOUNT",
"PFMERGE",
// Keys
"DEL",
"DUMP",
"EXISTS",
"EXPIRE",
"EXPIREAT",
"KEYS",
"MOVE",
"OBJECT",
"PERSIST",
"PEXPIRE",
"PEXPIREAT",
"PTTL",
"RANDOMKEY",
"RENAME",
"RENAMENX",
"RESTORE",
"SCAN",
"SORT",
"TOUCH",
"TTL",
"TYPE",
"UNLINK",
"WAIT",
// Lists
"BLMOVE",
"BLPOP",
"BRPOP",
"BRPOPLPUSH",
"LINDEX",
"LLEN",
"LMOVE",
"LPOP",
"LRANGE",
"LTRIM",
"RPOP",
"RPOPLPUSH",
// Pub/Sub
"PSUBSCRIBE",
"PUBSUB",
"PUNSUBSCRIBE",
"SUBSCRIBE",
"UNSUBSCRIBE",
// Server
"ACL",
"BGREWRITEAOF",
"BGSAVE",
"COMMAND",
"DBSIZE",
"DEBUG",
"FLUSHALL",
"FLUSHDB",
"INFO",
"LASTSAVE",
"LATENCY",
"LOLWUT",
"MEMORY",
"MODULE",
"MONITOR",
"PSYNC",
"REPLICAOF",
"ROLE",
"SAVE",
"SHUTDOWN",
"SLAVEOF",
"SLOWLOG",
"SWAPDB",
"SYNC",
"TIME",
// Sets
"SCARD",
"SDIFF",
"SDIFFSTORE",
"SINTER",
"SINTERSTORE",
"SMEMBERS",
"SPOP",
"SRANDMEMBER",
"SSCAN",
"SUNION",
"SUNIONSTORE",
// Sorted Sets
"BZPOPMAX",
"BZPOPMIN",
"ZCARD",
"ZINTER",
"ZINTERSTORE",
"ZPOPMAX",
"ZPOPMIN",
"ZRANGE",
"ZREMRANGEBYRANK",
"ZREVRANGE",
"ZSCAN",
"ZUNION",
"ZUNIONSTORE",
// Streams
"XACK",
"XCLAIM",
"XDEL",
"XGROUP",
"XINFO",
"XLEN",
"XPENDING",
"XRANGE",
"XREAD",
"XREADGROUP",
"XREVRANGE",
"XTRIM",
// Strings
"BITCOUNT",
"BITFIELD",
"BITOP",
"BITPOS",
"DECR",
"DECRBY",
"GET",
"GETBIT",
"GETRANGE",
"INCR",
"INCRBY",
"INCRBYFLOAT",
"MGET",
"SETBIT",
"STRALGO",
"STRLEN",
// Transactions
"DISCARD",
"EXEC",
"MULTI",
"UNWATCH",
"WATCH")
.map(Arguments::of);
}
}
static List<String> list(String... args) {
return Arrays.asList(args);
}
}

View File

@ -0,0 +1,324 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/
package io.opentelemetry.instrumentation.api.db;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;
public class SqlStatementSanitizerTest {
@ParameterizedTest
@ArgumentsSource(SqlArgs.class)
void sanitizeSql(String original, String expected) {
SqlStatementInfo result = SqlStatementSanitizer.create(true).sanitize(original);
assertThat(result.getFullStatement()).isEqualTo(expected);
}
@ParameterizedTest
@ArgumentsSource(CouchbaseArgs.class)
void normalizeCouchbase(String original, String expected) {
SqlStatementInfo result =
SqlStatementSanitizer.create(true).sanitize(original, SqlDialect.COUCHBASE);
assertThat(result.getFullStatement()).isEqualTo(expected);
}
@ParameterizedTest
@ArgumentsSource(SimplifyArgs.class)
void simplifySql(String original, Function<String, SqlStatementInfo> expecter) {
SqlStatementInfo result = SqlStatementSanitizer.create(true).sanitize(original);
String expected = expecter.apply(original).getFullStatement();
assertThat(result.getFullStatement()).isEqualTo(expected);
}
@Test
void veryLongSelectStatementsAreOk() {
StringBuilder sb = new StringBuilder("SELECT * FROM table WHERE");
for (int i = 0; i < 2000; i++) {
sb.append(" column").append(i).append("=123 and");
}
String query = sb.toString();
String sanitizedQuery = query.replace("=123", "=?").substring(0, AutoSqlSanitizer.LIMIT);
SqlStatementInfo expected = SqlStatementInfo.create(sanitizedQuery, "SELECT", "table");
SqlStatementInfo result = SqlStatementSanitizer.create(true).sanitize(query);
assertThat(result).isEqualTo(expected);
}
@Test
void lotsOfTicksDontCauseStackOverflowOrLongRuntimes() {
String s = "'";
SqlStatementSanitizer sanitizer = SqlStatementSanitizer.create(true);
for (int i = 0; i < 10000; i++) {
assertThat(sanitizer.sanitize(s)).isNotNull();
s += "'";
}
}
@Test
void veryLongNumbersAreOk() {
String s = "";
for (int i = 0; i < 10000; i++) {
s += String.valueOf(i);
}
SqlStatementInfo result = SqlStatementSanitizer.create(true).sanitize(s);
assertThat(result.getFullStatement()).isEqualTo("?");
}
@Test
void veryLongNumbersAtEndOfTableAreOk() {
String s = "A";
for (int i = 0; i < 10000; i++) {
s += String.valueOf(i);
}
SqlStatementInfo result = SqlStatementSanitizer.create(true).sanitize(s);
assertThat(result.getFullStatement()).isEqualTo(s.substring(0, AutoSqlSanitizer.LIMIT));
}
@Test
void test32kTruncation() {
StringBuffer s = new StringBuffer();
for (int i = 0; i < 10000; i++) {
s.append("SELECT * FROM TABLE WHERE FIELD = 1234 AND ");
}
String sanitized = SqlStatementSanitizer.create(true).sanitize(s.toString()).getFullStatement();
assertThat(sanitized.length()).isLessThanOrEqualTo(AutoSqlSanitizer.LIMIT);
assertThat(sanitized).doesNotContain("1234");
}
@Test
void randomBytesDontCauseExceptionsOrTimeouts() {
Random r = new Random(0);
for (int i = 0; i < 1000; i++) {
StringBuffer sb = new StringBuffer();
for (int c = 0; c < 1000; c++) {
sb.append((char) r.nextInt(Character.MAX_VALUE));
}
SqlStatementSanitizer.create(true).sanitize(sb.toString());
}
}
static class SqlArgs implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(
Arguments.of("SELECT * FROM TABLE WHERE FIELD=1234", "SELECT * FROM TABLE WHERE FIELD=?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = 1234", "SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD>=-1234", "SELECT * FROM TABLE WHERE FIELD>=?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD<-1234", "SELECT * FROM TABLE WHERE FIELD<?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD <.1234", "SELECT * FROM TABLE WHERE FIELD <?"),
Arguments.of("SELECT 1.2", "SELECT ?"),
Arguments.of("SELECT -1.2", "SELECT ?"),
Arguments.of("SELECT -1.2e-9", "SELECT ?"),
Arguments.of("SELECT 2E+9", "SELECT ?"),
Arguments.of("SELECT +0.2", "SELECT ?"),
Arguments.of("SELECT .2", "SELECT ?"),
Arguments.of("7", "?"),
Arguments.of(".7", "?"),
Arguments.of("-7", "?"),
Arguments.of("+7", "?"),
Arguments.of("SELECT 0x0af764", "SELECT ?"),
Arguments.of("SELECT 0xdeadBEEF", "SELECT ?"),
Arguments.of("SELECT * FROM \"TABLE\"", "SELECT * FROM \"TABLE\""),
// Not numbers but could be confused as such
Arguments.of("SELECT A + B", "SELECT A + B"),
Arguments.of("SELECT -- comment", "SELECT -- comment"),
Arguments.of("SELECT * FROM TABLE123", "SELECT * FROM TABLE123"),
Arguments.of(
"SELECT FIELD2 FROM TABLE_123 WHERE X<>7", "SELECT FIELD2 FROM TABLE_123 WHERE X<>?"),
// Semi-nonsensical almost-numbers to elide or not
Arguments.of("SELECT --83--...--8e+76e3E-1", "SELECT ?"),
Arguments.of("SELECT DEADBEEF", "SELECT DEADBEEF"),
Arguments.of("SELECT 123-45-6789", "SELECT ?"),
Arguments.of("SELECT 1/2/34", "SELECT ?/?/?"),
// Basic ' strings
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = ''", "SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = 'words and spaces'",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = ' an escaped '' quote mark inside'",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = '\\\\'", "SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = '\"inside doubles\"'",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = '\"$$$$\"'",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = 'a single \" doublequote inside'",
"SELECT * FROM TABLE WHERE FIELD = ?"),
// Some databases allow using dollar-quoted strings
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = $$$$", "SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = $$words and spaces$$",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = $$quotes '\" inside$$",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = $$\"''\"$$",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = $$\\\\$$", "SELECT * FROM TABLE WHERE FIELD = ?"),
// Unicode, including a unicode identifier with a trailing number
Arguments.of(
"SELECT * FROM TABLEओ7 WHERE FIELD = 'ɣ'", "SELECT * FROM TABLEओ7 WHERE FIELD = ?"),
// whitespace normalization
Arguments.of(
"SELECT * \t\r\nFROM TABLE WHERE FIELD1 = 12344 AND FIELD2 = 5678",
"SELECT * FROM TABLE WHERE FIELD1 = ? AND FIELD2 = ?"),
// hibernate/jpa query language
Arguments.of("FROM TABLE WHERE FIELD=1234", "FROM TABLE WHERE FIELD=?"));
}
}
static class CouchbaseArgs implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(
// Some databases support/encourage " instead of ' with same escape rules
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = \"\"", "SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = \"words and spaces'\"",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = \" an escaped \"\" quote mark inside\"",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = \"\\\\\"", "SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = \"'inside singles'\"",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = \"'$$$$'\"",
"SELECT * FROM TABLE WHERE FIELD = ?"),
Arguments.of(
"SELECT * FROM TABLE WHERE FIELD = \"a single ' singlequote inside\"",
"SELECT * FROM TABLE WHERE FIELD = ?"));
}
}
static class SimplifyArgs implements ArgumentsProvider {
static Function<String, SqlStatementInfo> expect(String operation, String table) {
return sql -> SqlStatementInfo.create(sql, operation, table);
}
static Function<String, SqlStatementInfo> expect(String sql, String operation, String table) {
return ignored -> SqlStatementInfo.create(sql, operation, table);
}
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(
// Select
Arguments.of("SELECT x, y, z FROM schema.table", expect("SELECT", "schema.table")),
Arguments.of("SELECT x, y, z FROM `schema table`", expect("SELECT", "schema table")),
Arguments.of("SELECT x, y, z FROM \"schema table\"", expect("SELECT", "schema table")),
Arguments.of(
"WITH subquery as (select a from b) SELECT x, y, z FROM table",
expect("SELECT", null)),
Arguments.of("SELECT x, y, (select a from b) as z FROM table", expect("SELECT", null)),
Arguments.of(
"select delete, insert into, merge, update from table", expect("SELECT", "table")),
Arguments.of("select col /* from table2 */ from table", expect("SELECT", "table")),
Arguments.of("select col from table join anotherTable", expect("SELECT", null)),
Arguments.of("select col from (select * from anotherTable)", expect("SELECT", null)),
Arguments.of(
"select col from (select * from anotherTable) alias", expect("SELECT", null)),
Arguments.of(
"select col from table1 union select col from table2", expect("SELECT", null)),
Arguments.of(
"select col from table where col in (select * from anotherTable)",
expect("SELECT", null)),
Arguments.of("select col from table1, table2", expect("SELECT", null)),
Arguments.of("select col from table1 t1, table2 t2", expect("SELECT", null)),
Arguments.of("select col from table1 as t1, table2 as t2", expect("SELECT", null)),
Arguments.of(
"select col from table where col in (1, 2, 3)",
expect("select col from table where col in (?, ?, ?)", "SELECT", "table")),
Arguments.of("select col from table order by col, col2", expect("SELECT", "table")),
Arguments.of("select ąś∂ń© from źćļńĶ order by col, col2", expect("SELECT", "źćļńĶ")),
Arguments.of("select 12345678", expect("select ?", "SELECT", null)),
Arguments.of("/* update comment */ select * from table1", expect("SELECT", "table1")),
Arguments.of("select /*((*/abc from table", expect("SELECT", "table")),
Arguments.of("SeLeCT * FrOm TAblE", expect("SELECT", "table")),
// hibernate/jpa
Arguments.of("FROM schema.table", expect("SELECT", "schema.table")),
Arguments.of("/* update comment */ from table1", expect("SELECT", "table1")),
// Insert
Arguments.of(" insert into table where lalala", expect("INSERT", "table")),
Arguments.of("insert insert into table where lalala", expect("INSERT", "table")),
Arguments.of("insert into db.table where lalala", expect("INSERT", "db.table")),
Arguments.of("insert into `db table` where lalala", expect("INSERT", "db table")),
Arguments.of("insert into \"db table\" where lalala", expect("INSERT", "db table")),
Arguments.of("insert without i-n-t-o", expect("INSERT", null)),
// Delete
Arguments.of("delete from table where something something", expect("DELETE", "table")),
Arguments.of(
"delete from `my table` where something something", expect("DELETE", "my table")),
Arguments.of(
"delete from \"my table\" where something something", expect("DELETE", "my table")),
Arguments.of("delete from 12345678", expect("delete from ?", "DELETE", null)),
Arguments.of("delete (((", expect("delete (((", "DELETE", null)),
// Update
Arguments.of(
"update table set answer=42", expect("update table set answer=?", "UPDATE", "table")),
Arguments.of(
"update `my table` set answer=42",
expect("update `my table` set answer=?", "UPDATE", "my table")),
Arguments.of(
"update \"my table\" set answer=42",
expect("update \"my table\" set answer=?", "UPDATE", "my table")),
Arguments.of("update /*table", expect("UPDATE", null)),
// Merge
Arguments.of("merge into table", expect("MERGE", "table")),
Arguments.of("merge into `my table`", expect("MERGE", "my table")),
Arguments.of("merge into \"my table\"", expect("MERGE", "my table")),
Arguments.of("merge table (into is optional in some dbs)", expect("MERGE", "table")),
Arguments.of("merge (into )))", expect("MERGE", null)),
// Unknown operation
Arguments.of("and now for something completely different", expect(null, null)),
Arguments.of("", expect(null, null)),
Arguments.of(null, expect(null, null)));
}
}
}