aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSina Ghaderi <32870524+Sina-Ghaderi@users.noreply.github.com>2021-03-19 08:06:20 +0330
committerGitHub <noreply@github.com>2021-03-19 08:06:20 +0330
commitb8bb371ca118aa63785adffd98bb24cc1113f1c5 (patch)
tree00bab7b7194f92bee0e831958e34c099efbf80ff
parente16ec4f65419aefa0bb401d6af2e30be0d6bfa25 (diff)
pushing to git services
-rw-r--r--go.mod3
-rw-r--r--main.go127
-rw-r--r--mysql.go388
-rw-r--r--vendor/config/config.go50
-rw-r--r--vendor/datax/datax.go166
-rw-r--r--vendor/mysql/auth.go425
-rw-r--r--vendor/mysql/auth_test.go1330
-rw-r--r--vendor/mysql/benchmark_test.go374
-rw-r--r--vendor/mysql/buffer.go182
-rw-r--r--vendor/mysql/collations.go265
-rw-r--r--vendor/mysql/conncheck.go54
-rw-r--r--vendor/mysql/conncheck_dummy.go17
-rw-r--r--vendor/mysql/conncheck_test.go38
-rw-r--r--vendor/mysql/connection.go650
-rw-r--r--vendor/mysql/connection_test.go203
-rw-r--r--vendor/mysql/connector.go146
-rw-r--r--vendor/mysql/connector_test.go30
-rw-r--r--vendor/mysql/const.go174
-rw-r--r--vendor/mysql/driver.go107
-rw-r--r--vendor/mysql/driver_test.go3211
-rw-r--r--vendor/mysql/dsn.go560
-rw-r--r--vendor/mysql/dsn_test.go415
-rw-r--r--vendor/mysql/errors.go65
-rw-r--r--vendor/mysql/errors_test.go42
-rw-r--r--vendor/mysql/fields.go194
-rw-r--r--vendor/mysql/fuzz.go24
-rw-r--r--vendor/mysql/infile.go182
-rw-r--r--vendor/mysql/nulltime.go50
-rw-r--r--vendor/mysql/nulltime_go113.go40
-rw-r--r--vendor/mysql/nulltime_legacy.go39
-rw-r--r--vendor/mysql/nulltime_test.go62
-rw-r--r--vendor/mysql/packets.go1349
-rw-r--r--vendor/mysql/packets_test.go336
-rw-r--r--vendor/mysql/result.go22
-rw-r--r--vendor/mysql/rows.go223
-rw-r--r--vendor/mysql/statement.go220
-rw-r--r--vendor/mysql/statement_test.go151
-rw-r--r--vendor/mysql/transaction.go31
-rw-r--r--vendor/mysql/utils.go868
-rw-r--r--vendor/mysql/utils_test.go508
-rw-r--r--vendor/syscon/syscon.go203
-rw-r--r--vendor/syslog/logger.go38
42 files changed, 13562 insertions, 0 deletions
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..d77aa27
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,3 @@
+module fixrate
+
+go 1.15
diff --git a/main.go b/main.go
new file mode 100644
index 0000000..8b9c94c
--- /dev/null
+++ b/main.go
@@ -0,0 +1,127 @@
+package main
+
+import (
+ "config"
+ "datax"
+ "flag"
+ "fmt"
+ "os"
+ "runtime"
+ "syscon"
+ "syslog"
+ "time"
+)
+
+func main() {
+ if runtime.GOOS != "linux" {
+ fmt.Println("this application is supposed to run on linux, there is nothing to do in here... \nexiting with status 0")
+ return
+ }
+ defer syslog.HandlePan()
+
+ fixDaemon := flag.NewFlagSet("daemon", flag.ExitOnError)
+ confPath := fixDaemon.String("config", "fixrate.conf", "gnu config file for fixrate service")
+ fixDaemon.Usage = fixRateUsage
+ fixUser := flag.NewFlagSet("users", flag.ExitOnError)
+ fixUser.Usage = addUserUsage
+ confPathforUser := fixUser.String("config", "fixrate.conf", "gnu config file for fixrate service")
+ userName := fixUser.String("username", "sina@snix.ir", "a username to add/modify in database")
+ userRate := fixUser.Int("userrate", 10, "how many e-mails user should be able to send")
+ userReset := fixUser.Int("counter", 120, "time interval between user counter reset")
+ flag.Usage = globUsage
+ if len(os.Args) < 2 {
+ fmt.Println("expected 'daemon' or 'users' commands")
+ flag.Usage()
+ os.Exit(1)
+ }
+ switch os.Args[1] {
+ case "daemon":
+ fixDaemon.Parse(os.Args[2:])
+ conf := config.ReadConfigFromFile(*confPath)
+ syscon.StartNewService(conf)
+ case "users":
+ fixUser.Parse(os.Args[2:])
+ if len(os.Args[2:]) == 0 {
+ fmt.Println("expect more options -- use --h for information")
+ fixUser.Usage()
+ os.Exit(1)
+ }
+ conf := config.ReadConfigFromFile(*confPathforUser)
+ if err := datax.SQLInit(conf); err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+ defer datax.DBClose()
+ temp := &datax.UserAccount{
+ Username: *userName,
+ UserType: true,
+ LastReset: time.Now(),
+ Limit: *userRate,
+ Reset: *userReset,
+ }
+ if err := datax.CreateNewUser(temp); err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+ fmt.Printf("\033[32minfo:\033[0m username: %v with %v e-mails per %v seconds added to database.\n", *userName, *userRate, *userReset)
+ default:
+ fmt.Println("expected 'daemon' or 'users' commands")
+ flag.Usage()
+ os.Exit(1)
+ }
+}
+func globUsage() {
+ fmt.Printf(`usage of fixrate postfix module snix.ir LLC:
+%v commands... [ OPTIONS ] ...
+
+commands:
+ daemon starting fixrate daemon, should be used by systemd
+ --config pass a file to read configuration from. default: ./fixrate.conf
+
+ users add or modify users and attributes in database
+ --config pass a file to read configuration from. default: ./fixrate.conf
+ --username a username to add/modify in database. default is sina@snix.ir
+ --counter time interval (seconds) between user counter reset. default is 120
+ --userrate how many e-mails user should be able to send. default is 10
+
+example:
+---- adding name@domain.com ---- 100 e-mail per 10 minutes:
+%v users --username name@domain.com -- userrate 100 --counter 600
+
+Copyright (c) 2021 git.snix.ir, All rights reserved.
+Developed BY sina@snix.ir --> Sina Ghaderi
+This work is licensed under the terms of the MIT license.
+`, os.Args[0], os.Args[0])
+}
+
+func fixRateUsage() {
+ fmt.Printf(`usage of fixrate postfix module snix.ir LLC:
+%v daemon [ OPTIONS ] ...
+
+options:
+ --config pass a file to read configuration from. default: ./fixrate.conf
+ --h print this banner and exit.
+
+Copyright (c) 2021 git.snix.ir, All rights reserved.
+Developed BY sina@snix.ir --> Sina Ghaderi
+This work is licensed under the terms of the MIT license.
+`, os.Args[0])
+}
+
+func addUserUsage() {
+ fmt.Printf(`usage of fixrate postfix module snix.ir LLC:
+%v users [ OPTIONS ] ...
+
+options:
+ --config pass a file to read configuration from. default: ./fixrate.conf
+ --username a username to add/modify in database. default is sina@snix.ir
+ --counter time interval (seconds) between user counter reset. default is 120
+ --userrate how many e-mails user should be able to send. default is 10
+
+example:
+---- adding name@domain.com ---- 100 e-mail per 10 minutes:
+%v users --username name@domain.com -- userrate 100 --counter 600
+
+Copyright (c) 2021 git.snix.ir, All rights reserved.
+Developed BY sina@snix.ir --> Sina Ghaderi
+This work is licensed under the terms of the MIT license.
+`, os.Args[0], os.Args[0])
+}
diff --git a/mysql.go b/mysql.go
new file mode 100644
index 0000000..350b3fe
--- /dev/null
+++ b/mysql.go
@@ -0,0 +1,388 @@
+package main
+
+import (
+ // this is github.com/go-sql-driver/mysql driver package placed in vendor/mysql
+ //
+ //
+ //
+ //
+ //
+ // // // // // // // //
+
+ /*
+ Mozilla Public License Version 2.0
+ ==================================
+
+ 1. Definitions
+ --------------
+
+ 1.1. "Contributor"
+ means each individual or legal entity that creates, contributes to
+ the creation of, or owns Covered Software.
+
+ 1.2. "Contributor Version"
+ means the combination of the Contributions of others (if any) used
+ by a Contributor and that particular Contributor's Contribution.
+
+ 1.3. "Contribution"
+ means Covered Software of a particular Contributor.
+
+ 1.4. "Covered Software"
+ means Source Code Form to which the initial Contributor has attached
+ the notice in Exhibit A, the Executable Form of such Source Code
+ Form, and Modifications of such Source Code Form, in each case
+ including portions thereof.
+
+ 1.5. "Incompatible With Secondary Licenses"
+ means
+
+ (a) that the initial Contributor has attached the notice described
+ in Exhibit B to the Covered Software; or
+
+ (b) that the Covered Software was made available under the terms of
+ version 1.1 or earlier of the License, but not also under the
+ terms of a Secondary License.
+
+ 1.6. "Executable Form"
+ means any form of the work other than Source Code Form.
+
+ 1.7. "Larger Work"
+ means a work that combines Covered Software with other material, in
+ a separate file or files, that is not Covered Software.
+
+ 1.8. "License"
+ means this document.
+
+ 1.9. "Licensable"
+ means having the right to grant, to the maximum extent possible,
+ whether at the time of the initial grant or subsequently, any and
+ all of the rights conveyed by this License.
+
+ 1.10. "Modifications"
+ means any of the following:
+
+ (a) any file in Source Code Form that results from an addition to,
+ deletion from, or modification of the contents of Covered
+ Software; or
+
+ (b) any new file in Source Code Form that contains any Covered
+ Software.
+
+ 1.11. "Patent Claims" of a Contributor
+ means any patent claim(s), including without limitation, method,
+ process, and apparatus claims, in any patent Licensable by such
+ Contributor that would be infringed, but for the grant of the
+ License, by the making, using, selling, offering for sale, having
+ made, import, or transfer of either its Contributions or its
+ Contributor Version.
+
+ 1.12. "Secondary License"
+ means either the GNU General Public License, Version 2.0, the GNU
+ Lesser General Public License, Version 2.1, the GNU Affero General
+ Public License, Version 3.0, or any later versions of those
+ licenses.
+
+ 1.13. "Source Code Form"
+ means the form of the work preferred for making modifications.
+
+ 1.14. "You" (or "Your")
+ means an individual or a legal entity exercising rights under this
+ License. For legal entities, "You" includes any entity that
+ controls, is controlled by, or is under common control with You. For
+ purposes of this definition, "control" means (a) the power, direct
+ or indirect, to cause the direction or management of such entity,
+ whether by contract or otherwise, or (b) ownership of more than
+ fifty percent (50%) of the outstanding shares or beneficial
+ ownership of such entity.
+
+ 2. License Grants and Conditions
+ --------------------------------
+
+ 2.1. Grants
+
+ Each Contributor hereby grants You a world-wide, royalty-free,
+ non-exclusive license:
+
+ (a) under intellectual property rights (other than patent or trademark)
+ Licensable by such Contributor to use, reproduce, make available,
+ modify, display, perform, distribute, and otherwise exploit its
+ Contributions, either on an unmodified basis, with Modifications, or
+ as part of a Larger Work; and
+
+ (b) under Patent Claims of such Contributor to make, use, sell, offer
+ for sale, have made, import, and otherwise transfer either its
+ Contributions or its Contributor Version.
+
+ 2.2. Effective Date
+
+ The licenses granted in Section 2.1 with respect to any Contribution
+ become effective for each Contribution on the date the Contributor first
+ distributes such Contribution.
+
+ 2.3. Limitations on Grant Scope
+
+ The licenses granted in this Section 2 are the only rights granted under
+ this License. No additional rights or licenses will be implied from the
+ distribution or licensing of Covered Software under this License.
+ Notwithstanding Section 2.1(b) above, no patent license is granted by a
+ Contributor:
+
+ (a) for any code that a Contributor has removed from Covered Software;
+ or
+
+ (b) for infringements caused by: (i) Your and any other third party's
+ modifications of Covered Software, or (ii) the combination of its
+ Contributions with other software (except as part of its Contributor
+ Version); or
+
+ (c) under Patent Claims infringed by Covered Software in the absence of
+ its Contributions.
+
+ This License does not grant any rights in the trademarks, service marks,
+ or logos of any Contributor (except as may be necessary to comply with
+ the notice requirements in Section 3.4).
+
+ 2.4. Subsequent Licenses
+
+ No Contributor makes additional grants as a result of Your choice to
+ distribute the Covered Software under a subsequent version of this
+ License (see Section 10.2) or under the terms of a Secondary License (if
+ permitted under the terms of Section 3.3).
+
+ 2.5. Representation
+
+ Each Contributor represents that the Contributor believes its
+ Contributions are its original creation(s) or it has sufficient rights
+ to grant the rights to its Contributions conveyed by this License.
+
+ 2.6. Fair Use
+
+ This License is not intended to limit any rights You have under
+ applicable copyright doctrines of fair use, fair dealing, or other
+ equivalents.
+
+ 2.7. Conditions
+
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
+ in Section 2.1.
+
+ 3. Responsibilities
+ -------------------
+
+ 3.1. Distribution of Source Form
+
+ All distribution of Covered Software in Source Code Form, including any
+ Modifications that You create or to which You contribute, must be under
+ the terms of this License. You must inform recipients that the Source
+ Code Form of the Covered Software is governed by the terms of this
+ License, and how they can obtain a copy of this License. You may not
+ attempt to alter or restrict the recipients' rights in the Source Code
+ Form.
+
+ 3.2. Distribution of Executable Form
+
+ If You distribute Covered Software in Executable Form then:
+
+ (a) such Covered Software must also be made available in Source Code
+ Form, as described in Section 3.1, and You must inform recipients of
+ the Executable Form how they can obtain a copy of such Source Code
+ Form by reasonable means in a timely manner, at a charge no more
+ than the cost of distribution to the recipient; and
+
+ (b) You may distribute such Executable Form under the terms of this
+ License, or sublicense it under different terms, provided that the
+ license for the Executable Form does not attempt to limit or alter
+ the recipients' rights in the Source Code Form under this License.
+
+ 3.3. Distribution of a Larger Work
+
+ You may create and distribute a Larger Work under terms of Your choice,
+ provided that You also comply with the requirements of this License for
+ the Covered Software. If the Larger Work is a combination of Covered
+ Software with a work governed by one or more Secondary Licenses, and the
+ Covered Software is not Incompatible With Secondary Licenses, this
+ License permits You to additionally distribute such Covered Software
+ under the terms of such Secondary License(s), so that the recipient of
+ the Larger Work may, at their option, further distribute the Covered
+ Software under the terms of either this License or such Secondary
+ License(s).
+
+ 3.4. Notices
+
+ You may not remove or alter the substance of any license notices
+ (including copyright notices, patent notices, disclaimers of warranty,
+ or limitations of liability) contained within the Source Code Form of
+ the Covered Software, except that You may alter any license notices to
+ the extent required to remedy known factual inaccuracies.
+
+ 3.5. Application of Additional Terms
+
+ You may choose to offer, and to charge a fee for, warranty, support,
+ indemnity or liability obligations to one or more recipients of Covered
+ Software. However, You may do so only on Your own behalf, and not on
+ behalf of any Contributor. You must make it absolutely clear that any
+ such warranty, support, indemnity, or liability obligation is offered by
+ You alone, and You hereby agree to indemnify every Contributor for any
+ liability incurred by such Contributor as a result of warranty, support,
+ indemnity or liability terms You offer. You may include additional
+ disclaimers of warranty and limitations of liability specific to any
+ jurisdiction.
+
+ 4. Inability to Comply Due to Statute or Regulation
+ ---------------------------------------------------
+
+ If it is impossible for You to comply with any of the terms of this
+ License with respect to some or all of the Covered Software due to
+ statute, judicial order, or regulation then You must: (a) comply with
+ the terms of this License to the maximum extent possible; and (b)
+ describe the limitations and the code they affect. Such description must
+ be placed in a text file included with all distributions of the Covered
+ Software under this License. Except to the extent prohibited by statute
+ or regulation, such description must be sufficiently detailed for a
+ recipient of ordinary skill to be able to understand it.
+
+ 5. Termination
+ --------------
+
+ 5.1. The rights granted under this License will terminate automatically
+ if You fail to comply with any of its terms. However, if You become
+ compliant, then the rights granted under this License from a particular
+ Contributor are reinstated (a) provisionally, unless and until such
+ Contributor explicitly and finally terminates Your grants, and (b) on an
+ ongoing basis, if such Contributor fails to notify You of the
+ non-compliance by some reasonable means prior to 60 days after You have
+ come back into compliance. Moreover, Your grants from a particular
+ Contributor are reinstated on an ongoing basis if such Contributor
+ notifies You of the non-compliance by some reasonable means, this is the
+ first time You have received notice of non-compliance with this License
+ from such Contributor, and You become compliant prior to 30 days after
+ Your receipt of the notice.
+
+ 5.2. If You initiate litigation against any entity by asserting a patent
+ infringement claim (excluding declaratory judgment actions,
+ counter-claims, and cross-claims) alleging that a Contributor Version
+ directly or indirectly infringes any patent, then the rights granted to
+ You by any and all Contributors for the Covered Software under Section
+ 2.1 of this License shall terminate.
+
+ 5.3. In the event of termination under Sections 5.1 or 5.2 above, all
+ end user license agreements (excluding distributors and resellers) which
+ have been validly granted by You or Your distributors under this License
+ prior to termination shall survive termination.
+
+ ************************************************************************
+ * *
+ * 6. Disclaimer of Warranty *
+ * ------------------------- *
+ * *
+ * Covered Software is provided under this License on an "as is" *
+ * basis, without warranty of any kind, either expressed, implied, or *
+ * statutory, including, without limitation, warranties that the *
+ * Covered Software is free of defects, merchantable, fit for a *
+ * particular purpose or non-infringing. The entire risk as to the *
+ * quality and performance of the Covered Software is with You. *
+ * Should any Covered Software prove defective in any respect, You *
+ * (not any Contributor) assume the cost of any necessary servicing, *
+ * repair, or correction. This disclaimer of warranty constitutes an *
+ * essential part of this License. No use of any Covered Software is *
+ * authorized under this License except under this disclaimer. *
+ * *
+ ************************************************************************
+
+ ************************************************************************
+ * *
+ * 7. Limitation of Liability *
+ * -------------------------- *
+ * *
+ * Under no circumstances and under no legal theory, whether tort *
+ * (including negligence), contract, or otherwise, shall any *
+ * Contributor, or anyone who distributes Covered Software as *
+ * permitted above, be liable to You for any direct, indirect, *
+ * special, incidental, or consequential damages of any character *
+ * including, without limitation, damages for lost profits, loss of *
+ * goodwill, work stoppage, computer failure or malfunction, or any *
+ * and all other commercial damages or losses, even if such party *
+ * shall have been informed of the possibility of such damages. This *
+ * limitation of liability shall not apply to liability for death or *
+ * personal injury resulting from such party's negligence to the *
+ * extent applicable law prohibits such limitation. Some *
+ * jurisdictions do not allow the exclusion or limitation of *
+ * incidental or consequential damages, so this exclusion and *
+ * limitation may not apply to You. *
+ * *
+ ************************************************************************
+
+ 8. Litigation
+ -------------
+
+ Any litigation relating to this License may be brought only in the
+ courts of a jurisdiction where the defendant maintains its principal
+ place of business and such litigation shall be governed by laws of that
+ jurisdiction, without reference to its conflict-of-law provisions.
+ Nothing in this Section shall prevent a party's ability to bring
+ cross-claims or counter-claims.
+
+ 9. Miscellaneous
+ ----------------
+
+ This License represents the complete agreement concerning the subject
+ matter hereof. If any provision of this License is held to be
+ unenforceable, such provision shall be reformed only to the extent
+ necessary to make it enforceable. Any law or regulation which provides
+ that the language of a contract shall be construed against the drafter
+ shall not be used to construe this License against a Contributor.
+
+ 10. Versions of the License
+ ---------------------------
+
+ 10.1. New Versions
+
+ Mozilla Foundation is the license steward. Except as provided in Section
+ 10.3, no one other than the license steward has the right to modify or
+ publish new versions of this License. Each version will be given a
+ distinguishing version number.
+
+ 10.2. Effect of New Versions
+
+ You may distribute the Covered Software under the terms of the version
+ of the License under which You originally received the Covered Software,
+ or under the terms of any subsequent version published by the license
+ steward.
+
+ 10.3. Modified Versions
+
+ If you create software not governed by this License, and you want to
+ create a new license for such software, you may create and use a
+ modified version of this License if you rename the license and remove
+ any references to the name of the license steward (except to note that
+ such modified license differs from this License).
+
+ 10.4. Distributing Source Code Form that is Incompatible With Secondary
+ Licenses
+
+ If You choose to distribute Source Code Form that is Incompatible With
+ Secondary Licenses under the terms of this version of the License, the
+ notice described in Exhibit B of this License must be attached.
+
+ Exhibit A - Source Code Form License Notice
+ -------------------------------------------
+
+ This Source Code Form is subject to the terms of the Mozilla Public
+ License, v. 2.0. If a copy of the MPL was not distributed with this
+ file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+ If it is not possible or desirable to put the notice in a particular
+ file, then You may include the notice in a location (such as a LICENSE
+ file in a relevant directory) where a recipient would be likely to look
+ for such a notice.
+
+ You may add additional accurate notices of copyright ownership.
+
+ Exhibit B - "Incompatible With Secondary Licenses" Notice
+ ---------------------------------------------------------
+
+ This Source Code Form is "Incompatible With Secondary Licenses", as
+ defined by the Mozilla Public License, v. 2.0. */
+
+ _ "mysql"
+)
diff --git a/vendor/config/config.go b/vendor/config/config.go
new file mode 100644
index 0000000..525a120
--- /dev/null
+++ b/vendor/config/config.go
@@ -0,0 +1,50 @@
+package config
+
+import (
+ "bufio"
+ "fmt"
+ "os"
+ "regexp"
+ "strings"
+ "syslog"
+)
+
+const literalRegex string = `^(\s*)((?i)(socket_path|sql_database|sql_username|sql_address|sql_tcpport|sql_password|default_ratelimit|socket_perm|listener_type|listen_addr)(?i))(\s*)=`
+
+type Config struct{}
+
+var mapconf = make(map[string][]string)
+
+func ReadConfigFromFile(path string) *Config {
+ file, err := os.Open(path)
+ if err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+ defer file.Close()
+ re := regexp.MustCompile(literalRegex)
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ if find := re.MatchString(scanner.Text()); find {
+ literal := re.FindString(scanner.Text())
+ mapconf[strings.TrimSpace(literal[:len(literal)-1])] = strings.Fields(strings.Split(scanner.Text(), literal)[1])
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+ return &Config{}
+}
+
+func (*Config) GetConf(str string, index int) string {
+ if args, owkey := mapconf[str]; owkey {
+ if len(args) == 0 {
+ panic(syslog.BigError{Why: fmt.Errorf("literal %v have no args in config file", str), Cod: 1})
+ }
+ if index < len(args) {
+ return args[index]
+ }
+
+ panic(syslog.BigError{Why: fmt.Errorf("not enough args in literal %v, can not reach arg %v, it's empty", str, index), Cod: 1})
+ }
+ panic(syslog.BigError{Why: fmt.Errorf("literal %v can not be found in config file", str), Cod: 1})
+}
diff --git a/vendor/datax/datax.go b/vendor/datax/datax.go
new file mode 100644
index 0000000..c07f8e8
--- /dev/null
+++ b/vendor/datax/datax.go
@@ -0,0 +1,166 @@
+package datax
+
+import (
+ "config"
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "strconv"
+ "syslog"
+ "time"
+)
+
+type UserAccount struct {
+ Username string
+ Limit, Counter, Reset int
+ LastReset time.Time
+ UserType bool
+}
+
+var (
+ DefaultUserTmp UserAccount
+ SystemDatabase *sql.DB
+)
+
+func SQLInit(file *config.Config) error {
+ dabase := file.GetConf("sql_database", 0)
+ addr := fmt.Sprintf("(%v:%v)", file.GetConf("sql_address", 0), file.GetConf("sql_tcpport", 0))
+ db, err := sql.Open("mysql", file.GetConf("sql_username", 0)+":"+file.GetConf("sql_password", 0)+"@tcp"+addr+"/"+dabase+"?parseTime=true")
+ if err != nil {
+ return err
+ }
+
+ limit, err := strconv.Atoi(file.GetConf("default_ratelimit", 0))
+ if err != nil {
+ panic(syslog.BigError{Why: errors.New("can not use non integer type (default_ratelimit)"), Cod: 1})
+ }
+ reset, err := strconv.Atoi(file.GetConf("default_ratelimit", 1))
+ if err != nil {
+ panic(syslog.BigError{Why: errors.New("can not use non integer type (default_ratelimit)"), Cod: 1})
+ }
+ DefaultUserTmp = UserAccount{
+ Limit: limit,
+ Reset: reset,
+ }
+
+ ctx, cancelfunc := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancelfunc()
+
+ sqlTable := `
+ CREATE TABLE IF NOT EXISTS fixrate(
+ username VARCHAR(512) NOT NULL PRIMARY KEY,
+ limitt INT,
+ counter INT,
+ reset INT,
+ lastreset DATETIME,
+ usertype BOOL
+ );
+ `
+ if _, err := db.ExecContext(ctx, sqlTable); err != nil {
+ return err
+ }
+ SystemDatabase = db
+ if err := updateDefaultFixRates(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func CreateNewUser(item *UserAccount) error {
+ ctx, cancelfunc := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancelfunc()
+ sqlAdditem := `
+ REPLACE INTO fixrate(
+ username,
+ limitt,
+ counter,
+ reset,
+ lastreset,
+ usertype
+ ) values(?, ?, ?, ?, ?, ?)
+ `
+ stmt, err := SystemDatabase.Prepare(sqlAdditem)
+ defer stmt.Close()
+
+ if err != nil {
+ return err
+ }
+ _, err = stmt.ExecContext(ctx, item.Username, item.Limit, item.Counter, item.Reset, item.LastReset, item.UserType)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func DBClose() {
+ if err := SystemDatabase.Close(); err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+}
+
+func GetUserFromDatabase(user *string) (*UserAccount, error) {
+ ctx, cancelfunc := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancelfunc()
+
+ sqlReadall := `SELECT * FROM fixrate WHERE username = ?`
+ rows := SystemDatabase.QueryRowContext(ctx, sqlReadall, user)
+ users := UserAccount{}
+ if err := rows.Scan(&users.Username, &users.Limit, &users.Counter, &users.Reset, &users.LastReset, &users.UserType); err != nil {
+ if err != sql.ErrNoRows {
+ return nil, err
+ }
+ newuser := DefaultUserTmp
+ newuser.Username = *user
+ newuser.LastReset = time.Now()
+ if err := CreateNewUser(&newuser); err != nil {
+ return nil, err
+ }
+ return &newuser, nil
+
+ }
+ return &users, nil
+}
+
+func (u *UserAccount) UpdateUserCounter(counter int) error {
+ ctx, cancelfunc := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancelfunc()
+ stmt, err := SystemDatabase.Prepare("update fixrate set counter=? where username=?")
+ if err != nil {
+ return err
+ }
+ _, err = stmt.ExecContext(ctx, counter, u.Username)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (u *UserAccount) UpdateUserLastReset(nowtime time.Time) error {
+ ctx, cancelfunc := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancelfunc()
+ stmt, err := SystemDatabase.Prepare("update fixrate set lastreset=? where username=?")
+ if err != nil {
+ return err
+ }
+ _, err = stmt.ExecContext(ctx, nowtime, u.Username)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func updateDefaultFixRates() error {
+ ctx, cancelfunc := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancelfunc()
+ stmt, err := SystemDatabase.Prepare("update fixrate set limitt=?, reset=? where usertype=?")
+ if err != nil {
+ return err
+ }
+ _, err = stmt.ExecContext(ctx, DefaultUserTmp.Limit, DefaultUserTmp.Reset, DefaultUserTmp.UserType)
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/vendor/mysql/auth.go b/vendor/mysql/auth.go
new file mode 100644
index 0000000..b2f19e8
--- /dev/null
+++ b/vendor/mysql/auth.go
@@ -0,0 +1,425 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/sha1"
+ "crypto/sha256"
+ "crypto/x509"
+ "encoding/pem"
+ "fmt"
+ "sync"
+)
+
+// server pub keys registry
+var (
+ serverPubKeyLock sync.RWMutex
+ serverPubKeyRegistry map[string]*rsa.PublicKey
+)
+
+// RegisterServerPubKey registers a server RSA public key which can be used to
+// send data in a secure manner to the server without receiving the public key
+// in a potentially insecure way from the server first.
+// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN.
+//
+// Note: The provided rsa.PublicKey instance is exclusively owned by the driver
+// after registering it and may not be modified.
+//
+// data, err := ioutil.ReadFile("mykey.pem")
+// if err != nil {
+// log.Fatal(err)
+// }
+//
+// block, _ := pem.Decode(data)
+// if block == nil || block.Type != "PUBLIC KEY" {
+// log.Fatal("failed to decode PEM block containing public key")
+// }
+//
+// pub, err := x509.ParsePKIXPublicKey(block.Bytes)
+// if err != nil {
+// log.Fatal(err)
+// }
+//
+// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
+// mysql.RegisterServerPubKey("mykey", rsaPubKey)
+// } else {
+// log.Fatal("not a RSA public key")
+// }
+//
+func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) {
+ serverPubKeyLock.Lock()
+ if serverPubKeyRegistry == nil {
+ serverPubKeyRegistry = make(map[string]*rsa.PublicKey)
+ }
+
+ serverPubKeyRegistry[name] = pubKey
+ serverPubKeyLock.Unlock()
+}
+
+// DeregisterServerPubKey removes the public key registered with the given name.
+func DeregisterServerPubKey(name string) {
+ serverPubKeyLock.Lock()
+ if serverPubKeyRegistry != nil {
+ delete(serverPubKeyRegistry, name)
+ }
+ serverPubKeyLock.Unlock()
+}
+
+func getServerPubKey(name string) (pubKey *rsa.PublicKey) {
+ serverPubKeyLock.RLock()
+ if v, ok := serverPubKeyRegistry[name]; ok {
+ pubKey = v
+ }
+ serverPubKeyLock.RUnlock()
+ return
+}
+
+// Hash password using pre 4.1 (old password) method
+// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
+type myRnd struct {
+ seed1, seed2 uint32
+}
+
+const myRndMaxVal = 0x3FFFFFFF
+
+// Pseudo random number generator
+func newMyRnd(seed1, seed2 uint32) *myRnd {
+ return &myRnd{
+ seed1: seed1 % myRndMaxVal,
+ seed2: seed2 % myRndMaxVal,
+ }
+}
+
+// Tested to be equivalent to MariaDB's floating point variant
+// http://play.golang.org/p/QHvhd4qved
+// http://play.golang.org/p/RG0q4ElWDx
+func (r *myRnd) NextByte() byte {
+ r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
+ r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
+
+ return byte(uint64(r.seed1) * 31 / myRndMaxVal)
+}
+
+// Generate binary hash from byte string using insecure pre 4.1 method
+func pwHash(password []byte) (result [2]uint32) {
+ var add uint32 = 7
+ var tmp uint32
+
+ result[0] = 1345345333
+ result[1] = 0x12345671
+
+ for _, c := range password {
+ // skip spaces and tabs in password
+ if c == ' ' || c == '\t' {
+ continue
+ }
+
+ tmp = uint32(c)
+ result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
+ result[1] += (result[1] << 8) ^ result[0]
+ add += tmp
+ }
+
+ // Remove sign bit (1<<31)-1)
+ result[0] &= 0x7FFFFFFF
+ result[1] &= 0x7FFFFFFF
+
+ return
+}
+
+// Hash password using insecure pre 4.1 method
+func scrambleOldPassword(scramble []byte, password string) []byte {
+ scramble = scramble[:8]
+
+ hashPw := pwHash([]byte(password))
+ hashSc := pwHash(scramble)
+
+ r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
+
+ var out [8]byte
+ for i := range out {
+ out[i] = r.NextByte() + 64
+ }
+
+ mask := r.NextByte()
+ for i := range out {
+ out[i] ^= mask
+ }
+
+ return out[:]
+}
+
+// Hash password using 4.1+ method (SHA1)
+func scramblePassword(scramble []byte, password string) []byte {
+ if len(password) == 0 {
+ return nil
+ }
+
+ // stage1Hash = SHA1(password)
+ crypt := sha1.New()
+ crypt.Write([]byte(password))
+ stage1 := crypt.Sum(nil)
+
+ // scrambleHash = SHA1(scramble + SHA1(stage1Hash))
+ // inner Hash
+ crypt.Reset()
+ crypt.Write(stage1)
+ hash := crypt.Sum(nil)
+
+ // outer Hash
+ crypt.Reset()
+ crypt.Write(scramble)
+ crypt.Write(hash)
+ scramble = crypt.Sum(nil)
+
+ // token = scrambleHash XOR stage1Hash
+ for i := range scramble {
+ scramble[i] ^= stage1[i]
+ }
+ return scramble
+}
+
+// Hash password using MySQL 8+ method (SHA256)
+func scrambleSHA256Password(scramble []byte, password string) []byte {
+ if len(password) == 0 {
+ return nil
+ }
+
+ // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
+
+ crypt := sha256.New()
+ crypt.Write([]byte(password))
+ message1 := crypt.Sum(nil)
+
+ crypt.Reset()
+ crypt.Write(message1)
+ message1Hash := crypt.Sum(nil)
+
+ crypt.Reset()
+ crypt.Write(message1Hash)
+ crypt.Write(scramble)
+ message2 := crypt.Sum(nil)
+
+ for i := range message1 {
+ message1[i] ^= message2[i]
+ }
+
+ return message1
+}
+
+func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) {
+ plain := make([]byte, len(password)+1)
+ copy(plain, password)
+ for i := range plain {
+ j := i % len(seed)
+ plain[i] ^= seed[j]
+ }
+ sha1 := sha1.New()
+ return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
+}
+
+func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
+ enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
+ if err != nil {
+ return err
+ }
+ return mc.writeAuthSwitchPacket(enc)
+}
+
+func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
+ switch plugin {
+ case "caching_sha2_password":
+ authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
+ return authResp, nil
+
+ case "mysql_old_password":
+ if !mc.cfg.AllowOldPasswords {
+ return nil, ErrOldPassword
+ }
+ if len(mc.cfg.Passwd) == 0 {
+ return nil, nil
+ }
+ // Note: there are edge cases where this should work but doesn't;
+ // this is currently "wontfix":
+ // https://github.com/go-sql-driver/mysql/issues/184
+ authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
+ return authResp, nil
+
+ case "mysql_clear_password":
+ if !mc.cfg.AllowCleartextPasswords {
+ return nil, ErrCleartextPassword
+ }
+ // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
+ // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
+ return append([]byte(mc.cfg.Passwd), 0), nil
+
+ case "mysql_native_password":
+ if !mc.cfg.AllowNativePasswords {
+ return nil, ErrNativePassword
+ }
+ // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
+ // Native password authentication only need and will need 20-byte challenge.
+ authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
+ return authResp, nil
+
+ case "sha256_password":
+ if len(mc.cfg.Passwd) == 0 {
+ return []byte{0}, nil
+ }
+ if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
+ // write cleartext auth packet
+ return append([]byte(mc.cfg.Passwd), 0), nil
+ }
+
+ pubKey := mc.cfg.pubKey
+ if pubKey == nil {
+ // request public key from server
+ return []byte{1}, nil
+ }
+
+ // encrypted password
+ enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
+ return enc, err
+
+ default:
+ errLog.Print("unknown auth plugin:", plugin)
+ return nil, ErrUnknownPlugin
+ }
+}
+
+func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
+ // Read Result Packet
+ authData, newPlugin, err := mc.readAuthResult()
+ if err != nil {
+ return err
+ }
+
+ // handle auth plugin switch, if requested
+ if newPlugin != "" {
+ // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
+ // sent and we have to keep using the cipher sent in the init packet.
+ if authData == nil {
+ authData = oldAuthData
+ } else {
+ // copy data from read buffer to owned slice
+ copy(oldAuthData, authData)
+ }
+
+ plugin = newPlugin
+
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ return err
+ }
+ if err = mc.writeAuthSwitchPacket(authResp); err != nil {
+ return err
+ }
+
+ // Read Result Packet
+ authData, newPlugin, err = mc.readAuthResult()
+ if err != nil {
+ return err
+ }
+
+ // Do not allow to change the auth plugin more than once
+ if newPlugin != "" {
+ return ErrMalformPkt
+ }
+ }
+
+ switch plugin {
+
+ // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
+ case "caching_sha2_password":
+ switch len(authData) {
+ case 0:
+ return nil // auth successful
+ case 1:
+ switch authData[0] {
+ case cachingSha2PasswordFastAuthSuccess:
+ if err = mc.readResultOK(); err == nil {
+ return nil // auth successful
+ }
+
+ case cachingSha2PasswordPerformFullAuthentication:
+ if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
+ // write cleartext auth packet
+ err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
+ if err != nil {
+ return err
+ }
+ } else {
+ pubKey := mc.cfg.pubKey
+ if pubKey == nil {
+ // request public key from server
+ data, err := mc.buf.takeSmallBuffer(4 + 1)
+ if err != nil {
+ return err
+ }
+ data[4] = cachingSha2PasswordRequestPublicKey
+ mc.writePacket(data)
+
+ // parse public key
+ if data, err = mc.readPacket(); err != nil {
+ return err
+ }
+
+ block, rest := pem.Decode(data[1:])
+ if block == nil {
+ return fmt.Errorf("No Pem data found, data: %s", rest)
+ }
+ pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
+ if err != nil {
+ return err
+ }
+ pubKey = pkix.(*rsa.PublicKey)
+ }
+
+ // send encrypted password
+ err = mc.sendEncryptedPassword(oldAuthData, pubKey)
+ if err != nil {
+ return err
+ }
+ }
+ return mc.readResultOK()
+
+ default:
+ return ErrMalformPkt
+ }
+ default:
+ return ErrMalformPkt
+ }
+
+ case "sha256_password":
+ switch len(authData) {
+ case 0:
+ return nil // auth successful
+ default:
+ block, _ := pem.Decode(authData)
+ pub, err := x509.ParsePKIXPublicKey(block.Bytes)
+ if err != nil {
+ return err
+ }
+
+ // send encrypted password
+ err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey))
+ if err != nil {
+ return err
+ }
+ return mc.readResultOK()
+ }
+
+ default:
+ return nil // auth successful
+ }
+
+ return err
+}
diff --git a/vendor/mysql/auth_test.go b/vendor/mysql/auth_test.go
new file mode 100644
index 0000000..3bce7fe
--- /dev/null
+++ b/vendor/mysql/auth_test.go
@@ -0,0 +1,1330 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "bytes"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/pem"
+ "fmt"
+ "testing"
+)
+
+var testPubKey = []byte("-----BEGIN PUBLIC KEY-----\n" +
+ "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAol0Z8G8U+25Btxk/g/fm\n" +
+ "UAW/wEKjQCTjkibDE4B+qkuWeiumg6miIRhtilU6m9BFmLQSy1ltYQuu4k17A4tQ\n" +
+ "rIPpOQYZges/qsDFkZh3wyK5jL5WEFVdOasf6wsfszExnPmcZS4axxoYJfiuilrN\n" +
+ "hnwinBAqfi3S0sw5MpSI4Zl1AbOrHG4zDI62Gti2PKiMGyYDZTS9xPrBLbN95Kby\n" +
+ "FFclQLEzA9RJcS1nHFsWtRgHjGPhhjCQxEm9NQ1nePFhCfBfApyfH1VM2VCOQum6\n" +
+ "Ci9bMuHWjTjckC84mzF99kOxOWVU7mwS6gnJqBzpuz8t3zq8/iQ2y7QrmZV+jTJP\n" +
+ "WQIDAQAB\n" +
+ "-----END PUBLIC KEY-----\n")
+
+var testPubKeyRSA *rsa.PublicKey
+
+func init() {
+ block, _ := pem.Decode(testPubKey)
+ pub, err := x509.ParsePKIXPublicKey(block.Bytes)
+ if err != nil {
+ panic(err)
+ }
+ testPubKeyRSA = pub.(*rsa.PublicKey)
+}
+
+func TestScrambleOldPass(t *testing.T) {
+ scramble := []byte{9, 8, 7, 6, 5, 4, 3, 2}
+ vectors := []struct {
+ pass string
+ out string
+ }{
+ {" pass", "47575c5a435b4251"},
+ {"pass ", "47575c5a435b4251"},
+ {"123\t456", "575c47505b5b5559"},
+ {"C0mpl!ca ted#PASS123", "5d5d554849584a45"},
+ }
+ for _, tuple := range vectors {
+ ours := scrambleOldPassword(scramble, tuple.pass)
+ if tuple.out != fmt.Sprintf("%x", ours) {
+ t.Errorf("Failed old password %q", tuple.pass)
+ }
+ }
+}
+
+func TestScrambleSHA256Pass(t *testing.T) {
+ scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21}
+ vectors := []struct {
+ pass string
+ out string
+ }{
+ {"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"},
+ {"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"},
+ }
+ for _, tuple := range vectors {
+ ours := scrambleSHA256Password(scramble, tuple.pass)
+ if tuple.out != fmt.Sprintf("%x", ours) {
+ t.Errorf("Failed SHA256 password %q", tuple.pass)
+ }
+ }
+}
+
+func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+
+ authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69,
+ 22, 41, 84, 32, 123, 43, 118}
+ plugin := "caching_sha2_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ expectedAuthResp := []byte{102, 32, 5, 35, 143, 161, 140, 241, 171, 232, 56,
+ 139, 43, 14, 107, 196, 249, 170, 147, 60, 220, 204, 120, 178, 214, 15,
+ 184, 150, 26, 61, 57, 235}
+ if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+ t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response
+ conn.data = []byte{
+ 2, 0, 0, 2, 1, 3, // Fast Auth Success
+ 7, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, // OK
+ }
+ conn.maxReads = 1
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+}
+
+func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = ""
+
+ authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69,
+ 22, 41, 84, 32, 123, 43, 118}
+ plugin := "caching_sha2_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ if writtenAuthRespLen != 0 {
+ t.Fatalf("unexpected written auth response (%d bytes): %v",
+ writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response
+ conn.data = []byte{
+ 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
+ }
+ conn.maxReads = 1
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+}
+
+func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+
+ authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
+ 62, 94, 83, 80, 52, 85}
+ plugin := "caching_sha2_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165,
+ 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70,
+ 110, 40, 139, 124, 41}
+ if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+ t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response
+ conn.data = []byte{
+ 2, 0, 0, 2, 1, 4, // Perform Full Authentication
+ }
+ conn.queuedReplies = [][]byte{
+ // pub key response
+ append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...),
+
+ // OK
+ {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 3
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ if !bytes.HasPrefix(conn.written, []byte{1, 0, 0, 3, 2, 0, 1, 0, 5}) {
+ t.Errorf("unexpected written data: %v", conn.written)
+ }
+}
+
+func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+ mc.cfg.pubKey = testPubKeyRSA
+
+ authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
+ 62, 94, 83, 80, 52, 85}
+ plugin := "caching_sha2_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165,
+ 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70,
+ 110, 40, 139, 124, 41}
+ if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+ t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response
+ conn.data = []byte{
+ 2, 0, 0, 2, 1, 4, // Perform Full Authentication
+ }
+ conn.queuedReplies = [][]byte{
+ // OK
+ {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 2
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) {
+ t.Errorf("unexpected written data: %v", conn.written)
+ }
+}
+
+func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+
+ authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
+ 62, 94, 83, 80, 52, 85}
+ plugin := "caching_sha2_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Hack to make the caching_sha2_password plugin believe that the connection
+ // is secure
+ mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165,
+ 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70,
+ 110, 40, 139, 124, 41}
+ if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+ t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response
+ conn.data = []byte{
+ 2, 0, 0, 2, 1, 4, // Perform Full Authentication
+ }
+ conn.queuedReplies = [][]byte{
+ // OK
+ {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 3
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ if !bytes.Equal(conn.written, []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}) {
+ t.Errorf("unexpected written data: %v", conn.written)
+ }
+}
+
+func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) {
+ _, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+
+ authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+ 103, 26, 95, 81, 17, 24, 21}
+ plugin := "mysql_clear_password"
+
+ // Send Client Authentication Packet
+ _, err := mc.auth(authData, plugin)
+ if err != ErrCleartextPassword {
+ t.Errorf("expected ErrCleartextPassword, got %v", err)
+ }
+}
+
+func TestAuthFastCleartextPassword(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+ mc.cfg.AllowCleartextPasswords = true
+
+ authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+ 103, 26, 95, 81, 17, 24, 21}
+ plugin := "mysql_clear_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
+ if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+ t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response
+ conn.data = []byte{
+ 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
+ }
+ conn.maxReads = 1
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+}
+
+func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = ""
+ mc.cfg.AllowCleartextPasswords = true
+
+ authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+ 103, 26, 95, 81, 17, 24, 21}
+ plugin := "mysql_clear_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ expectedAuthResp := []byte{0}
+ if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+ t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response
+ conn.data = []byte{
+ 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
+ }
+ conn.maxReads = 1
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+}
+
+func TestAuthFastNativePasswordNotAllowed(t *testing.T) {
+ _, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+ mc.cfg.AllowNativePasswords = false
+
+ authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+ 103, 26, 95, 81, 17, 24, 21}
+ plugin := "mysql_native_password"
+
+ // Send Client Authentication Packet
+ _, err := mc.auth(authData, plugin)
+ if err != ErrNativePassword {
+ t.Errorf("expected ErrNativePassword, got %v", err)
+ }
+}
+
+func TestAuthFastNativePassword(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+
+ authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+ 103, 26, 95, 81, 17, 24, 21}
+ plugin := "mysql_native_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ expectedAuthResp := []byte{53, 177, 140, 159, 251, 189, 127, 53, 109, 252,
+ 172, 50, 211, 192, 240, 164, 26, 48, 207, 45}
+ if writtenAuthRespLen != 20 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+ t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response
+ conn.data = []byte{
+ 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
+ }
+ conn.maxReads = 1
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+}
+
+func TestAuthFastNativePasswordEmpty(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = ""
+
+ authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+ 103, 26, 95, 81, 17, 24, 21}
+ plugin := "mysql_native_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ if writtenAuthRespLen != 0 {
+ t.Fatalf("unexpected written auth response (%d bytes): %v",
+ writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response
+ conn.data = []byte{
+ 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
+ }
+ conn.maxReads = 1
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+}
+
+func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = ""
+
+ authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
+ 62, 94, 83, 80, 52, 85}
+ plugin := "sha256_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ expectedAuthResp := []byte{0}
+ if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+ t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response (pub key response)
+ conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...)
+ conn.queuedReplies = [][]byte{
+ // OK
+ {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 2
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) {
+ t.Errorf("unexpected written data: %v", conn.written)
+ }
+}
+
+func TestAuthFastSHA256PasswordRSA(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+
+ authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
+ 62, 94, 83, 80, 52, 85}
+ plugin := "sha256_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ expectedAuthResp := []byte{1}
+ if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+ t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response (pub key response)
+ conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...)
+ conn.queuedReplies = [][]byte{
+ // OK
+ {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 2
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) {
+ t.Errorf("unexpected written data: %v", conn.written)
+ }
+}
+
+func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+ mc.cfg.pubKey = testPubKeyRSA
+
+ authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
+ 62, 94, 83, 80, 52, 85}
+ plugin := "sha256_password"
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // auth response (OK)
+ conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0}
+ conn.maxReads = 1
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+}
+
+func TestAuthFastSHA256PasswordSecure(t *testing.T) {
+ conn, mc := newRWMockConn(1)
+ mc.cfg.User = "root"
+ mc.cfg.Passwd = "secret"
+
+ // hack to make the caching_sha2_password plugin believe that the connection
+ // is secure
+ mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
+
+ authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
+ 62, 94, 83, 80, 52, 85}
+ plugin := "sha256_password"
+
+ // send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // unset TLS config to prevent the actual establishment of a TLS wrapper
+ mc.cfg.tls = nil
+
+ err = mc.writeHandshakeResponsePacket(authResp, plugin)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // check written auth response
+ authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+ authRespEnd := authRespStart + 1 + len(authResp)
+ writtenAuthRespLen := conn.written[authRespStart]
+ writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+ expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
+ if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+ t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+ }
+ conn.written = nil
+
+ // auth response (OK)
+ conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0}
+ conn.maxReads = 1
+
+ // Handle response to auth packet
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ if !bytes.Equal(conn.written, []byte{}) {
+ t.Errorf("unexpected written data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.Passwd = "secret"
+
+ // auth switch request
+ conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
+ 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101,
+ 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84,
+ 50, 0}
+
+ // auth response
+ conn.queuedReplies = [][]byte{
+ {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, // OK
+ }
+ conn.maxReads = 3
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{
+ // 1. Packet: Hash
+ 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128,
+ 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58,
+ 153, 9, 130,
+ }
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.Passwd = ""
+
+ // auth switch request
+ conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
+ 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101,
+ 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84,
+ 50, 0}
+
+ // auth response
+ conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}}
+ conn.maxReads = 2
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{0, 0, 0, 3}
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.Passwd = "secret"
+
+ // auth switch request
+ conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
+ 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101,
+ 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84,
+ 50, 0}
+
+ conn.queuedReplies = [][]byte{
+ // Perform Full Authentication
+ {2, 0, 0, 4, 1, 4},
+
+ // Pub Key Response
+ append([]byte{byte(1 + len(testPubKey)), 1, 0, 6, 1}, testPubKey...),
+
+ // OK
+ {7, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 4
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReplyPrefix := []byte{
+ // 1. Packet: Hash
+ 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128,
+ 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58,
+ 153, 9, 130,
+
+ // 2. Packet: Pub Key Request
+ 1, 0, 0, 5, 2,
+
+ // 3. Packet: Encrypted Password
+ 0, 1, 0, 7, // [changing bytes]
+ }
+ if !bytes.HasPrefix(conn.written, expectedReplyPrefix) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.Passwd = "secret"
+ mc.cfg.pubKey = testPubKeyRSA
+
+ // auth switch request
+ conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
+ 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101,
+ 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84,
+ 50, 0}
+
+ conn.queuedReplies = [][]byte{
+ // Perform Full Authentication
+ {2, 0, 0, 4, 1, 4},
+
+ // OK
+ {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 3
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReplyPrefix := []byte{
+ // 1. Packet: Hash
+ 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128,
+ 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58,
+ 153, 9, 130,
+
+ // 2. Packet: Encrypted Password
+ 0, 1, 0, 5, // [changing bytes]
+ }
+ if !bytes.HasPrefix(conn.written, expectedReplyPrefix) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.Passwd = "secret"
+
+ // Hack to make the caching_sha2_password plugin believe that the connection
+ // is secure
+ mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
+
+ // auth switch request
+ conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
+ 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101,
+ 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84,
+ 50, 0}
+
+ // auth response
+ conn.queuedReplies = [][]byte{
+ {2, 0, 0, 4, 1, 4}, // Perform Full Authentication
+ {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, // OK
+ }
+ conn.maxReads = 3
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{
+ // 1. Packet: Hash
+ 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128,
+ 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58,
+ 153, 9, 130,
+
+ // 2. Packet: Cleartext password
+ 7, 0, 0, 5, 115, 101, 99, 114, 101, 116, 0,
+ }
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+
+ conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108,
+ 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0}
+ conn.maxReads = 1
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+ err := mc.handleAuthResult(authData, plugin)
+ if err != ErrCleartextPassword {
+ t.Errorf("expected ErrCleartextPassword, got %v", err)
+ }
+}
+
+func TestAuthSwitchCleartextPassword(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.AllowCleartextPasswords = true
+ mc.cfg.Passwd = "secret"
+
+ // auth switch request
+ conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108,
+ 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0}
+
+ // auth response
+ conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}}
+ conn.maxReads = 2
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.AllowCleartextPasswords = true
+ mc.cfg.Passwd = ""
+
+ // auth switch request
+ conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108,
+ 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0}
+
+ // auth response
+ conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}}
+ conn.maxReads = 2
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{1, 0, 0, 3, 0}
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.AllowNativePasswords = false
+
+ conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97,
+ 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96,
+ 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55,
+ 31, 0}
+ conn.maxReads = 1
+ authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31,
+ 48, 31, 89, 39, 55, 31}
+ plugin := "caching_sha2_password"
+ err := mc.handleAuthResult(authData, plugin)
+ if err != ErrNativePassword {
+ t.Errorf("expected ErrNativePassword, got %v", err)
+ }
+}
+
+func TestAuthSwitchNativePassword(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.AllowNativePasswords = true
+ mc.cfg.Passwd = "secret"
+
+ // auth switch request
+ conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97,
+ 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96,
+ 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55,
+ 31, 0}
+
+ // auth response
+ conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}}
+ conn.maxReads = 2
+
+ authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31,
+ 48, 31, 89, 39, 55, 31}
+ plugin := "caching_sha2_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{20, 0, 0, 3, 202, 41, 195, 164, 34, 226, 49, 103,
+ 21, 211, 167, 199, 227, 116, 8, 48, 57, 71, 149, 146}
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchNativePasswordEmpty(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.AllowNativePasswords = true
+ mc.cfg.Passwd = ""
+
+ // auth switch request
+ conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97,
+ 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96,
+ 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55,
+ 31, 0}
+
+ // auth response
+ conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}}
+ conn.maxReads = 2
+
+ authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31,
+ 48, 31, 89, 39, 55, 31}
+ plugin := "caching_sha2_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{0, 0, 0, 3}
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+
+ conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108,
+ 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61,
+ 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0}
+ conn.maxReads = 1
+ authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
+ 84, 96, 101, 92, 123, 121, 107}
+ plugin := "mysql_native_password"
+ err := mc.handleAuthResult(authData, plugin)
+ if err != ErrOldPassword {
+ t.Errorf("expected ErrOldPassword, got %v", err)
+ }
+}
+
+// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request.
+func TestOldAuthSwitchNotAllowed(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+
+ // OldAuthSwitch request
+ conn.data = []byte{1, 0, 0, 2, 0xfe}
+ conn.maxReads = 1
+ authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
+ 84, 96, 101, 92, 123, 121, 107}
+ plugin := "mysql_native_password"
+ err := mc.handleAuthResult(authData, plugin)
+ if err != ErrOldPassword {
+ t.Errorf("expected ErrOldPassword, got %v", err)
+ }
+}
+
+func TestAuthSwitchOldPassword(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.AllowOldPasswords = true
+ mc.cfg.Passwd = "secret"
+
+ // auth switch request
+ conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108,
+ 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61,
+ 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0}
+
+ // auth response
+ conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
+ conn.maxReads = 2
+
+ authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
+ 84, 96, 101, 92, 123, 121, 107}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0}
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request.
+func TestOldAuthSwitch(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.AllowOldPasswords = true
+ mc.cfg.Passwd = "secret"
+
+ // OldAuthSwitch request
+ conn.data = []byte{1, 0, 0, 2, 0xfe}
+
+ // auth response
+ conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
+ conn.maxReads = 2
+
+ authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
+ 84, 96, 101, 92, 123, 121, 107}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0}
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+func TestAuthSwitchOldPasswordEmpty(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.AllowOldPasswords = true
+ mc.cfg.Passwd = ""
+
+ // auth switch request
+ conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108,
+ 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61,
+ 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0}
+
+ // auth response
+ conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
+ conn.maxReads = 2
+
+ authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
+ 84, 96, 101, 92, 123, 121, 107}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{0, 0, 0, 3}
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request.
+func TestOldAuthSwitchPasswordEmpty(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.AllowOldPasswords = true
+ mc.cfg.Passwd = ""
+
+ // OldAuthSwitch request.
+ conn.data = []byte{1, 0, 0, 2, 0xfe}
+
+ // auth response
+ conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
+ conn.maxReads = 2
+
+ authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
+ 84, 96, 101, 92, 123, 121, 107}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReply := []byte{0, 0, 0, 3}
+ if !bytes.Equal(conn.written, expectedReply) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.Passwd = ""
+
+ // auth switch request
+ conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97,
+ 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69,
+ 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0}
+
+ conn.queuedReplies = [][]byte{
+ // OK
+ {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 3
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReplyPrefix := []byte{
+ // 1. Packet: Empty Password
+ 1, 0, 0, 3, 0,
+ }
+ if !bytes.HasPrefix(conn.written, expectedReplyPrefix) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchSHA256PasswordRSA(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.Passwd = "secret"
+
+ // auth switch request
+ conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97,
+ 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69,
+ 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0}
+
+ conn.queuedReplies = [][]byte{
+ // Pub Key Response
+ append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...),
+
+ // OK
+ {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 3
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReplyPrefix := []byte{
+ // 1. Packet: Pub Key Request
+ 1, 0, 0, 3, 1,
+
+ // 2. Packet: Encrypted Password
+ 0, 1, 0, 5, // [changing bytes]
+ }
+ if !bytes.HasPrefix(conn.written, expectedReplyPrefix) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.Passwd = "secret"
+ mc.cfg.pubKey = testPubKeyRSA
+
+ // auth switch request
+ conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97,
+ 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69,
+ 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0}
+
+ conn.queuedReplies = [][]byte{
+ // OK
+ {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 2
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReplyPrefix := []byte{
+ // 1. Packet: Encrypted Password
+ 0, 1, 0, 3, // [changing bytes]
+ }
+ if !bytes.HasPrefix(conn.written, expectedReplyPrefix) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
+
+func TestAuthSwitchSHA256PasswordSecure(t *testing.T) {
+ conn, mc := newRWMockConn(2)
+ mc.cfg.Passwd = "secret"
+
+ // Hack to make the caching_sha2_password plugin believe that the connection
+ // is secure
+ mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
+
+ // auth switch request
+ conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97,
+ 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69,
+ 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0}
+
+ conn.queuedReplies = [][]byte{
+ // OK
+ {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0},
+ }
+ conn.maxReads = 2
+
+ authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+ 47, 43, 9, 41, 112, 67, 110}
+ plugin := "mysql_native_password"
+
+ if err := mc.handleAuthResult(authData, plugin); err != nil {
+ t.Errorf("got error: %v", err)
+ }
+
+ expectedReplyPrefix := []byte{
+ // 1. Packet: Cleartext Password
+ 7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0,
+ }
+ if !bytes.Equal(conn.written, expectedReplyPrefix) {
+ t.Errorf("got unexpected data: %v", conn.written)
+ }
+}
diff --git a/vendor/mysql/benchmark_test.go b/vendor/mysql/benchmark_test.go
new file mode 100644
index 0000000..1030ddc
--- /dev/null
+++ b/vendor/mysql/benchmark_test.go
@@ -0,0 +1,374 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "fmt"
+ "math"
+ "runtime"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+type TB testing.B
+
+func (tb *TB) check(err error) {
+ if err != nil {
+ tb.Fatal(err)
+ }
+}
+
+func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB {
+ tb.check(err)
+ return db
+}
+
+func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows {
+ tb.check(err)
+ return rows
+}
+
+func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
+ tb.check(err)
+ return stmt
+}
+
+func initDB(b *testing.B, queries ...string) *sql.DB {
+ tb := (*TB)(b)
+ db := tb.checkDB(sql.Open("mysql", dsn))
+ for _, query := range queries {
+ if _, err := db.Exec(query); err != nil {
+ b.Fatalf("error on %q: %v", query, err)
+ }
+ }
+ return db
+}
+
+const concurrencyLevel = 10
+
+func BenchmarkQuery(b *testing.B) {
+ tb := (*TB)(b)
+ b.StopTimer()
+ b.ReportAllocs()
+ db := initDB(b,
+ "DROP TABLE IF EXISTS foo",
+ "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
+ `INSERT INTO foo VALUES (1, "one")`,
+ `INSERT INTO foo VALUES (2, "two")`,
+ )
+ db.SetMaxIdleConns(concurrencyLevel)
+ defer db.Close()
+
+ stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
+ defer stmt.Close()
+
+ remain := int64(b.N)
+ var wg sync.WaitGroup
+ wg.Add(concurrencyLevel)
+ defer wg.Wait()
+ b.StartTimer()
+
+ for i := 0; i < concurrencyLevel; i++ {
+ go func() {
+ for {
+ if atomic.AddInt64(&remain, -1) < 0 {
+ wg.Done()
+ return
+ }
+
+ var got string
+ tb.check(stmt.QueryRow(1).Scan(&got))
+ if got != "one" {
+ b.Errorf("query = %q; want one", got)
+ wg.Done()
+ return
+ }
+ }
+ }()
+ }
+}
+
+func BenchmarkExec(b *testing.B) {
+ tb := (*TB)(b)
+ b.StopTimer()
+ b.ReportAllocs()
+ db := tb.checkDB(sql.Open("mysql", dsn))
+ db.SetMaxIdleConns(concurrencyLevel)
+ defer db.Close()
+
+ stmt := tb.checkStmt(db.Prepare("DO 1"))
+ defer stmt.Close()
+
+ remain := int64(b.N)
+ var wg sync.WaitGroup
+ wg.Add(concurrencyLevel)
+ defer wg.Wait()
+ b.StartTimer()
+
+ for i := 0; i < concurrencyLevel; i++ {
+ go func() {
+ for {
+ if atomic.AddInt64(&remain, -1) < 0 {
+ wg.Done()
+ return
+ }
+
+ if _, err := stmt.Exec(); err != nil {
+ b.Logf("stmt.Exec failed: %v", err)
+ b.Fail()
+ }
+ }
+ }()
+ }
+}
+
+// data, but no db writes
+var roundtripSample []byte
+
+func initRoundtripBenchmarks() ([]byte, int, int) {
+ if roundtripSample == nil {
+ roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024))
+ }
+ return roundtripSample, 16, len(roundtripSample)
+}
+
+func BenchmarkRoundtripTxt(b *testing.B) {
+ b.StopTimer()
+ sample, min, max := initRoundtripBenchmarks()
+ sampleString := string(sample)
+ b.ReportAllocs()
+ tb := (*TB)(b)
+ db := tb.checkDB(sql.Open("mysql", dsn))
+ defer db.Close()
+ b.StartTimer()
+ var result string
+ for i := 0; i < b.N; i++ {
+ length := min + i
+ if length > max {
+ length = max
+ }
+ test := sampleString[0:length]
+ rows := tb.checkRows(db.Query(`SELECT "` + test + `"`))
+ if !rows.Next() {
+ rows.Close()
+ b.Fatalf("crashed")
+ }
+ err := rows.Scan(&result)
+ if err != nil {
+ rows.Close()
+ b.Fatalf("crashed")
+ }
+ if result != test {
+ rows.Close()
+ b.Errorf("mismatch")
+ }
+ rows.Close()
+ }
+}
+
+func BenchmarkRoundtripBin(b *testing.B) {
+ b.StopTimer()
+ sample, min, max := initRoundtripBenchmarks()
+ b.ReportAllocs()
+ tb := (*TB)(b)
+ db := tb.checkDB(sql.Open("mysql", dsn))
+ defer db.Close()
+ stmt := tb.checkStmt(db.Prepare("SELECT ?"))
+ defer stmt.Close()
+ b.StartTimer()
+ var result sql.RawBytes
+ for i := 0; i < b.N; i++ {
+ length := min + i
+ if length > max {
+ length = max
+ }
+ test := sample[0:length]
+ rows := tb.checkRows(stmt.Query(test))
+ if !rows.Next() {
+ rows.Close()
+ b.Fatalf("crashed")
+ }
+ err := rows.Scan(&result)
+ if err != nil {
+ rows.Close()
+ b.Fatalf("crashed")
+ }
+ if !bytes.Equal(result, test) {
+ rows.Close()
+ b.Errorf("mismatch")
+ }
+ rows.Close()
+ }
+}
+
+func BenchmarkInterpolation(b *testing.B) {
+ mc := &mysqlConn{
+ cfg: &Config{
+ InterpolateParams: true,
+ Loc: time.UTC,
+ },
+ maxAllowedPacket: maxPacketSize,
+ maxWriteSize: maxPacketSize - 1,
+ buf: newBuffer(nil),
+ }
+
+ args := []driver.Value{
+ int64(42424242),
+ float64(math.Pi),
+ false,
+ time.Unix(1423411542, 807015000),
+ []byte("bytes containing special chars ' \" \a \x00"),
+ "string containing special chars ' \" \a \x00",
+ }
+ q := "SELECT ?, ?, ?, ?, ?, ?"
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := mc.interpolateParams(q, args)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
+
+ tb := (*TB)(b)
+ stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?"))
+ defer stmt.Close()
+
+ b.SetParallelism(p)
+ b.ReportAllocs()
+ b.ResetTimer()
+ b.RunParallel(func(pb *testing.PB) {
+ var got string
+ for pb.Next() {
+ tb.check(stmt.QueryRow(1).Scan(&got))
+ if got != "one" {
+ b.Fatalf("query = %q; want one", got)
+ }
+ }
+ })
+}
+
+func BenchmarkQueryContext(b *testing.B) {
+ db := initDB(b,
+ "DROP TABLE IF EXISTS foo",
+ "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
+ `INSERT INTO foo VALUES (1, "one")`,
+ `INSERT INTO foo VALUES (2, "two")`,
+ )
+ defer db.Close()
+ for _, p := range []int{1, 2, 3, 4} {
+ b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
+ benchmarkQueryContext(b, db, p)
+ })
+ }
+}
+
+func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
+
+ tb := (*TB)(b)
+ stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1"))
+ defer stmt.Close()
+
+ b.SetParallelism(p)
+ b.ReportAllocs()
+ b.ResetTimer()
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ if _, err := stmt.ExecContext(ctx); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+}
+
+func BenchmarkExecContext(b *testing.B) {
+ db := initDB(b,
+ "DROP TABLE IF EXISTS foo",
+ "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
+ `INSERT INTO foo VALUES (1, "one")`,
+ `INSERT INTO foo VALUES (2, "two")`,
+ )
+ defer db.Close()
+ for _, p := range []int{1, 2, 3, 4} {
+ b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
+ benchmarkQueryContext(b, db, p)
+ })
+ }
+}
+
+// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes.
+// "size=" means size of each blobs.
+func BenchmarkQueryRawBytes(b *testing.B) {
+ var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000}
+ db := initDB(b,
+ "DROP TABLE IF EXISTS bench_rawbytes",
+ "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)",
+ )
+ defer db.Close()
+
+ blob := make([]byte, sizes[len(sizes)-1])
+ for i := range blob {
+ blob[i] = 42
+ }
+ for i := 0; i < 100; i++ {
+ _, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+
+ for _, s := range sizes {
+ b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) {
+ db.SetMaxIdleConns(0)
+ db.SetMaxIdleConns(1)
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for j := 0; j < b.N; j++ {
+ rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s)
+ if err != nil {
+ b.Fatal(err)
+ }
+ nrows := 0
+ for rows.Next() {
+ var buf sql.RawBytes
+ err := rows.Scan(&buf)
+ if err != nil {
+ b.Fatal(err)
+ }
+ if len(buf) != s {
+ b.Fatalf("size mismatch: expected %v, got %v", s, len(buf))
+ }
+ nrows++
+ }
+ rows.Close()
+ if nrows != 100 {
+ b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows)
+ }
+ }
+ })
+ }
+}
diff --git a/vendor/mysql/buffer.go b/vendor/mysql/buffer.go
new file mode 100644
index 0000000..0774c5c
--- /dev/null
+++ b/vendor/mysql/buffer.go
@@ -0,0 +1,182 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "io"
+ "net"
+ "time"
+)
+
+const defaultBufSize = 4096
+const maxCachedBufSize = 256 * 1024
+
+// A buffer which is used for both reading and writing.
+// This is possible since communication on each connection is synchronous.
+// In other words, we can't write and read simultaneously on the same connection.
+// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
+// Also highly optimized for this particular use case.
+// This buffer is backed by two byte slices in a double-buffering scheme
+type buffer struct {
+ buf []byte // buf is a byte buffer who's length and capacity are equal.
+ nc net.Conn
+ idx int
+ length int
+ timeout time.Duration
+ dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer
+ flipcnt uint // flipccnt is the current buffer counter for double-buffering
+}
+
+// newBuffer allocates and returns a new buffer.
+func newBuffer(nc net.Conn) buffer {
+ fg := make([]byte, defaultBufSize)
+ return buffer{
+ buf: fg,
+ nc: nc,
+ dbuf: [2][]byte{fg, nil},
+ }
+}
+
+// flip replaces the active buffer with the background buffer
+// this is a delayed flip that simply increases the buffer counter;
+// the actual flip will be performed the next time we call `buffer.fill`
+func (b *buffer) flip() {
+ b.flipcnt += 1
+}
+
+// fill reads into the buffer until at least _need_ bytes are in it
+func (b *buffer) fill(need int) error {
+ n := b.length
+ // fill data into its double-buffering target: if we've called
+ // flip on this buffer, we'll be copying to the background buffer,
+ // and then filling it with network data; otherwise we'll just move
+ // the contents of the current buffer to the front before filling it
+ dest := b.dbuf[b.flipcnt&1]
+
+ // grow buffer if necessary to fit the whole packet.
+ if need > len(dest) {
+ // Round up to the next multiple of the default size
+ dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
+
+ // if the allocated buffer is not too large, move it to backing storage
+ // to prevent extra allocations on applications that perform large reads
+ if len(dest) <= maxCachedBufSize {
+ b.dbuf[b.flipcnt&1] = dest
+ }
+ }
+
+ // if we're filling the fg buffer, move the existing data to the start of it.
+ // if we're filling the bg buffer, copy over the data
+ if n > 0 {
+ copy(dest[:n], b.buf[b.idx:])
+ }
+
+ b.buf = dest
+ b.idx = 0
+
+ for {
+ if b.timeout > 0 {
+ if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
+ return err
+ }
+ }
+
+ nn, err := b.nc.Read(b.buf[n:])
+ n += nn
+
+ switch err {
+ case nil:
+ if n < need {
+ continue
+ }
+ b.length = n
+ return nil
+
+ case io.EOF:
+ if n >= need {
+ b.length = n
+ return nil
+ }
+ return io.ErrUnexpectedEOF
+
+ default:
+ return err
+ }
+ }
+}
+
+// returns next N bytes from buffer.
+// The returned slice is only guaranteed to be valid until the next read
+func (b *buffer) readNext(need int) ([]byte, error) {
+ if b.length < need {
+ // refill
+ if err := b.fill(need); err != nil {
+ return nil, err
+ }
+ }
+
+ offset := b.idx
+ b.idx += need
+ b.length -= need
+ return b.buf[offset:b.idx], nil
+}
+
+// takeBuffer returns a buffer with the requested size.
+// If possible, a slice from the existing buffer is returned.
+// Otherwise a bigger buffer is made.
+// Only one buffer (total) can be used at a time.
+func (b *buffer) takeBuffer(length int) ([]byte, error) {
+ if b.length > 0 {
+ return nil, ErrBusyBuffer
+ }
+
+ // test (cheap) general case first
+ if length <= cap(b.buf) {
+ return b.buf[:length], nil
+ }
+
+ if length < maxPacketSize {
+ b.buf = make([]byte, length)
+ return b.buf, nil
+ }
+
+ // buffer is larger than we want to store.
+ return make([]byte, length), nil
+}
+
+// takeSmallBuffer is shortcut which can be used if length is
+// known to be smaller than defaultBufSize.
+// Only one buffer (total) can be used at a time.
+func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
+ if b.length > 0 {
+ return nil, ErrBusyBuffer
+ }
+ return b.buf[:length], nil
+}
+
+// takeCompleteBuffer returns the complete existing buffer.
+// This can be used if the necessary buffer size is unknown.
+// cap and len of the returned buffer will be equal.
+// Only one buffer (total) can be used at a time.
+func (b *buffer) takeCompleteBuffer() ([]byte, error) {
+ if b.length > 0 {
+ return nil, ErrBusyBuffer
+ }
+ return b.buf, nil
+}
+
+// store stores buf, an updated buffer, if its suitable to do so.
+func (b *buffer) store(buf []byte) error {
+ if b.length > 0 {
+ return ErrBusyBuffer
+ } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
+ b.buf = buf[:cap(buf)]
+ }
+ return nil
+}
diff --git a/vendor/mysql/collations.go b/vendor/mysql/collations.go
new file mode 100644
index 0000000..326a9f7
--- /dev/null
+++ b/vendor/mysql/collations.go
@@ -0,0 +1,265 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+const defaultCollation = "utf8mb4_general_ci"
+const binaryCollation = "binary"
+
+// A list of available collations mapped to the internal ID.
+// To update this map use the following MySQL query:
+// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID
+//
+// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255.
+//
+// ucs2, utf16, and utf32 can't be used for connection charset.
+// https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset
+// They are commented out to reduce this map.
+var collations = map[string]byte{
+ "big5_chinese_ci": 1,
+ "latin2_czech_cs": 2,
+ "dec8_swedish_ci": 3,
+ "cp850_general_ci": 4,
+ "latin1_german1_ci": 5,
+ "hp8_english_ci": 6,
+ "koi8r_general_ci": 7,
+ "latin1_swedish_ci": 8,
+ "latin2_general_ci": 9,
+ "swe7_swedish_ci": 10,
+ "ascii_general_ci": 11,
+ "ujis_japanese_ci": 12,
+ "sjis_japanese_ci": 13,
+ "cp1251_bulgarian_ci": 14,
+ "latin1_danish_ci": 15,
+ "hebrew_general_ci": 16,
+ "tis620_thai_ci": 18,
+ "euckr_korean_ci": 19,
+ "latin7_estonian_cs": 20,
+ "latin2_hungarian_ci": 21,
+ "koi8u_general_ci": 22,
+ "cp1251_ukrainian_ci": 23,
+ "gb2312_chinese_ci": 24,
+ "greek_general_ci": 25,
+ "cp1250_general_ci": 26,
+ "latin2_croatian_ci": 27,
+ "gbk_chinese_ci": 28,
+ "cp1257_lithuanian_ci": 29,
+ "latin5_turkish_ci": 30,
+ "latin1_german2_ci": 31,
+ "armscii8_general_ci": 32,
+ "utf8_general_ci": 33,
+ "cp1250_czech_cs": 34,
+ //"ucs2_general_ci": 35,
+ "cp866_general_ci": 36,
+ "keybcs2_general_ci": 37,
+ "macce_general_ci": 38,
+ "macroman_general_ci": 39,
+ "cp852_general_ci": 40,
+ "latin7_general_ci": 41,
+ "latin7_general_cs": 42,
+ "macce_bin": 43,
+ "cp1250_croatian_ci": 44,
+ "utf8mb4_general_ci": 45,
+ "utf8mb4_bin": 46,
+ "latin1_bin": 47,
+ "latin1_general_ci": 48,
+ "latin1_general_cs": 49,
+ "cp1251_bin": 50,
+ "cp1251_general_ci": 51,
+ "cp1251_general_cs": 52,
+ "macroman_bin": 53,
+ //"utf16_general_ci": 54,
+ //"utf16_bin": 55,
+ //"utf16le_general_ci": 56,
+ "cp1256_general_ci": 57,
+ "cp1257_bin": 58,
+ "cp1257_general_ci": 59,
+ //"utf32_general_ci": 60,
+ //"utf32_bin": 61,
+ //"utf16le_bin": 62,
+ "binary": 63,
+ "armscii8_bin": 64,
+ "ascii_bin": 65,
+ "cp1250_bin": 66,
+ "cp1256_bin": 67,
+ "cp866_bin": 68,
+ "dec8_bin": 69,
+ "greek_bin": 70,
+ "hebrew_bin": 71,
+ "hp8_bin": 72,
+ "keybcs2_bin": 73,
+ "koi8r_bin": 74,
+ "koi8u_bin": 75,
+ "utf8_tolower_ci": 76,
+ "latin2_bin": 77,
+ "latin5_bin": 78,
+ "latin7_bin": 79,
+ "cp850_bin": 80,
+ "cp852_bin": 81,
+ "swe7_bin": 82,
+ "utf8_bin": 83,
+ "big5_bin": 84,
+ "euckr_bin": 85,
+ "gb2312_bin": 86,
+ "gbk_bin": 87,
+ "sjis_bin": 88,
+ "tis620_bin": 89,
+ //"ucs2_bin": 90,
+ "ujis_bin": 91,
+ "geostd8_general_ci": 92,
+ "geostd8_bin": 93,
+ "latin1_spanish_ci": 94,
+ "cp932_japanese_ci": 95,
+ "cp932_bin": 96,
+ "eucjpms_japanese_ci": 97,
+ "eucjpms_bin": 98,
+ "cp1250_polish_ci": 99,
+ //"utf16_unicode_ci": 101,
+ //"utf16_icelandic_ci": 102,
+ //"utf16_latvian_ci": 103,
+ //"utf16_romanian_ci": 104,
+ //"utf16_slovenian_ci": 105,
+ //"utf16_polish_ci": 106,
+ //"utf16_estonian_ci": 107,
+ //"utf16_spanish_ci": 108,
+ //"utf16_swedish_ci": 109,
+ //"utf16_turkish_ci": 110,
+ //"utf16_czech_ci": 111,
+ //"utf16_danish_ci": 112,
+ //"utf16_lithuanian_ci": 113,
+ //"utf16_slovak_ci": 114,
+ //"utf16_spanish2_ci": 115,
+ //"utf16_roman_ci": 116,
+ //"utf16_persian_ci": 117,
+ //"utf16_esperanto_ci": 118,
+ //"utf16_hungarian_ci": 119,
+ //"utf16_sinhala_ci": 120,
+ //"utf16_german2_ci": 121,
+ //"utf16_croatian_ci": 122,
+ //"utf16_unicode_520_ci": 123,
+ //"utf16_vietnamese_ci": 124,
+ //"ucs2_unicode_ci": 128,
+ //"ucs2_icelandic_ci": 129,
+ //"ucs2_latvian_ci": 130,
+ //"ucs2_romanian_ci": 131,
+ //"ucs2_slovenian_ci": 132,
+ //"ucs2_polish_ci": 133,
+ //"ucs2_estonian_ci": 134,
+ //"ucs2_spanish_ci": 135,
+ //"ucs2_swedish_ci": 136,
+ //"ucs2_turkish_ci": 137,
+ //"ucs2_czech_ci": 138,
+ //"ucs2_danish_ci": 139,
+ //"ucs2_lithuanian_ci": 140,
+ //"ucs2_slovak_ci": 141,
+ //"ucs2_spanish2_ci": 142,
+ //"ucs2_roman_ci": 143,
+ //"ucs2_persian_ci": 144,
+ //"ucs2_esperanto_ci": 145,
+ //"ucs2_hungarian_ci": 146,
+ //"ucs2_sinhala_ci": 147,
+ //"ucs2_german2_ci": 148,
+ //"ucs2_croatian_ci": 149,
+ //"ucs2_unicode_520_ci": 150,
+ //"ucs2_vietnamese_ci": 151,
+ //"ucs2_general_mysql500_ci": 159,
+ //"utf32_unicode_ci": 160,
+ //"utf32_icelandic_ci": 161,
+ //"utf32_latvian_ci": 162,
+ //"utf32_romanian_ci": 163,
+ //"utf32_slovenian_ci": 164,
+ //"utf32_polish_ci": 165,
+ //"utf32_estonian_ci": 166,
+ //"utf32_spanish_ci": 167,
+ //"utf32_swedish_ci": 168,
+ //"utf32_turkish_ci": 169,
+ //"utf32_czech_ci": 170,
+ //"utf32_danish_ci": 171,
+ //"utf32_lithuanian_ci": 172,
+ //"utf32_slovak_ci": 173,
+ //"utf32_spanish2_ci": 174,
+ //"utf32_roman_ci": 175,
+ //"utf32_persian_ci": 176,
+ //"utf32_esperanto_ci": 177,
+ //"utf32_hungarian_ci": 178,
+ //"utf32_sinhala_ci": 179,
+ //"utf32_german2_ci": 180,
+ //"utf32_croatian_ci": 181,
+ //"utf32_unicode_520_ci": 182,
+ //"utf32_vietnamese_ci": 183,
+ "utf8_unicode_ci": 192,
+ "utf8_icelandic_ci": 193,
+ "utf8_latvian_ci": 194,
+ "utf8_romanian_ci": 195,
+ "utf8_slovenian_ci": 196,
+ "utf8_polish_ci": 197,
+ "utf8_estonian_ci": 198,
+ "utf8_spanish_ci": 199,
+ "utf8_swedish_ci": 200,
+ "utf8_turkish_ci": 201,
+ "utf8_czech_ci": 202,
+ "utf8_danish_ci": 203,
+ "utf8_lithuanian_ci": 204,
+ "utf8_slovak_ci": 205,
+ "utf8_spanish2_ci": 206,
+ "utf8_roman_ci": 207,
+ "utf8_persian_ci": 208,
+ "utf8_esperanto_ci": 209,
+ "utf8_hungarian_ci": 210,
+ "utf8_sinhala_ci": 211,
+ "utf8_german2_ci": 212,
+ "utf8_croatian_ci": 213,
+ "utf8_unicode_520_ci": 214,
+ "utf8_vietnamese_ci": 215,
+ "utf8_general_mysql500_ci": 223,
+ "utf8mb4_unicode_ci": 224,
+ "utf8mb4_icelandic_ci": 225,
+ "utf8mb4_latvian_ci": 226,
+ "utf8mb4_romanian_ci": 227,
+ "utf8mb4_slovenian_ci": 228,
+ "utf8mb4_polish_ci": 229,
+ "utf8mb4_estonian_ci": 230,
+ "utf8mb4_spanish_ci": 231,
+ "utf8mb4_swedish_ci": 232,
+ "utf8mb4_turkish_ci": 233,
+ "utf8mb4_czech_ci": 234,
+ "utf8mb4_danish_ci": 235,
+ "utf8mb4_lithuanian_ci": 236,
+ "utf8mb4_slovak_ci": 237,
+ "utf8mb4_spanish2_ci": 238,
+ "utf8mb4_roman_ci": 239,
+ "utf8mb4_persian_ci": 240,
+ "utf8mb4_esperanto_ci": 241,
+ "utf8mb4_hungarian_ci": 242,
+ "utf8mb4_sinhala_ci": 243,
+ "utf8mb4_german2_ci": 244,
+ "utf8mb4_croatian_ci": 245,
+ "utf8mb4_unicode_520_ci": 246,
+ "utf8mb4_vietnamese_ci": 247,
+ "gb18030_chinese_ci": 248,
+ "gb18030_bin": 249,
+ "gb18030_unicode_520_ci": 250,
+ "utf8mb4_0900_ai_ci": 255,
+}
+
+// A denylist of collations which is unsafe to interpolate parameters.
+// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
+var unsafeCollations = map[string]bool{
+ "big5_chinese_ci": true,
+ "sjis_japanese_ci": true,
+ "gbk_chinese_ci": true,
+ "big5_bin": true,
+ "gb2312_bin": true,
+ "gbk_bin": true,
+ "sjis_bin": true,
+ "cp932_japanese_ci": true,
+ "cp932_bin": true,
+ "gb18030_chinese_ci": true,
+ "gb18030_bin": true,
+ "gb18030_unicode_520_ci": true,
+}
diff --git a/vendor/mysql/conncheck.go b/vendor/mysql/conncheck.go
new file mode 100644
index 0000000..024eb28
--- /dev/null
+++ b/vendor/mysql/conncheck.go
@@ -0,0 +1,54 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
+
+package mysql
+
+import (
+ "errors"
+ "io"
+ "net"
+ "syscall"
+)
+
+var errUnexpectedRead = errors.New("unexpected read from socket")
+
+func connCheck(conn net.Conn) error {
+ var sysErr error
+
+ sysConn, ok := conn.(syscall.Conn)
+ if !ok {
+ return nil
+ }
+ rawConn, err := sysConn.SyscallConn()
+ if err != nil {
+ return err
+ }
+
+ err = rawConn.Read(func(fd uintptr) bool {
+ var buf [1]byte
+ n, err := syscall.Read(int(fd), buf[:])
+ switch {
+ case n == 0 && err == nil:
+ sysErr = io.EOF
+ case n > 0:
+ sysErr = errUnexpectedRead
+ case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
+ sysErr = nil
+ default:
+ sysErr = err
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+
+ return sysErr
+}
diff --git a/vendor/mysql/conncheck_dummy.go b/vendor/mysql/conncheck_dummy.go
new file mode 100644
index 0000000..ea7fb60
--- /dev/null
+++ b/vendor/mysql/conncheck_dummy.go
@@ -0,0 +1,17 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos
+
+package mysql
+
+import "net"
+
+func connCheck(conn net.Conn) error {
+ return nil
+}
diff --git a/vendor/mysql/conncheck_test.go b/vendor/mysql/conncheck_test.go
new file mode 100644
index 0000000..5399551
--- /dev/null
+++ b/vendor/mysql/conncheck_test.go
@@ -0,0 +1,38 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
+
+package mysql
+
+import (
+ "testing"
+ "time"
+)
+
+func TestStaleConnectionChecks(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("SET @@SESSION.wait_timeout = 2")
+
+ if err := dbt.db.Ping(); err != nil {
+ dbt.Fatal(err)
+ }
+
+ // wait for MySQL to close our connection
+ time.Sleep(3 * time.Second)
+
+ tx, err := dbt.db.Begin()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ if err := tx.Rollback(); err != nil {
+ dbt.Fatal(err)
+ }
+ })
+}
diff --git a/vendor/mysql/connection.go b/vendor/mysql/connection.go
new file mode 100644
index 0000000..835f897
--- /dev/null
+++ b/vendor/mysql/connection.go
@@ -0,0 +1,650 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "encoding/json"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type mysqlConn struct {
+ buf buffer
+ netConn net.Conn
+ rawConn net.Conn // underlying connection when netConn is TLS connection.
+ affectedRows uint64
+ insertId uint64
+ cfg *Config
+ maxAllowedPacket int
+ maxWriteSize int
+ writeTimeout time.Duration
+ flags clientFlag
+ status statusFlag
+ sequence uint8
+ parseTime bool
+ reset bool // set when the Go SQL package calls ResetSession
+
+ // for context support (Go 1.8+)
+ watching bool
+ watcher chan<- context.Context
+ closech chan struct{}
+ finished chan<- struct{}
+ canceled atomicError // set non-nil if conn is canceled
+ closed atomicBool // set when conn is closed, before closech is closed
+}
+
+// Handles parameters set in DSN after the connection is established
+func (mc *mysqlConn) handleParams() (err error) {
+ var cmdSet strings.Builder
+ for param, val := range mc.cfg.Params {
+ switch param {
+ // Charset: character_set_connection, character_set_client, character_set_results
+ case "charset":
+ charsets := strings.Split(val, ",")
+ for i := range charsets {
+ // ignore errors here - a charset may not exist
+ err = mc.exec("SET NAMES " + charsets[i])
+ if err == nil {
+ break
+ }
+ }
+ if err != nil {
+ return
+ }
+
+ // Other system vars accumulated in a single SET command
+ default:
+ if cmdSet.Len() == 0 {
+ // Heuristic: 29 chars for each other key=value to reduce reallocations
+ cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1))
+ cmdSet.WriteString("SET ")
+ } else {
+ cmdSet.WriteByte(',')
+ }
+ cmdSet.WriteString(param)
+ cmdSet.WriteByte('=')
+ cmdSet.WriteString(val)
+ }
+ }
+
+ if cmdSet.Len() > 0 {
+ err = mc.exec(cmdSet.String())
+ if err != nil {
+ return
+ }
+ }
+
+ return
+}
+
+func (mc *mysqlConn) markBadConn(err error) error {
+ if mc == nil {
+ return err
+ }
+ if err != errBadConnNoWrite {
+ return err
+ }
+ return driver.ErrBadConn
+}
+
+func (mc *mysqlConn) Begin() (driver.Tx, error) {
+ return mc.begin(false)
+}
+
+func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
+ if mc.closed.IsSet() {
+ errLog.Print(ErrInvalidConn)
+ return nil, driver.ErrBadConn
+ }
+ var q string
+ if readOnly {
+ q = "START TRANSACTION READ ONLY"
+ } else {
+ q = "START TRANSACTION"
+ }
+ err := mc.exec(q)
+ if err == nil {
+ return &mysqlTx{mc}, err
+ }
+ return nil, mc.markBadConn(err)
+}
+
+func (mc *mysqlConn) Close() (err error) {
+ // Makes Close idempotent
+ if !mc.closed.IsSet() {
+ err = mc.writeCommandPacket(comQuit)
+ }
+
+ mc.cleanup()
+
+ return
+}
+
+// Closes the network connection and unsets internal variables. Do not call this
+// function after successfully authentication, call Close instead. This function
+// is called before auth or on auth failure because MySQL will have already
+// closed the network connection.
+func (mc *mysqlConn) cleanup() {
+ if !mc.closed.TrySet(true) {
+ return
+ }
+
+ // Makes cleanup idempotent
+ close(mc.closech)
+ if mc.netConn == nil {
+ return
+ }
+ if err := mc.netConn.Close(); err != nil {
+ errLog.Print(err)
+ }
+}
+
+func (mc *mysqlConn) error() error {
+ if mc.closed.IsSet() {
+ if err := mc.canceled.Value(); err != nil {
+ return err
+ }
+ return ErrInvalidConn
+ }
+ return nil
+}
+
+func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
+ if mc.closed.IsSet() {
+ errLog.Print(ErrInvalidConn)
+ return nil, driver.ErrBadConn
+ }
+ // Send command
+ err := mc.writeCommandPacketStr(comStmtPrepare, query)
+ if err != nil {
+ // STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
+ errLog.Print(err)
+ return nil, driver.ErrBadConn
+ }
+
+ stmt := &mysqlStmt{
+ mc: mc,
+ }
+
+ // Read Result
+ columnCount, err := stmt.readPrepareResultPacket()
+ if err == nil {
+ if stmt.paramCount > 0 {
+ if err = mc.readUntilEOF(); err != nil {
+ return nil, err
+ }
+ }
+
+ if columnCount > 0 {
+ err = mc.readUntilEOF()
+ }
+ }
+
+ return stmt, err
+}
+
+func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
+ // Number of ? should be same to len(args)
+ if strings.Count(query, "?") != len(args) {
+ return "", driver.ErrSkip
+ }
+
+ buf, err := mc.buf.takeCompleteBuffer()
+ if err != nil {
+ // can not take the buffer. Something must be wrong with the connection
+ errLog.Print(err)
+ return "", ErrInvalidConn
+ }
+ buf = buf[:0]
+ argPos := 0
+
+ for i := 0; i < len(query); i++ {
+ q := strings.IndexByte(query[i:], '?')
+ if q == -1 {
+ buf = append(buf, query[i:]...)
+ break
+ }
+ buf = append(buf, query[i:i+q]...)
+ i += q
+
+ arg := args[argPos]
+ argPos++
+
+ if arg == nil {
+ buf = append(buf, "NULL"...)
+ continue
+ }
+
+ switch v := arg.(type) {
+ case int64:
+ buf = strconv.AppendInt(buf, v, 10)
+ case uint64:
+ // Handle uint64 explicitly because our custom ConvertValue emits unsigned values
+ buf = strconv.AppendUint(buf, v, 10)
+ case float64:
+ buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
+ case bool:
+ if v {
+ buf = append(buf, '1')
+ } else {
+ buf = append(buf, '0')
+ }
+ case time.Time:
+ if v.IsZero() {
+ buf = append(buf, "'0000-00-00'"...)
+ } else {
+ buf = append(buf, '\'')
+ buf, err = appendDateTime(buf, v.In(mc.cfg.Loc))
+ if err != nil {
+ return "", err
+ }
+ buf = append(buf, '\'')
+ }
+ case json.RawMessage:
+ buf = append(buf, '\'')
+ if mc.status&statusNoBackslashEscapes == 0 {
+ buf = escapeBytesBackslash(buf, v)
+ } else {
+ buf = escapeBytesQuotes(buf, v)
+ }
+ buf = append(buf, '\'')
+ case []byte:
+ if v == nil {
+ buf = append(buf, "NULL"...)
+ } else {
+ buf = append(buf, "_binary'"...)
+ if mc.status&statusNoBackslashEscapes == 0 {
+ buf = escapeBytesBackslash(buf, v)
+ } else {
+ buf = escapeBytesQuotes(buf, v)
+ }
+ buf = append(buf, '\'')
+ }
+ case string:
+ buf = append(buf, '\'')
+ if mc.status&statusNoBackslashEscapes == 0 {
+ buf = escapeStringBackslash(buf, v)
+ } else {
+ buf = escapeStringQuotes(buf, v)
+ }
+ buf = append(buf, '\'')
+ default:
+ return "", driver.ErrSkip
+ }
+
+ if len(buf)+4 > mc.maxAllowedPacket {
+ return "", driver.ErrSkip
+ }
+ }
+ if argPos != len(args) {
+ return "", driver.ErrSkip
+ }
+ return string(buf), nil
+}
+
+func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ if mc.closed.IsSet() {
+ errLog.Print(ErrInvalidConn)
+ return nil, driver.ErrBadConn
+ }
+ if len(args) != 0 {
+ if !mc.cfg.InterpolateParams {
+ return nil, driver.ErrSkip
+ }
+ // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
+ prepared, err := mc.interpolateParams(query, args)
+ if err != nil {
+ return nil, err
+ }
+ query = prepared
+ }
+ mc.affectedRows = 0
+ mc.insertId = 0
+
+ err := mc.exec(query)
+ if err == nil {
+ return &mysqlResult{
+ affectedRows: int64(mc.affectedRows),
+ insertId: int64(mc.insertId),
+ }, err
+ }
+ return nil, mc.markBadConn(err)
+}
+
+// Internal function to execute commands
+func (mc *mysqlConn) exec(query string) error {
+ // Send command
+ if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
+ return mc.markBadConn(err)
+ }
+
+ // Read Result
+ resLen, err := mc.readResultSetHeaderPacket()
+ if err != nil {
+ return err
+ }
+
+ if resLen > 0 {
+ // columns
+ if err := mc.readUntilEOF(); err != nil {
+ return err
+ }
+
+ // rows
+ if err := mc.readUntilEOF(); err != nil {
+ return err
+ }
+ }
+
+ return mc.discardResults()
+}
+
+func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+ return mc.query(query, args)
+}
+
+func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
+ if mc.closed.IsSet() {
+ errLog.Print(ErrInvalidConn)
+ return nil, driver.ErrBadConn
+ }
+ if len(args) != 0 {
+ if !mc.cfg.InterpolateParams {
+ return nil, driver.ErrSkip
+ }
+ // try client-side prepare to reduce roundtrip
+ prepared, err := mc.interpolateParams(query, args)
+ if err != nil {
+ return nil, err
+ }
+ query = prepared
+ }
+ // Send command
+ err := mc.writeCommandPacketStr(comQuery, query)
+ if err == nil {
+ // Read Result
+ var resLen int
+ resLen, err = mc.readResultSetHeaderPacket()
+ if err == nil {
+ rows := new(textRows)
+ rows.mc = mc
+
+ if resLen == 0 {
+ rows.rs.done = true
+
+ switch err := rows.NextResultSet(); err {
+ case nil, io.EOF:
+ return rows, nil
+ default:
+ return nil, err
+ }
+ }
+
+ // Columns
+ rows.rs.columns, err = mc.readColumns(resLen)
+ return rows, err
+ }
+ }
+ return nil, mc.markBadConn(err)
+}
+
+// Gets the value of the given MySQL System Variable
+// The returned byte slice is only valid until the next read
+func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
+ // Send command
+ if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
+ return nil, err
+ }
+
+ // Read Result
+ resLen, err := mc.readResultSetHeaderPacket()
+ if err == nil {
+ rows := new(textRows)
+ rows.mc = mc
+ rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
+
+ if resLen > 0 {
+ // Columns
+ if err := mc.readUntilEOF(); err != nil {
+ return nil, err
+ }
+ }
+
+ dest := make([]driver.Value, resLen)
+ if err = rows.readRow(dest); err == nil {
+ return dest[0].([]byte), mc.readUntilEOF()
+ }
+ }
+ return nil, err
+}
+
+// finish is called when the query has canceled.
+func (mc *mysqlConn) cancel(err error) {
+ mc.canceled.Set(err)
+ mc.cleanup()
+}
+
+// finish is called when the query has succeeded.
+func (mc *mysqlConn) finish() {
+ if !mc.watching || mc.finished == nil {
+ return
+ }
+ select {
+ case mc.finished <- struct{}{}:
+ mc.watching = false
+ case <-mc.closech:
+ }
+}
+
+// Ping implements driver.Pinger interface
+func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
+ if mc.closed.IsSet() {
+ errLog.Print(ErrInvalidConn)
+ return driver.ErrBadConn
+ }
+
+ if err = mc.watchCancel(ctx); err != nil {
+ return
+ }
+ defer mc.finish()
+
+ if err = mc.writeCommandPacket(comPing); err != nil {
+ return mc.markBadConn(err)
+ }
+
+ return mc.readResultOK()
+}
+
+// BeginTx implements driver.ConnBeginTx interface
+func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
+ if mc.closed.IsSet() {
+ return nil, driver.ErrBadConn
+ }
+
+ if err := mc.watchCancel(ctx); err != nil {
+ return nil, err
+ }
+ defer mc.finish()
+
+ if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
+ level, err := mapIsolationLevel(opts.Isolation)
+ if err != nil {
+ return nil, err
+ }
+ err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return mc.begin(opts.ReadOnly)
+}
+
+func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
+ dargs, err := namedValueToValue(args)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := mc.watchCancel(ctx); err != nil {
+ return nil, err
+ }
+
+ rows, err := mc.query(query, dargs)
+ if err != nil {
+ mc.finish()
+ return nil, err
+ }
+ rows.finish = mc.finish
+ return rows, err
+}
+
+func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
+ dargs, err := namedValueToValue(args)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := mc.watchCancel(ctx); err != nil {
+ return nil, err
+ }
+ defer mc.finish()
+
+ return mc.Exec(query, dargs)
+}
+
+func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
+ if err := mc.watchCancel(ctx); err != nil {
+ return nil, err
+ }
+
+ stmt, err := mc.Prepare(query)
+ mc.finish()
+ if err != nil {
+ return nil, err
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ stmt.Close()
+ return nil, ctx.Err()
+ }
+ return stmt, nil
+}
+
+func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+ dargs, err := namedValueToValue(args)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := stmt.mc.watchCancel(ctx); err != nil {
+ return nil, err
+ }
+
+ rows, err := stmt.query(dargs)
+ if err != nil {
+ stmt.mc.finish()
+ return nil, err
+ }
+ rows.finish = stmt.mc.finish
+ return rows, err
+}
+
+func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+ dargs, err := namedValueToValue(args)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := stmt.mc.watchCancel(ctx); err != nil {
+ return nil, err
+ }
+ defer stmt.mc.finish()
+
+ return stmt.Exec(dargs)
+}
+
+func (mc *mysqlConn) watchCancel(ctx context.Context) error {
+ if mc.watching {
+ // Reach here if canceled,
+ // so the connection is already invalid
+ mc.cleanup()
+ return nil
+ }
+ // When ctx is already cancelled, don't watch it.
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ // When ctx is not cancellable, don't watch it.
+ if ctx.Done() == nil {
+ return nil
+ }
+ // When watcher is not alive, can't watch it.
+ if mc.watcher == nil {
+ return nil
+ }
+
+ mc.watching = true
+ mc.watcher <- ctx
+ return nil
+}
+
+func (mc *mysqlConn) startWatcher() {
+ watcher := make(chan context.Context, 1)
+ mc.watcher = watcher
+ finished := make(chan struct{})
+ mc.finished = finished
+ go func() {
+ for {
+ var ctx context.Context
+ select {
+ case ctx = <-watcher:
+ case <-mc.closech:
+ return
+ }
+
+ select {
+ case <-ctx.Done():
+ mc.cancel(ctx.Err())
+ case <-finished:
+ case <-mc.closech:
+ return
+ }
+ }
+ }()
+}
+
+func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
+ nv.Value, err = converter{}.ConvertValue(nv.Value)
+ return
+}
+
+// ResetSession implements driver.SessionResetter.
+// (From Go 1.10)
+func (mc *mysqlConn) ResetSession(ctx context.Context) error {
+ if mc.closed.IsSet() {
+ return driver.ErrBadConn
+ }
+ mc.reset = true
+ return nil
+}
+
+// IsValid implements driver.Validator interface
+// (From Go 1.15)
+func (mc *mysqlConn) IsValid() bool {
+ return !mc.closed.IsSet()
+}
diff --git a/vendor/mysql/connection_test.go b/vendor/mysql/connection_test.go
new file mode 100644
index 0000000..a6d6773
--- /dev/null
+++ b/vendor/mysql/connection_test.go
@@ -0,0 +1,203 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "context"
+ "database/sql/driver"
+ "encoding/json"
+ "errors"
+ "net"
+ "testing"
+)
+
+func TestInterpolateParams(t *testing.T) {
+ mc := &mysqlConn{
+ buf: newBuffer(nil),
+ maxAllowedPacket: maxPacketSize,
+ cfg: &Config{
+ InterpolateParams: true,
+ },
+ }
+
+ q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
+ if err != nil {
+ t.Errorf("Expected err=nil, got %#v", err)
+ return
+ }
+ expected := `SELECT 42+'gopher'`
+ if q != expected {
+ t.Errorf("Expected: %q\nGot: %q", expected, q)
+ }
+}
+
+func TestInterpolateParamsJSONRawMessage(t *testing.T) {
+ mc := &mysqlConn{
+ buf: newBuffer(nil),
+ maxAllowedPacket: maxPacketSize,
+ cfg: &Config{
+ InterpolateParams: true,
+ },
+ }
+
+ buf, err := json.Marshal(struct {
+ Value int `json:"value"`
+ }{Value: 42})
+ if err != nil {
+ t.Errorf("Expected err=nil, got %#v", err)
+ return
+ }
+ q, err := mc.interpolateParams("SELECT ?", []driver.Value{json.RawMessage(buf)})
+ if err != nil {
+ t.Errorf("Expected err=nil, got %#v", err)
+ return
+ }
+ expected := `SELECT '{\"value\":42}'`
+ if q != expected {
+ t.Errorf("Expected: %q\nGot: %q", expected, q)
+ }
+}
+
+func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
+ mc := &mysqlConn{
+ buf: newBuffer(nil),
+ maxAllowedPacket: maxPacketSize,
+ cfg: &Config{
+ InterpolateParams: true,
+ },
+ }
+
+ q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
+ if err != driver.ErrSkip {
+ t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
+ }
+}
+
+// We don't support placeholder in string literal for now.
+// https://github.com/go-sql-driver/mysql/pull/490
+func TestInterpolateParamsPlaceholderInString(t *testing.T) {
+ mc := &mysqlConn{
+ buf: newBuffer(nil),
+ maxAllowedPacket: maxPacketSize,
+ cfg: &Config{
+ InterpolateParams: true,
+ },
+ }
+
+ q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
+ // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
+ if err != driver.ErrSkip {
+ t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
+ }
+}
+
+func TestInterpolateParamsUint64(t *testing.T) {
+ mc := &mysqlConn{
+ buf: newBuffer(nil),
+ maxAllowedPacket: maxPacketSize,
+ cfg: &Config{
+ InterpolateParams: true,
+ },
+ }
+
+ q, err := mc.interpolateParams("SELECT ?", []driver.Value{uint64(42)})
+ if err != nil {
+ t.Errorf("Expected err=nil, got err=%#v, q=%#v", err, q)
+ }
+ if q != "SELECT 42" {
+ t.Errorf("Expected uint64 interpolation to work, got q=%#v", q)
+ }
+}
+
+func TestCheckNamedValue(t *testing.T) {
+ value := driver.NamedValue{Value: ^uint64(0)}
+ x := &mysqlConn{}
+ err := x.CheckNamedValue(&value)
+
+ if err != nil {
+ t.Fatal("uint64 high-bit not convertible", err)
+ }
+
+ if value.Value != ^uint64(0) {
+ t.Fatalf("uint64 high-bit converted, got %#v %T", value.Value, value.Value)
+ }
+}
+
+// TestCleanCancel tests passed context is cancelled at start.
+// No packet should be sent. Connection should keep current status.
+func TestCleanCancel(t *testing.T) {
+ mc := &mysqlConn{
+ closech: make(chan struct{}),
+ }
+ mc.startWatcher()
+ defer mc.cleanup()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ for i := 0; i < 3; i++ { // Repeat same behavior
+ err := mc.Ping(ctx)
+ if err != context.Canceled {
+ t.Errorf("expected context.Canceled, got %#v", err)
+ }
+
+ if mc.closed.IsSet() {
+ t.Error("expected mc is not closed, closed actually")
+ }
+
+ if mc.watching {
+ t.Error("expected watching is false, but true")
+ }
+ }
+}
+
+func TestPingMarkBadConnection(t *testing.T) {
+ nc := badConnection{err: errors.New("boom")}
+ ms := &mysqlConn{
+ netConn: nc,
+ buf: newBuffer(nc),
+ maxAllowedPacket: defaultMaxAllowedPacket,
+ }
+
+ err := ms.Ping(context.Background())
+
+ if err != driver.ErrBadConn {
+ t.Errorf("expected driver.ErrBadConn, got %#v", err)
+ }
+}
+
+func TestPingErrInvalidConn(t *testing.T) {
+ nc := badConnection{err: errors.New("failed to write"), n: 10}
+ ms := &mysqlConn{
+ netConn: nc,
+ buf: newBuffer(nc),
+ maxAllowedPacket: defaultMaxAllowedPacket,
+ closech: make(chan struct{}),
+ }
+
+ err := ms.Ping(context.Background())
+
+ if err != ErrInvalidConn {
+ t.Errorf("expected ErrInvalidConn, got %#v", err)
+ }
+}
+
+type badConnection struct {
+ n int
+ err error
+ net.Conn
+}
+
+func (bc badConnection) Write(b []byte) (n int, err error) {
+ return bc.n, bc.err
+}
+
+func (bc badConnection) Close() error {
+ return nil
+}
diff --git a/vendor/mysql/connector.go b/vendor/mysql/connector.go
new file mode 100644
index 0000000..d567b4e
--- /dev/null
+++ b/vendor/mysql/connector.go
@@ -0,0 +1,146 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "context"
+ "database/sql/driver"
+ "net"
+)
+
+type connector struct {
+ cfg *Config // immutable private copy.
+}
+
+// Connect implements driver.Connector interface.
+// Connect returns a connection to the database.
+func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
+ var err error
+
+ // New mysqlConn
+ mc := &mysqlConn{
+ maxAllowedPacket: maxPacketSize,
+ maxWriteSize: maxPacketSize - 1,
+ closech: make(chan struct{}),
+ cfg: c.cfg,
+ }
+ mc.parseTime = mc.cfg.ParseTime
+
+ // Connect to Server
+ dialsLock.RLock()
+ dial, ok := dials[mc.cfg.Net]
+ dialsLock.RUnlock()
+ if ok {
+ dctx := ctx
+ if mc.cfg.Timeout > 0 {
+ var cancel context.CancelFunc
+ dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
+ defer cancel()
+ }
+ mc.netConn, err = dial(dctx, mc.cfg.Addr)
+ } else {
+ nd := net.Dialer{Timeout: mc.cfg.Timeout}
+ mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ // Enable TCP Keepalives on TCP connections
+ if tc, ok := mc.netConn.(*net.TCPConn); ok {
+ if err := tc.SetKeepAlive(true); err != nil {
+ // Don't send COM_QUIT before handshake.
+ mc.netConn.Close()
+ mc.netConn = nil
+ return nil, err
+ }
+ }
+
+ // Call startWatcher for context support (From Go 1.8)
+ mc.startWatcher()
+ if err := mc.watchCancel(ctx); err != nil {
+ mc.cleanup()
+ return nil, err
+ }
+ defer mc.finish()
+
+ mc.buf = newBuffer(mc.netConn)
+
+ // Set I/O timeouts
+ mc.buf.timeout = mc.cfg.ReadTimeout
+ mc.writeTimeout = mc.cfg.WriteTimeout
+
+ // Reading Handshake Initialization Packet
+ authData, plugin, err := mc.readHandshakePacket()
+ if err != nil {
+ mc.cleanup()
+ return nil, err
+ }
+
+ if plugin == "" {
+ plugin = defaultAuthPlugin
+ }
+
+ // Send Client Authentication Packet
+ authResp, err := mc.auth(authData, plugin)
+ if err != nil {
+ // try the default auth plugin, if using the requested plugin failed
+ errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
+ plugin = defaultAuthPlugin
+ authResp, err = mc.auth(authData, plugin)
+ if err != nil {
+ mc.cleanup()
+ return nil, err
+ }
+ }
+ if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
+ mc.cleanup()
+ return nil, err
+ }
+
+ // Handle response to auth packet, switch methods if possible
+ if err = mc.handleAuthResult(authData, plugin); err != nil {
+ // Authentication failed and MySQL has already closed the connection
+ // (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
+ // Do not send COM_QUIT, just cleanup and return the error.
+ mc.cleanup()
+ return nil, err
+ }
+
+ if mc.cfg.MaxAllowedPacket > 0 {
+ mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
+ } else {
+ // Get max allowed packet size
+ maxap, err := mc.getSystemVar("max_allowed_packet")
+ if err != nil {
+ mc.Close()
+ return nil, err
+ }
+ mc.maxAllowedPacket = stringToInt(maxap) - 1
+ }
+ if mc.maxAllowedPacket < maxPacketSize {
+ mc.maxWriteSize = mc.maxAllowedPacket
+ }
+
+ // Handle DSN Params
+ err = mc.handleParams()
+ if err != nil {
+ mc.Close()
+ return nil, err
+ }
+
+ return mc, nil
+}
+
+// Driver implements driver.Connector interface.
+// Driver returns &MySQLDriver{}.
+func (c *connector) Driver() driver.Driver {
+ return &MySQLDriver{}
+}
diff --git a/vendor/mysql/connector_test.go b/vendor/mysql/connector_test.go
new file mode 100644
index 0000000..976903c
--- /dev/null
+++ b/vendor/mysql/connector_test.go
@@ -0,0 +1,30 @@
+package mysql
+
+import (
+ "context"
+ "net"
+ "testing"
+ "time"
+)
+
+func TestConnectorReturnsTimeout(t *testing.T) {
+ connector := &connector{&Config{
+ Net: "tcp",
+ Addr: "1.1.1.1:1234",
+ Timeout: 10 * time.Millisecond,
+ }}
+
+ _, err := connector.Connect(context.Background())
+ if err == nil {
+ t.Fatal("error expected")
+ }
+
+ if nerr, ok := err.(*net.OpError); ok {
+ expected := "dial tcp 1.1.1.1:1234: i/o timeout"
+ if nerr.Error() != expected {
+ t.Fatalf("expected %q, got %q", expected, nerr.Error())
+ }
+ } else {
+ t.Fatalf("expected %T, got %T", nerr, err)
+ }
+}
diff --git a/vendor/mysql/const.go b/vendor/mysql/const.go
new file mode 100644
index 0000000..b1e6b85
--- /dev/null
+++ b/vendor/mysql/const.go
@@ -0,0 +1,174 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+const (
+ defaultAuthPlugin = "mysql_native_password"
+ defaultMaxAllowedPacket = 4 << 20 // 4 MiB
+ minProtocolVersion = 10
+ maxPacketSize = 1<<24 - 1
+ timeFormat = "2006-01-02 15:04:05.999999"
+)
+
+// MySQL constants documentation:
+// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
+
+const (
+ iOK byte = 0x00
+ iAuthMoreData byte = 0x01
+ iLocalInFile byte = 0xfb
+ iEOF byte = 0xfe
+ iERR byte = 0xff
+)
+
+// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
+type clientFlag uint32
+
+const (
+ clientLongPassword clientFlag = 1 << iota
+ clientFoundRows
+ clientLongFlag
+ clientConnectWithDB
+ clientNoSchema
+ clientCompress
+ clientODBC
+ clientLocalFiles
+ clientIgnoreSpace
+ clientProtocol41
+ clientInteractive
+ clientSSL
+ clientIgnoreSIGPIPE
+ clientTransactions
+ clientReserved
+ clientSecureConn
+ clientMultiStatements
+ clientMultiResults
+ clientPSMultiResults
+ clientPluginAuth
+ clientConnectAttrs
+ clientPluginAuthLenEncClientData
+ clientCanHandleExpiredPasswords
+ clientSessionTrack
+ clientDeprecateEOF
+)
+
+const (
+ comQuit byte = iota + 1
+ comInitDB
+ comQuery
+ comFieldList
+ comCreateDB
+ comDropDB
+ comRefresh
+ comShutdown
+ comStatistics
+ comProcessInfo
+ comConnect
+ comProcessKill
+ comDebug
+ comPing
+ comTime
+ comDelayedInsert
+ comChangeUser
+ comBinlogDump
+ comTableDump
+ comConnectOut
+ comRegisterSlave
+ comStmtPrepare
+ comStmtExecute
+ comStmtSendLongData
+ comStmtClose
+ comStmtReset
+ comSetOption
+ comStmtFetch
+)
+
+// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
+type fieldType byte
+
+const (
+ fieldTypeDecimal fieldType = iota
+ fieldTypeTiny
+ fieldTypeShort
+ fieldTypeLong
+ fieldTypeFloat
+ fieldTypeDouble
+ fieldTypeNULL
+ fieldTypeTimestamp
+ fieldTypeLongLong
+ fieldTypeInt24
+ fieldTypeDate
+ fieldTypeTime
+ fieldTypeDateTime
+ fieldTypeYear
+ fieldTypeNewDate
+ fieldTypeVarChar
+ fieldTypeBit
+)
+const (
+ fieldTypeJSON fieldType = iota + 0xf5
+ fieldTypeNewDecimal
+ fieldTypeEnum
+ fieldTypeSet
+ fieldTypeTinyBLOB
+ fieldTypeMediumBLOB
+ fieldTypeLongBLOB
+ fieldTypeBLOB
+ fieldTypeVarString
+ fieldTypeString
+ fieldTypeGeometry
+)
+
+type fieldFlag uint16
+
+const (
+ flagNotNULL fieldFlag = 1 << iota
+ flagPriKey
+ flagUniqueKey
+ flagMultipleKey
+ flagBLOB
+ flagUnsigned
+ flagZeroFill
+ flagBinary
+ flagEnum
+ flagAutoIncrement
+ flagTimestamp
+ flagSet
+ flagUnknown1
+ flagUnknown2
+ flagUnknown3
+ flagUnknown4
+)
+
+// http://dev.mysql.com/doc/internals/en/status-flags.html
+type statusFlag uint16
+
+const (
+ statusInTrans statusFlag = 1 << iota
+ statusInAutocommit
+ statusReserved // Not in documentation
+ statusMoreResultsExists
+ statusNoGoodIndexUsed
+ statusNoIndexUsed
+ statusCursorExists
+ statusLastRowSent
+ statusDbDropped
+ statusNoBackslashEscapes
+ statusMetadataChanged
+ statusQueryWasSlow
+ statusPsOutParams
+ statusInTransReadonly
+ statusSessionStateChanged
+)
+
+const (
+ cachingSha2PasswordRequestPublicKey = 2
+ cachingSha2PasswordFastAuthSuccess = 3
+ cachingSha2PasswordPerformFullAuthentication = 4
+)
diff --git a/vendor/mysql/driver.go b/vendor/mysql/driver.go
new file mode 100644
index 0000000..c1bdf11
--- /dev/null
+++ b/vendor/mysql/driver.go
@@ -0,0 +1,107 @@
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// Package mysql provides a MySQL driver for Go's database/sql package.
+//
+// The driver should be used via the database/sql package:
+//
+// import "database/sql"
+// import _ "github.com/go-sql-driver/mysql"
+//
+// db, err := sql.Open("mysql", "user:password@/dbname")
+//
+// See https://github.com/go-sql-driver/mysql#usage for details
+package mysql
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "net"
+ "sync"
+)
+
+// MySQLDriver is exported to make the driver directly accessible.
+// In general the driver is used via the database/sql package.
+type MySQLDriver struct{}
+
+// DialFunc is a function which can be used to establish the network connection.
+// Custom dial functions must be registered with RegisterDial
+//
+// Deprecated: users should register a DialContextFunc instead
+type DialFunc func(addr string) (net.Conn, error)
+
+// DialContextFunc is a function which can be used to establish the network connection.
+// Custom dial functions must be registered with RegisterDialContext
+type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error)
+
+var (
+ dialsLock sync.RWMutex
+ dials map[string]DialContextFunc
+)
+
+// RegisterDialContext registers a custom dial function. It can then be used by the
+// network address mynet(addr), where mynet is the registered new network.
+// The current context for the connection and its address is passed to the dial function.
+func RegisterDialContext(net string, dial DialContextFunc) {
+ dialsLock.Lock()
+ defer dialsLock.Unlock()
+ if dials == nil {
+ dials = make(map[string]DialContextFunc)
+ }
+ dials[net] = dial
+}
+
+// RegisterDial registers a custom dial function. It can then be used by the
+// network address mynet(addr), where mynet is the registered new network.
+// addr is passed as a parameter to the dial function.
+//
+// Deprecated: users should call RegisterDialContext instead
+func RegisterDial(network string, dial DialFunc) {
+ RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) {
+ return dial(addr)
+ })
+}
+
+// Open new Connection.
+// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
+// the DSN string is formatted
+func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
+ cfg, err := ParseDSN(dsn)
+ if err != nil {
+ return nil, err
+ }
+ c := &connector{
+ cfg: cfg,
+ }
+ return c.Connect(context.Background())
+}
+
+func init() {
+ sql.Register("mysql", &MySQLDriver{})
+}
+
+// NewConnector returns new driver.Connector.
+func NewConnector(cfg *Config) (driver.Connector, error) {
+ cfg = cfg.Clone()
+ // normalize the contents of cfg so calls to NewConnector have the same
+ // behavior as MySQLDriver.OpenConnector
+ if err := cfg.normalize(); err != nil {
+ return nil, err
+ }
+ return &connector{cfg: cfg}, nil
+}
+
+// OpenConnector implements driver.DriverContext.
+func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
+ cfg, err := ParseDSN(dsn)
+ if err != nil {
+ return nil, err
+ }
+ return &connector{
+ cfg: cfg,
+ }, nil
+}
diff --git a/vendor/mysql/driver_test.go b/vendor/mysql/driver_test.go
new file mode 100644
index 0000000..54f7cd1
--- /dev/null
+++ b/vendor/mysql/driver_test.go
@@ -0,0 +1,3211 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "database/sql"
+ "database/sql/driver"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "math"
+ "net"
+ "net/url"
+ "os"
+ "reflect"
+ "runtime"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+// Ensure that all the driver interfaces are implemented
+var (
+ _ driver.Rows = &binaryRows{}
+ _ driver.Rows = &textRows{}
+)
+
+var (
+ user string
+ pass string
+ prot string
+ addr string
+ dbname string
+ dsn string
+ netAddr string
+ available bool
+)
+
+var (
+ tDate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC)
+ sDate = "2012-06-14"
+ tDateTime = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)
+ sDateTime = "2011-11-20 21:27:37"
+ tDate0 = time.Time{}
+ sDate0 = "0000-00-00"
+ sDateTime0 = "0000-00-00 00:00:00"
+)
+
+// See https://github.com/go-sql-driver/mysql/wiki/Testing
+func init() {
+ // get environment variables
+ env := func(key, defaultValue string) string {
+ if value := os.Getenv(key); value != "" {
+ return value
+ }
+ return defaultValue
+ }
+ user = env("MYSQL_TEST_USER", "root")
+ pass = env("MYSQL_TEST_PASS", "")
+ prot = env("MYSQL_TEST_PROT", "tcp")
+ addr = env("MYSQL_TEST_ADDR", "localhost:3306")
+ dbname = env("MYSQL_TEST_DBNAME", "gotest")
+ netAddr = fmt.Sprintf("%s(%s)", prot, addr)
+ dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname)
+ c, err := net.Dial(prot, addr)
+ if err == nil {
+ available = true
+ c.Close()
+ }
+}
+
+type DBTest struct {
+ *testing.T
+ db *sql.DB
+}
+
+type netErrorMock struct {
+ temporary bool
+ timeout bool
+}
+
+func (e netErrorMock) Temporary() bool {
+ return e.temporary
+}
+
+func (e netErrorMock) Timeout() bool {
+ return e.timeout
+}
+
+func (e netErrorMock) Error() string {
+ return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout)
+}
+
+func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
+ if !available {
+ t.Skipf("MySQL server not running on %s", netAddr)
+ }
+
+ dsn += "&multiStatements=true"
+ var db *sql.DB
+ if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
+ db, err = sql.Open("mysql", dsn)
+ if err != nil {
+ t.Fatalf("error connecting: %s", err.Error())
+ }
+ defer db.Close()
+ }
+
+ dbt := &DBTest{t, db}
+ for _, test := range tests {
+ test(dbt)
+ dbt.db.Exec("DROP TABLE IF EXISTS test")
+ }
+}
+
+func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
+ if !available {
+ t.Skipf("MySQL server not running on %s", netAddr)
+ }
+
+ db, err := sql.Open("mysql", dsn)
+ if err != nil {
+ t.Fatalf("error connecting: %s", err.Error())
+ }
+ defer db.Close()
+
+ db.Exec("DROP TABLE IF EXISTS test")
+
+ dsn2 := dsn + "&interpolateParams=true"
+ var db2 *sql.DB
+ if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
+ db2, err = sql.Open("mysql", dsn2)
+ if err != nil {
+ t.Fatalf("error connecting: %s", err.Error())
+ }
+ defer db2.Close()
+ }
+
+ dsn3 := dsn + "&multiStatements=true"
+ var db3 *sql.DB
+ if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
+ db3, err = sql.Open("mysql", dsn3)
+ if err != nil {
+ t.Fatalf("error connecting: %s", err.Error())
+ }
+ defer db3.Close()
+ }
+
+ dbt := &DBTest{t, db}
+ dbt2 := &DBTest{t, db2}
+ dbt3 := &DBTest{t, db3}
+ for _, test := range tests {
+ test(dbt)
+ dbt.db.Exec("DROP TABLE IF EXISTS test")
+ if db2 != nil {
+ test(dbt2)
+ dbt2.db.Exec("DROP TABLE IF EXISTS test")
+ }
+ if db3 != nil {
+ test(dbt3)
+ dbt3.db.Exec("DROP TABLE IF EXISTS test")
+ }
+ }
+}
+
+func (dbt *DBTest) fail(method, query string, err error) {
+ if len(query) > 300 {
+ query = "[query too large to print]"
+ }
+ dbt.Fatalf("error on %s %s: %s", method, query, err.Error())
+}
+
+func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
+ res, err := dbt.db.Exec(query, args...)
+ if err != nil {
+ dbt.fail("exec", query, err)
+ }
+ return res
+}
+
+func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) {
+ rows, err := dbt.db.Query(query, args...)
+ if err != nil {
+ dbt.fail("query", query, err)
+ }
+ return rows
+}
+
+func maybeSkip(t *testing.T, err error, skipErrno uint16) {
+ mySQLErr, ok := err.(*MySQLError)
+ if !ok {
+ return
+ }
+
+ if mySQLErr.Number == skipErrno {
+ t.Skipf("skipping test for error: %v", err)
+ }
+}
+
+func TestEmptyQuery(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ // just a comment, no query
+ rows := dbt.mustQuery("--")
+ defer rows.Close()
+ // will hang before #255
+ if rows.Next() {
+ dbt.Errorf("next on rows must be false")
+ }
+ })
+}
+
+func TestCRUD(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ // Create Table
+ dbt.mustExec("CREATE TABLE test (value BOOL)")
+
+ // Test for unexpected data
+ var out bool
+ rows := dbt.mustQuery("SELECT * FROM test")
+ if rows.Next() {
+ dbt.Error("unexpected data in empty table")
+ }
+ rows.Close()
+
+ // Create Data
+ res := dbt.mustExec("INSERT INTO test VALUES (1)")
+ count, err := res.RowsAffected()
+ if err != nil {
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+ }
+ if count != 1 {
+ dbt.Fatalf("expected 1 affected row, got %d", count)
+ }
+
+ id, err := res.LastInsertId()
+ if err != nil {
+ dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error())
+ }
+ if id != 0 {
+ dbt.Fatalf("expected InsertId 0, got %d", id)
+ }
+
+ // Read
+ rows = dbt.mustQuery("SELECT value FROM test")
+ if rows.Next() {
+ rows.Scan(&out)
+ if true != out {
+ dbt.Errorf("true != %t", out)
+ }
+
+ if rows.Next() {
+ dbt.Error("unexpected data")
+ }
+ } else {
+ dbt.Error("no data")
+ }
+ rows.Close()
+
+ // Update
+ res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true)
+ count, err = res.RowsAffected()
+ if err != nil {
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+ }
+ if count != 1 {
+ dbt.Fatalf("expected 1 affected row, got %d", count)
+ }
+
+ // Check Update
+ rows = dbt.mustQuery("SELECT value FROM test")
+ if rows.Next() {
+ rows.Scan(&out)
+ if false != out {
+ dbt.Errorf("false != %t", out)
+ }
+
+ if rows.Next() {
+ dbt.Error("unexpected data")
+ }
+ } else {
+ dbt.Error("no data")
+ }
+ rows.Close()
+
+ // Delete
+ res = dbt.mustExec("DELETE FROM test WHERE value = ?", false)
+ count, err = res.RowsAffected()
+ if err != nil {
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+ }
+ if count != 1 {
+ dbt.Fatalf("expected 1 affected row, got %d", count)
+ }
+
+ // Check for unexpected rows
+ res = dbt.mustExec("DELETE FROM test")
+ count, err = res.RowsAffected()
+ if err != nil {
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+ }
+ if count != 0 {
+ dbt.Fatalf("expected 0 affected row, got %d", count)
+ }
+ })
+}
+
+func TestMultiQuery(t *testing.T) {
+ runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+ // Create Table
+ dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ")
+
+ // Create Data
+ res := dbt.mustExec("INSERT INTO test VALUES (1, 1)")
+ count, err := res.RowsAffected()
+ if err != nil {
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+ }
+ if count != 1 {
+ dbt.Fatalf("expected 1 affected row, got %d", count)
+ }
+
+ // Update
+ res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;")
+ count, err = res.RowsAffected()
+ if err != nil {
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+ }
+ if count != 1 {
+ dbt.Fatalf("expected 1 affected row, got %d", count)
+ }
+
+ // Read
+ var out int
+ rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;")
+ if rows.Next() {
+ rows.Scan(&out)
+ if 5 != out {
+ dbt.Errorf("5 != %d", out)
+ }
+
+ if rows.Next() {
+ dbt.Error("unexpected data")
+ }
+ } else {
+ dbt.Error("no data")
+ }
+ rows.Close()
+
+ })
+}
+
+func TestInt(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}
+ in := int64(42)
+ var out int64
+ var rows *sql.Rows
+
+ // SIGNED
+ for _, v := range types {
+ dbt.mustExec("CREATE TABLE test (value " + v + ")")
+
+ dbt.mustExec("INSERT INTO test VALUES (?)", in)
+
+ rows = dbt.mustQuery("SELECT value FROM test")
+ if rows.Next() {
+ rows.Scan(&out)
+ if in != out {
+ dbt.Errorf("%s: %d != %d", v, in, out)
+ }
+ } else {
+ dbt.Errorf("%s: no data", v)
+ }
+ rows.Close()
+
+ dbt.mustExec("DROP TABLE IF EXISTS test")
+ }
+
+ // UNSIGNED ZEROFILL
+ for _, v := range types {
+ dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)")
+
+ dbt.mustExec("INSERT INTO test VALUES (?)", in)
+
+ rows = dbt.mustQuery("SELECT value FROM test")
+ if rows.Next() {
+ rows.Scan(&out)
+ if in != out {
+ dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out)
+ }
+ } else {
+ dbt.Errorf("%s ZEROFILL: no data", v)
+ }
+ rows.Close()
+
+ dbt.mustExec("DROP TABLE IF EXISTS test")
+ }
+ })
+}
+
+func TestFloat32(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ types := [2]string{"FLOAT", "DOUBLE"}
+ in := float32(42.23)
+ var out float32
+ var rows *sql.Rows
+ for _, v := range types {
+ dbt.mustExec("CREATE TABLE test (value " + v + ")")
+ dbt.mustExec("INSERT INTO test VALUES (?)", in)
+ rows = dbt.mustQuery("SELECT value FROM test")
+ if rows.Next() {
+ rows.Scan(&out)
+ if in != out {
+ dbt.Errorf("%s: %g != %g", v, in, out)
+ }
+ } else {
+ dbt.Errorf("%s: no data", v)
+ }
+ rows.Close()
+ dbt.mustExec("DROP TABLE IF EXISTS test")
+ }
+ })
+}
+
+func TestFloat64(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ types := [2]string{"FLOAT", "DOUBLE"}
+ var expected float64 = 42.23
+ var out float64
+ var rows *sql.Rows
+ for _, v := range types {
+ dbt.mustExec("CREATE TABLE test (value " + v + ")")
+ dbt.mustExec("INSERT INTO test VALUES (42.23)")
+ rows = dbt.mustQuery("SELECT value FROM test")
+ if rows.Next() {
+ rows.Scan(&out)
+ if expected != out {
+ dbt.Errorf("%s: %g != %g", v, expected, out)
+ }
+ } else {
+ dbt.Errorf("%s: no data", v)
+ }
+ rows.Close()
+ dbt.mustExec("DROP TABLE IF EXISTS test")
+ }
+ })
+}
+
+func TestFloat64Placeholder(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ types := [2]string{"FLOAT", "DOUBLE"}
+ var expected float64 = 42.23
+ var out float64
+ var rows *sql.Rows
+ for _, v := range types {
+ dbt.mustExec("CREATE TABLE test (id int, value " + v + ")")
+ dbt.mustExec("INSERT INTO test VALUES (1, 42.23)")
+ rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1)
+ if rows.Next() {
+ rows.Scan(&out)
+ if expected != out {
+ dbt.Errorf("%s: %g != %g", v, expected, out)
+ }
+ } else {
+ dbt.Errorf("%s: no data", v)
+ }
+ rows.Close()
+ dbt.mustExec("DROP TABLE IF EXISTS test")
+ }
+ })
+}
+
+func TestString(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"}
+ in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย"
+ var out string
+ var rows *sql.Rows
+
+ for _, v := range types {
+ dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8")
+
+ dbt.mustExec("INSERT INTO test VALUES (?)", in)
+
+ rows = dbt.mustQuery("SELECT value FROM test")
+ if rows.Next() {
+ rows.Scan(&out)
+ if in != out {
+ dbt.Errorf("%s: %s != %s", v, in, out)
+ }
+ } else {
+ dbt.Errorf("%s: no data", v)
+ }
+ rows.Close()
+
+ dbt.mustExec("DROP TABLE IF EXISTS test")
+ }
+
+ // BLOB
+ dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8")
+
+ id := 2
+ in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " +
+ "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " +
+ "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " +
+ "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " +
+ "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " +
+ "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " +
+ "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " +
+ "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet."
+ dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in)
+
+ err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out)
+ if err != nil {
+ dbt.Fatalf("Error on BLOB-Query: %s", err.Error())
+ } else if out != in {
+ dbt.Errorf("BLOB: %s != %s", in, out)
+ }
+ })
+}
+
+func TestRawBytes(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ v1 := []byte("aaa")
+ v2 := []byte("bbb")
+ rows := dbt.mustQuery("SELECT ?, ?", v1, v2)
+ defer rows.Close()
+ if rows.Next() {
+ var o1, o2 sql.RawBytes
+ if err := rows.Scan(&o1, &o2); err != nil {
+ dbt.Errorf("Got error: %v", err)
+ }
+ if !bytes.Equal(v1, o1) {
+ dbt.Errorf("expected %v, got %v", v1, o1)
+ }
+ if !bytes.Equal(v2, o2) {
+ dbt.Errorf("expected %v, got %v", v2, o2)
+ }
+ // https://github.com/go-sql-driver/mysql/issues/765
+ // Appending to RawBytes shouldn't overwrite next RawBytes.
+ o1 = append(o1, "xyzzy"...)
+ if !bytes.Equal(v2, o2) {
+ dbt.Errorf("expected %v, got %v", v2, o2)
+ }
+ } else {
+ dbt.Errorf("no data")
+ }
+ })
+}
+
+func TestRawMessage(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ v1 := json.RawMessage("{}")
+ v2 := json.RawMessage("[]")
+ rows := dbt.mustQuery("SELECT ?, ?", v1, v2)
+ defer rows.Close()
+ if rows.Next() {
+ var o1, o2 json.RawMessage
+ if err := rows.Scan(&o1, &o2); err != nil {
+ dbt.Errorf("Got error: %v", err)
+ }
+ if !bytes.Equal(v1, o1) {
+ dbt.Errorf("expected %v, got %v", v1, o1)
+ }
+ if !bytes.Equal(v2, o2) {
+ dbt.Errorf("expected %v, got %v", v2, o2)
+ }
+ } else {
+ dbt.Errorf("no data")
+ }
+ })
+}
+
+type testValuer struct {
+ value string
+}
+
+func (tv testValuer) Value() (driver.Value, error) {
+ return tv.value, nil
+}
+
+func TestValuer(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ in := testValuer{"a_value"}
+ var out string
+ var rows *sql.Rows
+
+ dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8")
+ dbt.mustExec("INSERT INTO test VALUES (?)", in)
+ rows = dbt.mustQuery("SELECT value FROM test")
+ if rows.Next() {
+ rows.Scan(&out)
+ if in.value != out {
+ dbt.Errorf("Valuer: %v != %s", in, out)
+ }
+ } else {
+ dbt.Errorf("Valuer: no data")
+ }
+ rows.Close()
+
+ dbt.mustExec("DROP TABLE IF EXISTS test")
+ })
+}
+
+type testValuerWithValidation struct {
+ value string
+}
+
+func (tv testValuerWithValidation) Value() (driver.Value, error) {
+ if len(tv.value) == 0 {
+ return nil, fmt.Errorf("Invalid string valuer. Value must not be empty")
+ }
+
+ return tv.value, nil
+}
+
+func TestValuerWithValidation(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ in := testValuerWithValidation{"a_value"}
+ var out string
+ var rows *sql.Rows
+
+ dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8")
+ dbt.mustExec("INSERT INTO testValuer VALUES (?)", in)
+
+ rows = dbt.mustQuery("SELECT value FROM testValuer")
+ defer rows.Close()
+
+ if rows.Next() {
+ rows.Scan(&out)
+ if in.value != out {
+ dbt.Errorf("Valuer: %v != %s", in, out)
+ }
+ } else {
+ dbt.Errorf("Valuer: no data")
+ }
+
+ if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil {
+ dbt.Errorf("Failed to check valuer error")
+ }
+
+ if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil {
+ dbt.Errorf("Failed to check nil")
+ }
+
+ if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil {
+ dbt.Errorf("Failed to check not valuer")
+ }
+
+ dbt.mustExec("DROP TABLE IF EXISTS testValuer")
+ })
+}
+
+type timeTests struct {
+ dbtype string
+ tlayout string
+ tests []timeTest
+}
+
+type timeTest struct {
+ s string // leading "!": do not use t as value in queries
+ t time.Time
+}
+
+type timeMode byte
+
+func (t timeMode) String() string {
+ switch t {
+ case binaryString:
+ return "binary:string"
+ case binaryTime:
+ return "binary:time.Time"
+ case textString:
+ return "text:string"
+ }
+ panic("unsupported timeMode")
+}
+
+func (t timeMode) Binary() bool {
+ switch t {
+ case binaryString, binaryTime:
+ return true
+ }
+ return false
+}
+
+const (
+ binaryString timeMode = iota
+ binaryTime
+ textString
+)
+
+func (t timeTest) genQuery(dbtype string, mode timeMode) string {
+ var inner string
+ if mode.Binary() {
+ inner = "?"
+ } else {
+ inner = `"%s"`
+ }
+ return `SELECT cast(` + inner + ` as ` + dbtype + `)`
+}
+
+func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, mode timeMode) {
+ var rows *sql.Rows
+ query := t.genQuery(dbtype, mode)
+ switch mode {
+ case binaryString:
+ rows = dbt.mustQuery(query, t.s)
+ case binaryTime:
+ rows = dbt.mustQuery(query, t.t)
+ case textString:
+ query = fmt.Sprintf(query, t.s)
+ rows = dbt.mustQuery(query)
+ default:
+ panic("unsupported mode")
+ }
+ defer rows.Close()
+ var err error
+ if !rows.Next() {
+ err = rows.Err()
+ if err == nil {
+ err = fmt.Errorf("no data")
+ }
+ dbt.Errorf("%s [%s]: %s", dbtype, mode, err)
+ return
+ }
+ var dst interface{}
+ err = rows.Scan(&dst)
+ if err != nil {
+ dbt.Errorf("%s [%s]: %s", dbtype, mode, err)
+ return
+ }
+ switch val := dst.(type) {
+ case []uint8:
+ str := string(val)
+ if str == t.s {
+ return
+ }
+ if mode.Binary() && dbtype == "DATETIME" && len(str) == 26 && str[:19] == t.s {
+ // a fix mainly for TravisCI:
+ // accept full microsecond resolution in result for DATETIME columns
+ // where the binary protocol was used
+ return
+ }
+ dbt.Errorf("%s [%s] to string: expected %q, got %q",
+ dbtype, mode,
+ t.s, str,
+ )
+ case time.Time:
+ if val == t.t {
+ return
+ }
+ dbt.Errorf("%s [%s] to string: expected %q, got %q",
+ dbtype, mode,
+ t.s, val.Format(tlayout),
+ )
+ default:
+ fmt.Printf("%#v\n", []interface{}{dbtype, tlayout, mode, t.s, t.t})
+ dbt.Errorf("%s [%s]: unhandled type %T (is '%v')",
+ dbtype, mode,
+ val, val,
+ )
+ }
+}
+
+func TestDateTime(t *testing.T) {
+ afterTime := func(t time.Time, d string) time.Time {
+ dur, err := time.ParseDuration(d)
+ if err != nil {
+ panic(err)
+ }
+ return t.Add(dur)
+ }
+ // NOTE: MySQL rounds DATETIME(x) up - but that's not included in the tests
+ format := "2006-01-02 15:04:05.999999"
+ t0 := time.Time{}
+ tstr0 := "0000-00-00 00:00:00.000000"
+ testcases := []timeTests{
+ {"DATE", format[:10], []timeTest{
+ {t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)},
+ {t: t0, s: tstr0[:10]},
+ }},
+ {"DATETIME", format[:19], []timeTest{
+ {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)},
+ {t: t0, s: tstr0[:19]},
+ }},
+ {"DATETIME(0)", format[:21], []timeTest{
+ {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)},
+ {t: t0, s: tstr0[:19]},
+ }},
+ {"DATETIME(1)", format[:21], []timeTest{
+ {t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)},
+ {t: t0, s: tstr0[:21]},
+ }},
+ {"DATETIME(6)", format, []timeTest{
+ {t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)},
+ {t: t0, s: tstr0},
+ }},
+ {"TIME", format[11:19], []timeTest{
+ {t: afterTime(t0, "12345s")},
+ {s: "!-12:34:56"},
+ {s: "!-838:59:59"},
+ {s: "!838:59:59"},
+ {t: t0, s: tstr0[11:19]},
+ }},
+ {"TIME(0)", format[11:19], []timeTest{
+ {t: afterTime(t0, "12345s")},
+ {s: "!-12:34:56"},
+ {s: "!-838:59:59"},
+ {s: "!838:59:59"},
+ {t: t0, s: tstr0[11:19]},
+ }},
+ {"TIME(1)", format[11:21], []timeTest{
+ {t: afterTime(t0, "12345600ms")},
+ {s: "!-12:34:56.7"},
+ {s: "!-838:59:58.9"},
+ {s: "!838:59:58.9"},
+ {t: t0, s: tstr0[11:21]},
+ }},
+ {"TIME(6)", format[11:], []timeTest{
+ {t: afterTime(t0, "1234567890123000ns")},
+ {s: "!-12:34:56.789012"},
+ {s: "!-838:59:58.999999"},
+ {s: "!838:59:58.999999"},
+ {t: t0, s: tstr0[11:]},
+ }},
+ }
+ dsns := []string{
+ dsn + "&parseTime=true",
+ dsn + "&parseTime=false",
+ }
+ for _, testdsn := range dsns {
+ runTests(t, testdsn, func(dbt *DBTest) {
+ microsecsSupported := false
+ zeroDateSupported := false
+ var rows *sql.Rows
+ var err error
+ rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`)
+ if err == nil {
+ rows.Scan(&microsecsSupported)
+ rows.Close()
+ }
+ rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`)
+ if err == nil {
+ rows.Scan(&zeroDateSupported)
+ rows.Close()
+ }
+ for _, setups := range testcases {
+ if t := setups.dbtype; !microsecsSupported && t[len(t)-1:] == ")" {
+ // skip fractional second tests if unsupported by server
+ continue
+ }
+ for _, setup := range setups.tests {
+ allowBinTime := true
+ if setup.s == "" {
+ // fill time string wherever Go can reliable produce it
+ setup.s = setup.t.Format(setups.tlayout)
+ } else if setup.s[0] == '!' {
+ // skip tests using setup.t as source in queries
+ allowBinTime = false
+ // fix setup.s - remove the "!"
+ setup.s = setup.s[1:]
+ }
+ if !zeroDateSupported && setup.s == tstr0[:len(setup.s)] {
+ // skip disallowed 0000-00-00 date
+ continue
+ }
+ setup.run(dbt, setups.dbtype, setups.tlayout, textString)
+ setup.run(dbt, setups.dbtype, setups.tlayout, binaryString)
+ if allowBinTime {
+ setup.run(dbt, setups.dbtype, setups.tlayout, binaryTime)
+ }
+ }
+ }
+ })
+ }
+}
+
+func TestTimestampMicros(t *testing.T) {
+ format := "2006-01-02 15:04:05.999999"
+ f0 := format[:19]
+ f1 := format[:21]
+ f6 := format[:26]
+ runTests(t, dsn, func(dbt *DBTest) {
+ // check if microseconds are supported.
+ // Do not use timestamp(x) for that check - before 5.5.6, x would mean display width
+ // and not precision.
+ // Se last paragraph at http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html
+ microsecsSupported := false
+ if rows, err := dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`); err == nil {
+ rows.Scan(&microsecsSupported)
+ rows.Close()
+ }
+ if !microsecsSupported {
+ // skip test
+ return
+ }
+ _, err := dbt.db.Exec(`
+ CREATE TABLE test (
+ value0 TIMESTAMP NOT NULL DEFAULT '` + f0 + `',
+ value1 TIMESTAMP(1) NOT NULL DEFAULT '` + f1 + `',
+ value6 TIMESTAMP(6) NOT NULL DEFAULT '` + f6 + `'
+ )`,
+ )
+ if err != nil {
+ dbt.Error(err)
+ }
+ defer dbt.mustExec("DROP TABLE IF EXISTS test")
+ dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6)
+ var res0, res1, res6 string
+ rows := dbt.mustQuery("SELECT * FROM test")
+ defer rows.Close()
+ if !rows.Next() {
+ dbt.Errorf("test contained no selectable values")
+ }
+ err = rows.Scan(&res0, &res1, &res6)
+ if err != nil {
+ dbt.Error(err)
+ }
+ if res0 != f0 {
+ dbt.Errorf("expected %q, got %q", f0, res0)
+ }
+ if res1 != f1 {
+ dbt.Errorf("expected %q, got %q", f1, res1)
+ }
+ if res6 != f6 {
+ dbt.Errorf("expected %q, got %q", f6, res6)
+ }
+ })
+}
+
+func TestNULL(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ nullStmt, err := dbt.db.Prepare("SELECT NULL")
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ defer nullStmt.Close()
+
+ nonNullStmt, err := dbt.db.Prepare("SELECT 1")
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ defer nonNullStmt.Close()
+
+ // NullBool
+ var nb sql.NullBool
+ // Invalid
+ if err = nullStmt.QueryRow().Scan(&nb); err != nil {
+ dbt.Fatal(err)
+ }
+ if nb.Valid {
+ dbt.Error("valid NullBool which should be invalid")
+ }
+ // Valid
+ if err = nonNullStmt.QueryRow().Scan(&nb); err != nil {
+ dbt.Fatal(err)
+ }
+ if !nb.Valid {
+ dbt.Error("invalid NullBool which should be valid")
+ } else if nb.Bool != true {
+ dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool)
+ }
+
+ // NullFloat64
+ var nf sql.NullFloat64
+ // Invalid
+ if err = nullStmt.QueryRow().Scan(&nf); err != nil {
+ dbt.Fatal(err)
+ }
+ if nf.Valid {
+ dbt.Error("valid NullFloat64 which should be invalid")
+ }
+ // Valid
+ if err = nonNullStmt.QueryRow().Scan(&nf); err != nil {
+ dbt.Fatal(err)
+ }
+ if !nf.Valid {
+ dbt.Error("invalid NullFloat64 which should be valid")
+ } else if nf.Float64 != float64(1) {
+ dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64)
+ }
+
+ // NullInt64
+ var ni sql.NullInt64
+ // Invalid
+ if err = nullStmt.QueryRow().Scan(&ni); err != nil {
+ dbt.Fatal(err)
+ }
+ if ni.Valid {
+ dbt.Error("valid NullInt64 which should be invalid")
+ }
+ // Valid
+ if err = nonNullStmt.QueryRow().Scan(&ni); err != nil {
+ dbt.Fatal(err)
+ }
+ if !ni.Valid {
+ dbt.Error("invalid NullInt64 which should be valid")
+ } else if ni.Int64 != int64(1) {
+ dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64)
+ }
+
+ // NullString
+ var ns sql.NullString
+ // Invalid
+ if err = nullStmt.QueryRow().Scan(&ns); err != nil {
+ dbt.Fatal(err)
+ }
+ if ns.Valid {
+ dbt.Error("valid NullString which should be invalid")
+ }
+ // Valid
+ if err = nonNullStmt.QueryRow().Scan(&ns); err != nil {
+ dbt.Fatal(err)
+ }
+ if !ns.Valid {
+ dbt.Error("invalid NullString which should be valid")
+ } else if ns.String != `1` {
+ dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)")
+ }
+
+ // nil-bytes
+ var b []byte
+ // Read nil
+ if err = nullStmt.QueryRow().Scan(&b); err != nil {
+ dbt.Fatal(err)
+ }
+ if b != nil {
+ dbt.Error("non-nil []byte which should be nil")
+ }
+ // Read non-nil
+ if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
+ dbt.Fatal(err)
+ }
+ if b == nil {
+ dbt.Error("nil []byte which should be non-nil")
+ }
+ // Insert nil
+ b = nil
+ success := false
+ if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil {
+ dbt.Fatal(err)
+ }
+ if !success {
+ dbt.Error("inserting []byte(nil) as NULL failed")
+ }
+ // Check input==output with input==nil
+ b = nil
+ if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
+ dbt.Fatal(err)
+ }
+ if b != nil {
+ dbt.Error("non-nil echo from nil input")
+ }
+ // Check input==output with input!=nil
+ b = []byte("")
+ if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
+ dbt.Fatal(err)
+ }
+ if b == nil {
+ dbt.Error("nil echo from non-nil input")
+ }
+
+ // Insert NULL
+ dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)")
+
+ dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2)
+
+ var out interface{}
+ rows := dbt.mustQuery("SELECT * FROM test")
+ defer rows.Close()
+ if rows.Next() {
+ rows.Scan(&out)
+ if out != nil {
+ dbt.Errorf("%v != nil", out)
+ }
+ } else {
+ dbt.Error("no data")
+ }
+ })
+}
+
+func TestUint64(t *testing.T) {
+ const (
+ u0 = uint64(0)
+ uall = ^u0
+ uhigh = uall >> 1
+ utop = ^uhigh
+ s0 = int64(0)
+ sall = ^s0
+ shigh = int64(uhigh)
+ stop = ^shigh
+ )
+ runTests(t, dsn, func(dbt *DBTest) {
+ stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`)
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ defer stmt.Close()
+ row := stmt.QueryRow(
+ u0, uhigh, utop, uall,
+ s0, shigh, stop, sall,
+ )
+
+ var ua, ub, uc, ud uint64
+ var sa, sb, sc, sd int64
+
+ err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd)
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ switch {
+ case ua != u0,
+ ub != uhigh,
+ uc != utop,
+ ud != uall,
+ sa != s0,
+ sb != shigh,
+ sc != stop,
+ sd != sall:
+ dbt.Fatal("unexpected result value")
+ }
+ })
+}
+
+func TestLongData(t *testing.T) {
+ runTests(t, dsn+"&maxAllowedPacket=0", func(dbt *DBTest) {
+ var maxAllowedPacketSize int
+ err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize)
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ maxAllowedPacketSize--
+
+ // don't get too ambitious
+ if maxAllowedPacketSize > 1<<25 {
+ maxAllowedPacketSize = 1 << 25
+ }
+
+ dbt.mustExec("CREATE TABLE test (value LONGBLOB)")
+
+ in := strings.Repeat(`a`, maxAllowedPacketSize+1)
+ var out string
+ var rows *sql.Rows
+
+ // Long text data
+ const nonDataQueryLen = 28 // length query w/o value
+ inS := in[:maxAllowedPacketSize-nonDataQueryLen]
+ dbt.mustExec("INSERT INTO test VALUES('" + inS + "')")
+ rows = dbt.mustQuery("SELECT value FROM test")
+ defer rows.Close()
+ if rows.Next() {
+ rows.Scan(&out)
+ if inS != out {
+ dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out))
+ }
+ if rows.Next() {
+ dbt.Error("LONGBLOB: unexpexted row")
+ }
+ } else {
+ dbt.Fatalf("LONGBLOB: no data")
+ }
+
+ // Empty table
+ dbt.mustExec("TRUNCATE TABLE test")
+
+ // Long binary data
+ dbt.mustExec("INSERT INTO test VALUES(?)", in)
+ rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1)
+ defer rows.Close()
+ if rows.Next() {
+ rows.Scan(&out)
+ if in != out {
+ dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out))
+ }
+ if rows.Next() {
+ dbt.Error("LONGBLOB: unexpexted row")
+ }
+ } else {
+ if err = rows.Err(); err != nil {
+ dbt.Fatalf("LONGBLOB: no data (err: %s)", err.Error())
+ } else {
+ dbt.Fatal("LONGBLOB: no data (err: <nil>)")
+ }
+ }
+ })
+}
+
+func TestLoadData(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ verifyLoadDataResult := func() {
+ rows, err := dbt.db.Query("SELECT * FROM test")
+ if err != nil {
+ dbt.Fatal(err.Error())
+ }
+
+ i := 0
+ values := [4]string{
+ "a string",
+ "a string containing a \t",
+ "a string containing a \n",
+ "a string containing both \t\n",
+ }
+
+ var id int
+ var value string
+
+ for rows.Next() {
+ i++
+ err = rows.Scan(&id, &value)
+ if err != nil {
+ dbt.Fatal(err.Error())
+ }
+ if i != id {
+ dbt.Fatalf("%d != %d", i, id)
+ }
+ if values[i-1] != value {
+ dbt.Fatalf("%q != %q", values[i-1], value)
+ }
+ }
+ err = rows.Err()
+ if err != nil {
+ dbt.Fatal(err.Error())
+ }
+
+ if i != 4 {
+ dbt.Fatalf("rows count mismatch. Got %d, want 4", i)
+ }
+ }
+
+ dbt.db.Exec("DROP TABLE IF EXISTS test")
+ dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
+
+ // Local File
+ file, err := ioutil.TempFile("", "gotest")
+ defer os.Remove(file.Name())
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ RegisterLocalFile(file.Name())
+
+ // Try first with empty file
+ dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name()))
+ var count int
+ err = dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&count)
+ if err != nil {
+ dbt.Fatal(err.Error())
+ }
+ if count != 0 {
+ dbt.Fatalf("unexpected row count: got %d, want 0", count)
+ }
+
+ // Then fille File with data and try to load it
+ file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
+ file.Close()
+ dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name()))
+ verifyLoadDataResult()
+
+ // Try with non-existing file
+ _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")
+ if err == nil {
+ dbt.Fatal("load non-existent file didn't fail")
+ } else if err.Error() != "local file 'doesnotexist' is not registered" {
+ dbt.Fatal(err.Error())
+ }
+
+ // Empty table
+ dbt.mustExec("TRUNCATE TABLE test")
+
+ // Reader
+ RegisterReaderHandler("test", func() io.Reader {
+ file, err = os.Open(file.Name())
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ return file
+ })
+ dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test")
+ verifyLoadDataResult()
+ // negative test
+ _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test")
+ if err == nil {
+ dbt.Fatal("load non-existent Reader didn't fail")
+ } else if err.Error() != "Reader 'doesnotexist' is not registered" {
+ dbt.Fatal(err.Error())
+ }
+ })
+}
+
+func TestFoundRows(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
+ dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
+
+ res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
+ count, err := res.RowsAffected()
+ if err != nil {
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+ }
+ if count != 2 {
+ dbt.Fatalf("Expected 2 affected rows, got %d", count)
+ }
+ res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
+ count, err = res.RowsAffected()
+ if err != nil {
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+ }
+ if count != 2 {
+ dbt.Fatalf("Expected 2 affected rows, got %d", count)
+ }
+ })
+ runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
+ dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
+
+ res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
+ count, err := res.RowsAffected()
+ if err != nil {
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+ }
+ if count != 2 {
+ dbt.Fatalf("Expected 2 matched rows, got %d", count)
+ }
+ res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
+ count, err = res.RowsAffected()
+ if err != nil {
+ dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+ }
+ if count != 3 {
+ dbt.Fatalf("Expected 3 matched rows, got %d", count)
+ }
+ })
+}
+
+func TestTLS(t *testing.T) {
+ tlsTestReq := func(dbt *DBTest) {
+ if err := dbt.db.Ping(); err != nil {
+ if err == ErrNoTLS {
+ dbt.Skip("server does not support TLS")
+ } else {
+ dbt.Fatalf("error on Ping: %s", err.Error())
+ }
+ }
+
+ rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'")
+ defer rows.Close()
+
+ var variable, value *sql.RawBytes
+ for rows.Next() {
+ if err := rows.Scan(&variable, &value); err != nil {
+ dbt.Fatal(err.Error())
+ }
+
+ if (*value == nil) || (len(*value) == 0) {
+ dbt.Fatalf("no Cipher")
+ } else {
+ dbt.Logf("Cipher: %s", *value)
+ }
+ }
+ }
+ tlsTestOpt := func(dbt *DBTest) {
+ if err := dbt.db.Ping(); err != nil {
+ dbt.Fatalf("error on Ping: %s", err.Error())
+ }
+ }
+
+ runTests(t, dsn+"&tls=preferred", tlsTestOpt)
+ runTests(t, dsn+"&tls=skip-verify", tlsTestReq)
+
+ // Verify that registering / using a custom cfg works
+ RegisterTLSConfig("custom-skip-verify", &tls.Config{
+ InsecureSkipVerify: true,
+ })
+ runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq)
+}
+
+func TestReuseClosedConnection(t *testing.T) {
+ // this test does not use sql.database, it uses the driver directly
+ if !available {
+ t.Skipf("MySQL server not running on %s", netAddr)
+ }
+
+ md := &MySQLDriver{}
+ conn, err := md.Open(dsn)
+ if err != nil {
+ t.Fatalf("error connecting: %s", err.Error())
+ }
+ stmt, err := conn.Prepare("DO 1")
+ if err != nil {
+ t.Fatalf("error preparing statement: %s", err.Error())
+ }
+ _, err = stmt.Exec(nil)
+ if err != nil {
+ t.Fatalf("error executing statement: %s", err.Error())
+ }
+ err = conn.Close()
+ if err != nil {
+ t.Fatalf("error closing connection: %s", err.Error())
+ }
+
+ defer func() {
+ if err := recover(); err != nil {
+ t.Errorf("panic after reusing a closed connection: %v", err)
+ }
+ }()
+ _, err = stmt.Exec(nil)
+ if err != nil && err != driver.ErrBadConn {
+ t.Errorf("unexpected error '%s', expected '%s'",
+ err.Error(), driver.ErrBadConn.Error())
+ }
+}
+
+func TestCharset(t *testing.T) {
+ if !available {
+ t.Skipf("MySQL server not running on %s", netAddr)
+ }
+
+ mustSetCharset := func(charsetParam, expected string) {
+ runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) {
+ rows := dbt.mustQuery("SELECT @@character_set_connection")
+ defer rows.Close()
+
+ if !rows.Next() {
+ dbt.Fatalf("error getting connection charset: %s", rows.Err())
+ }
+
+ var got string
+ rows.Scan(&got)
+
+ if got != expected {
+ dbt.Fatalf("expected connection charset %s but got %s", expected, got)
+ }
+ })
+ }
+
+ // non utf8 test
+ mustSetCharset("charset=ascii", "ascii")
+
+ // when the first charset is invalid, use the second
+ mustSetCharset("charset=none,utf8", "utf8")
+
+ // when the first charset is valid, use it
+ mustSetCharset("charset=ascii,utf8", "ascii")
+ mustSetCharset("charset=utf8,ascii", "utf8")
+}
+
+func TestFailingCharset(t *testing.T) {
+ runTests(t, dsn+"&charset=none", func(dbt *DBTest) {
+ // run query to really establish connection...
+ _, err := dbt.db.Exec("SELECT 1")
+ if err == nil {
+ dbt.db.Close()
+ t.Fatalf("connection must not succeed without a valid charset")
+ }
+ })
+}
+
+func TestCollation(t *testing.T) {
+ if !available {
+ t.Skipf("MySQL server not running on %s", netAddr)
+ }
+
+ defaultCollation := "utf8mb4_general_ci"
+ testCollations := []string{
+ "", // do not set
+ defaultCollation, // driver default
+ "latin1_general_ci",
+ "binary",
+ "utf8_unicode_ci",
+ "cp1257_bin",
+ }
+
+ for _, collation := range testCollations {
+ var expected, tdsn string
+ if collation != "" {
+ tdsn = dsn + "&collation=" + collation
+ expected = collation
+ } else {
+ tdsn = dsn
+ expected = defaultCollation
+ }
+
+ runTests(t, tdsn, func(dbt *DBTest) {
+ var got string
+ if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil {
+ dbt.Fatal(err)
+ }
+
+ if got != expected {
+ dbt.Fatalf("expected connection collation %s but got %s", expected, got)
+ }
+ })
+ }
+}
+
+func TestColumnsWithAlias(t *testing.T) {
+ runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) {
+ rows := dbt.mustQuery("SELECT 1 AS A")
+ defer rows.Close()
+ cols, _ := rows.Columns()
+ if len(cols) != 1 {
+ t.Fatalf("expected 1 column, got %d", len(cols))
+ }
+ if cols[0] != "A" {
+ t.Fatalf("expected column name \"A\", got \"%s\"", cols[0])
+ }
+
+ rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A")
+ defer rows.Close()
+ cols, _ = rows.Columns()
+ if len(cols) != 1 {
+ t.Fatalf("expected 1 column, got %d", len(cols))
+ }
+ if cols[0] != "A.one" {
+ t.Fatalf("expected column name \"A.one\", got \"%s\"", cols[0])
+ }
+ })
+}
+
+func TestRawBytesResultExceedsBuffer(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ // defaultBufSize from buffer.go
+ expected := strings.Repeat("abc", defaultBufSize)
+
+ rows := dbt.mustQuery("SELECT '" + expected + "'")
+ defer rows.Close()
+ if !rows.Next() {
+ dbt.Error("expected result, got none")
+ }
+ var result sql.RawBytes
+ rows.Scan(&result)
+ if expected != string(result) {
+ dbt.Error("result did not match expected value")
+ }
+ })
+}
+
+func TestTimezoneConversion(t *testing.T) {
+ zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
+
+ // Regression test for timezone handling
+ tzTest := func(dbt *DBTest) {
+ // Create table
+ dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
+
+ // Insert local time into database (should be converted)
+ usCentral, _ := time.LoadLocation("US/Central")
+ reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral)
+ dbt.mustExec("INSERT INTO test VALUE (?)", reftime)
+
+ // Retrieve time from DB
+ rows := dbt.mustQuery("SELECT ts FROM test")
+ defer rows.Close()
+ if !rows.Next() {
+ dbt.Fatal("did not get any rows out")
+ }
+
+ var dbTime time.Time
+ err := rows.Scan(&dbTime)
+ if err != nil {
+ dbt.Fatal("Err", err)
+ }
+
+ // Check that dates match
+ if reftime.Unix() != dbTime.Unix() {
+ dbt.Errorf("times do not match.\n")
+ dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime)
+ dbt.Errorf(" Now(UTC)=%v\n", dbTime)
+ }
+ }
+
+ for _, tz := range zones {
+ runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
+ }
+}
+
+// Special cases
+
+func TestRowsClose(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ rows, err := dbt.db.Query("SELECT 1")
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ err = rows.Close()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ if rows.Next() {
+ dbt.Fatal("unexpected row after rows.Close()")
+ }
+
+ err = rows.Err()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ })
+}
+
+// dangling statements
+// http://code.google.com/p/go/issues/detail?id=3865
+func TestCloseStmtBeforeRows(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ stmt, err := dbt.db.Prepare("SELECT 1")
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ rows, err := stmt.Query()
+ if err != nil {
+ stmt.Close()
+ dbt.Fatal(err)
+ }
+ defer rows.Close()
+
+ err = stmt.Close()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ if !rows.Next() {
+ dbt.Fatal("getting row failed")
+ } else {
+ err = rows.Err()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ var out bool
+ err = rows.Scan(&out)
+ if err != nil {
+ dbt.Fatalf("error on rows.Scan(): %s", err.Error())
+ }
+ if out != true {
+ dbt.Errorf("true != %t", out)
+ }
+ }
+ })
+}
+
+// It is valid to have multiple Rows for the same Stmt
+// http://code.google.com/p/go/issues/detail?id=3734
+func TestStmtMultiRows(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0")
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ rows1, err := stmt.Query()
+ if err != nil {
+ stmt.Close()
+ dbt.Fatal(err)
+ }
+ defer rows1.Close()
+
+ rows2, err := stmt.Query()
+ if err != nil {
+ stmt.Close()
+ dbt.Fatal(err)
+ }
+ defer rows2.Close()
+
+ var out bool
+
+ // 1
+ if !rows1.Next() {
+ dbt.Fatal("first rows1.Next failed")
+ } else {
+ err = rows1.Err()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ err = rows1.Scan(&out)
+ if err != nil {
+ dbt.Fatalf("error on rows.Scan(): %s", err.Error())
+ }
+ if out != true {
+ dbt.Errorf("true != %t", out)
+ }
+ }
+
+ if !rows2.Next() {
+ dbt.Fatal("first rows2.Next failed")
+ } else {
+ err = rows2.Err()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ err = rows2.Scan(&out)
+ if err != nil {
+ dbt.Fatalf("error on rows.Scan(): %s", err.Error())
+ }
+ if out != true {
+ dbt.Errorf("true != %t", out)
+ }
+ }
+
+ // 2
+ if !rows1.Next() {
+ dbt.Fatal("second rows1.Next failed")
+ } else {
+ err = rows1.Err()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ err = rows1.Scan(&out)
+ if err != nil {
+ dbt.Fatalf("error on rows.Scan(): %s", err.Error())
+ }
+ if out != false {
+ dbt.Errorf("false != %t", out)
+ }
+
+ if rows1.Next() {
+ dbt.Fatal("unexpected row on rows1")
+ }
+ err = rows1.Close()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ }
+
+ if !rows2.Next() {
+ dbt.Fatal("second rows2.Next failed")
+ } else {
+ err = rows2.Err()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ err = rows2.Scan(&out)
+ if err != nil {
+ dbt.Fatalf("error on rows.Scan(): %s", err.Error())
+ }
+ if out != false {
+ dbt.Errorf("false != %t", out)
+ }
+
+ if rows2.Next() {
+ dbt.Fatal("unexpected row on rows2")
+ }
+ err = rows2.Close()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ }
+ })
+}
+
+// Regression test for
+// * more than 32 NULL parameters (issue 209)
+// * more parameters than fit into the buffer (issue 201)
+// * parameters * 64 > max_allowed_packet (issue 734)
+func TestPreparedManyCols(t *testing.T) {
+ numParams := 65535
+ runTests(t, dsn, func(dbt *DBTest) {
+ query := "SELECT ?" + strings.Repeat(",?", numParams-1)
+ stmt, err := dbt.db.Prepare(query)
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ defer stmt.Close()
+
+ // create more parameters than fit into the buffer
+ // which will take nil-values
+ params := make([]interface{}, numParams)
+ rows, err := stmt.Query(params...)
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ rows.Close()
+
+ // Create 0byte string which we can't send via STMT_LONG_DATA.
+ for i := 0; i < numParams; i++ {
+ params[i] = ""
+ }
+ rows, err = stmt.Query(params...)
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ rows.Close()
+ })
+}
+
+func TestConcurrent(t *testing.T) {
+ if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
+ t.Skip("MYSQL_TEST_CONCURRENT env var not set")
+ }
+
+ runTests(t, dsn, func(dbt *DBTest) {
+ var version string
+ if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+ if strings.Contains(strings.ToLower(version), "mariadb") {
+ t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`)
+ }
+
+ var max int
+ err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)
+ if err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+ dbt.Logf("testing up to %d concurrent connections \r\n", max)
+
+ var remaining, succeeded int32 = int32(max), 0
+
+ var wg sync.WaitGroup
+ wg.Add(max)
+
+ var fatalError string
+ var once sync.Once
+ fatalf := func(s string, vals ...interface{}) {
+ once.Do(func() {
+ fatalError = fmt.Sprintf(s, vals...)
+ })
+ }
+
+ for i := 0; i < max; i++ {
+ go func(id int) {
+ defer wg.Done()
+
+ tx, err := dbt.db.Begin()
+ atomic.AddInt32(&remaining, -1)
+
+ if err != nil {
+ if err.Error() != "Error 1040: Too many connections" {
+ fatalf("error on conn %d: %s", id, err.Error())
+ }
+ return
+ }
+
+ // keep the connection busy until all connections are open
+ for remaining > 0 {
+ if _, err = tx.Exec("DO 1"); err != nil {
+ fatalf("error on conn %d: %s", id, err.Error())
+ return
+ }
+ }
+
+ if err = tx.Commit(); err != nil {
+ fatalf("error on conn %d: %s", id, err.Error())
+ return
+ }
+
+ // everything went fine with this connection
+ atomic.AddInt32(&succeeded, 1)
+ }(i)
+ }
+
+ // wait until all conections are open
+ wg.Wait()
+
+ if fatalError != "" {
+ dbt.Fatal(fatalError)
+ }
+
+ dbt.Logf("reached %d concurrent connections\r\n", succeeded)
+ })
+}
+
+func testDialError(t *testing.T, dialErr error, expectErr error) {
+ RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) {
+ return nil, dialErr
+ })
+
+ db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
+ if err != nil {
+ t.Fatalf("error connecting: %s", err.Error())
+ }
+ defer db.Close()
+
+ _, err = db.Exec("DO 1")
+ if err != expectErr {
+ t.Fatalf("was expecting %s. Got: %s", dialErr, err)
+ }
+}
+
+func TestDialUnknownError(t *testing.T) {
+ testErr := fmt.Errorf("test")
+ testDialError(t, testErr, testErr)
+}
+
+func TestDialNonRetryableNetErr(t *testing.T) {
+ testErr := netErrorMock{}
+ testDialError(t, testErr, testErr)
+}
+
+func TestDialTemporaryNetErr(t *testing.T) {
+ testErr := netErrorMock{temporary: true}
+ testDialError(t, testErr, testErr)
+}
+
+// Tests custom dial functions
+func TestCustomDial(t *testing.T) {
+ if !available {
+ t.Skipf("MySQL server not running on %s", netAddr)
+ }
+
+ // our custom dial function which justs wraps net.Dial here
+ RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) {
+ var d net.Dialer
+ return d.DialContext(ctx, prot, addr)
+ })
+
+ db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
+ if err != nil {
+ t.Fatalf("error connecting: %s", err.Error())
+ }
+ defer db.Close()
+
+ if _, err = db.Exec("DO 1"); err != nil {
+ t.Fatalf("connection failed: %s", err.Error())
+ }
+}
+
+func TestSQLInjection(t *testing.T) {
+ createTest := func(arg string) func(dbt *DBTest) {
+ return func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (v INTEGER)")
+ dbt.mustExec("INSERT INTO test VALUES (?)", 1)
+
+ var v int
+ // NULL can't be equal to anything, the idea here is to inject query so it returns row
+ // This test verifies that escapeQuotes and escapeBackslash are working properly
+ err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v)
+ if err == sql.ErrNoRows {
+ return // success, sql injection failed
+ } else if err == nil {
+ dbt.Errorf("sql injection successful with arg: %s", arg)
+ } else {
+ dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error())
+ }
+ }
+ }
+
+ dsns := []string{
+ dsn,
+ dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
+ }
+ for _, testdsn := range dsns {
+ runTests(t, testdsn, createTest("1 OR 1=1"))
+ runTests(t, testdsn, createTest("' OR '1'='1"))
+ }
+}
+
+// Test if inserted data is correctly retrieved after being escaped
+func TestInsertRetrieveEscapedData(t *testing.T) {
+ testData := func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (v VARCHAR(255))")
+
+ // All sequences that are escaped by escapeQuotes and escapeBackslash
+ v := "foo \x00\n\r\x1a\"'\\"
+ dbt.mustExec("INSERT INTO test VALUES (?)", v)
+
+ var out string
+ err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out)
+ if err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+
+ if out != v {
+ dbt.Errorf("%q != %q", out, v)
+ }
+ }
+
+ dsns := []string{
+ dsn,
+ dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
+ }
+ for _, testdsn := range dsns {
+ runTests(t, testdsn, testData)
+ }
+}
+
+func TestUnixSocketAuthFail(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ // Save the current logger so we can restore it.
+ oldLogger := errLog
+
+ // Set a new logger so we can capture its output.
+ buffer := bytes.NewBuffer(make([]byte, 0, 64))
+ newLogger := log.New(buffer, "prefix: ", 0)
+ SetLogger(newLogger)
+
+ // Restore the logger.
+ defer SetLogger(oldLogger)
+
+ // Make a new DSN that uses the MySQL socket file and a bad password, which
+ // we can make by simply appending any character to the real password.
+ badPass := pass + "x"
+ socket := ""
+ if prot == "unix" {
+ socket = addr
+ } else {
+ // Get socket file from MySQL.
+ err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket)
+ if err != nil {
+ t.Fatalf("error on SELECT @@socket: %s", err.Error())
+ }
+ }
+ t.Logf("socket: %s", socket)
+ badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname)
+ db, err := sql.Open("mysql", badDSN)
+ if err != nil {
+ t.Fatalf("error connecting: %s", err.Error())
+ }
+ defer db.Close()
+
+ // Connect to MySQL for real. This will cause an auth failure.
+ err = db.Ping()
+ if err == nil {
+ t.Error("expected Ping() to return an error")
+ }
+
+ // The driver should not log anything.
+ if actual := buffer.String(); actual != "" {
+ t.Errorf("expected no output, got %q", actual)
+ }
+ })
+}
+
+// See Issue #422
+func TestInterruptBySignal(t *testing.T) {
+ runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec(`
+ DROP PROCEDURE IF EXISTS test_signal;
+ CREATE PROCEDURE test_signal(ret INT)
+ BEGIN
+ SELECT ret;
+ SIGNAL SQLSTATE
+ '45001'
+ SET
+ MESSAGE_TEXT = "an error",
+ MYSQL_ERRNO = 45001;
+ END
+ `)
+ defer dbt.mustExec("DROP PROCEDURE test_signal")
+
+ var val int
+
+ // text protocol
+ rows, err := dbt.db.Query("CALL test_signal(42)")
+ if err != nil {
+ dbt.Fatalf("error on text query: %s", err.Error())
+ }
+ for rows.Next() {
+ if err := rows.Scan(&val); err != nil {
+ dbt.Error(err)
+ } else if val != 42 {
+ dbt.Errorf("expected val to be 42")
+ }
+ }
+ rows.Close()
+
+ // binary protocol
+ rows, err = dbt.db.Query("CALL test_signal(?)", 42)
+ if err != nil {
+ dbt.Fatalf("error on binary query: %s", err.Error())
+ }
+ for rows.Next() {
+ if err := rows.Scan(&val); err != nil {
+ dbt.Error(err)
+ } else if val != 42 {
+ dbt.Errorf("expected val to be 42")
+ }
+ }
+ rows.Close()
+ })
+}
+
+func TestColumnsReusesSlice(t *testing.T) {
+ rows := mysqlRows{
+ rs: resultSet{
+ columns: []mysqlField{
+ {
+ tableName: "test",
+ name: "A",
+ },
+ {
+ tableName: "test",
+ name: "B",
+ },
+ },
+ },
+ }
+
+ allocs := testing.AllocsPerRun(1, func() {
+ cols := rows.Columns()
+
+ if len(cols) != 2 {
+ t.Fatalf("expected 2 columns, got %d", len(cols))
+ }
+ })
+
+ if allocs != 0 {
+ t.Fatalf("expected 0 allocations, got %d", int(allocs))
+ }
+
+ if rows.rs.columnNames == nil {
+ t.Fatalf("expected columnNames to be set, got nil")
+ }
+}
+
+func TestRejectReadOnly(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ // Create Table
+ dbt.mustExec("CREATE TABLE test (value BOOL)")
+ // Set the session to read-only. We didn't set the `rejectReadOnly`
+ // option, so any writes after this should fail.
+ _, err := dbt.db.Exec("SET SESSION TRANSACTION READ ONLY")
+ // Error 1193: Unknown system variable 'TRANSACTION' => skip test,
+ // MySQL server version is too old
+ maybeSkip(t, err, 1193)
+ if _, err := dbt.db.Exec("DROP TABLE test"); err == nil {
+ t.Fatalf("writing to DB in read-only session without " +
+ "rejectReadOnly did not error")
+ }
+ // Set the session back to read-write so runTests() can properly clean
+ // up the table `test`.
+ dbt.mustExec("SET SESSION TRANSACTION READ WRITE")
+ })
+
+ // Enable the `rejectReadOnly` option.
+ runTests(t, dsn+"&rejectReadOnly=true", func(dbt *DBTest) {
+ // Create Table
+ dbt.mustExec("CREATE TABLE test (value BOOL)")
+ // Set the session to read only. Any writes after this should error on
+ // a driver.ErrBadConn, and cause `database/sql` to initiate a new
+ // connection.
+ dbt.mustExec("SET SESSION TRANSACTION READ ONLY")
+ // This would error, but `database/sql` should automatically retry on a
+ // new connection which is not read-only, and eventually succeed.
+ dbt.mustExec("DROP TABLE test")
+ })
+}
+
+func TestPing(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ if err := dbt.db.Ping(); err != nil {
+ dbt.fail("Ping", "Ping", err)
+ }
+ })
+}
+
+// See Issue #799
+func TestEmptyPassword(t *testing.T) {
+ if !available {
+ t.Skipf("MySQL server not running on %s", netAddr)
+ }
+
+ dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname)
+ db, err := sql.Open("mysql", dsn)
+ if err == nil {
+ defer db.Close()
+ err = db.Ping()
+ }
+
+ if pass == "" {
+ if err != nil {
+ t.Fatal(err.Error())
+ }
+ } else {
+ if err == nil {
+ t.Fatal("expected authentication error")
+ }
+ if !strings.HasPrefix(err.Error(), "Error 1045") {
+ t.Fatal(err.Error())
+ }
+ }
+}
+
+// static interface implementation checks of mysqlConn
+var (
+ _ driver.ConnBeginTx = &mysqlConn{}
+ _ driver.ConnPrepareContext = &mysqlConn{}
+ _ driver.ExecerContext = &mysqlConn{}
+ _ driver.Pinger = &mysqlConn{}
+ _ driver.QueryerContext = &mysqlConn{}
+)
+
+// static interface implementation checks of mysqlStmt
+var (
+ _ driver.StmtExecContext = &mysqlStmt{}
+ _ driver.StmtQueryContext = &mysqlStmt{}
+)
+
+// Ensure that all the driver interfaces are implemented
+var (
+ // _ driver.RowsColumnTypeLength = &binaryRows{}
+ // _ driver.RowsColumnTypeLength = &textRows{}
+ _ driver.RowsColumnTypeDatabaseTypeName = &binaryRows{}
+ _ driver.RowsColumnTypeDatabaseTypeName = &textRows{}
+ _ driver.RowsColumnTypeNullable = &binaryRows{}
+ _ driver.RowsColumnTypeNullable = &textRows{}
+ _ driver.RowsColumnTypePrecisionScale = &binaryRows{}
+ _ driver.RowsColumnTypePrecisionScale = &textRows{}
+ _ driver.RowsColumnTypeScanType = &binaryRows{}
+ _ driver.RowsColumnTypeScanType = &textRows{}
+ _ driver.RowsNextResultSet = &binaryRows{}
+ _ driver.RowsNextResultSet = &textRows{}
+)
+
+func TestMultiResultSet(t *testing.T) {
+ type result struct {
+ values [][]int
+ columns []string
+ }
+
+ // checkRows is a helper test function to validate rows containing 3 result
+ // sets with specific values and columns. The basic query would look like this:
+ //
+ // SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
+ // SELECT 0 UNION SELECT 1;
+ // SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
+ //
+ // to distinguish test cases the first string argument is put in front of
+ // every error or fatal message.
+ checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) {
+ expected := []result{
+ {
+ values: [][]int{{1, 2}, {3, 4}},
+ columns: []string{"col1", "col2"},
+ },
+ {
+ values: [][]int{{1, 2, 3}, {4, 5, 6}},
+ columns: []string{"col1", "col2", "col3"},
+ },
+ }
+
+ var res1 result
+ for rows.Next() {
+ var res [2]int
+ if err := rows.Scan(&res[0], &res[1]); err != nil {
+ dbt.Fatal(err)
+ }
+ res1.values = append(res1.values, res[:])
+ }
+
+ cols, err := rows.Columns()
+ if err != nil {
+ dbt.Fatal(desc, err)
+ }
+ res1.columns = cols
+
+ if !reflect.DeepEqual(expected[0], res1) {
+ dbt.Error(desc, "want =", expected[0], "got =", res1)
+ }
+
+ if !rows.NextResultSet() {
+ dbt.Fatal(desc, "expected next result set")
+ }
+
+ // ignoring one result set
+
+ if !rows.NextResultSet() {
+ dbt.Fatal(desc, "expected next result set")
+ }
+
+ var res2 result
+ cols, err = rows.Columns()
+ if err != nil {
+ dbt.Fatal(desc, err)
+ }
+ res2.columns = cols
+
+ for rows.Next() {
+ var res [3]int
+ if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
+ dbt.Fatal(desc, err)
+ }
+ res2.values = append(res2.values, res[:])
+ }
+
+ if !reflect.DeepEqual(expected[1], res2) {
+ dbt.Error(desc, "want =", expected[1], "got =", res2)
+ }
+
+ if rows.NextResultSet() {
+ dbt.Error(desc, "unexpected next result set")
+ }
+
+ if err := rows.Err(); err != nil {
+ dbt.Error(desc, err)
+ }
+ }
+
+ runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+ rows := dbt.mustQuery(`DO 1;
+ SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
+ DO 1;
+ SELECT 0 UNION SELECT 1;
+ SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`)
+ defer rows.Close()
+ checkRows("query: ", rows, dbt)
+ })
+
+ runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+ queries := []string{
+ `
+ DROP PROCEDURE IF EXISTS test_mrss;
+ CREATE PROCEDURE test_mrss()
+ BEGIN
+ DO 1;
+ SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
+ DO 1;
+ SELECT 0 UNION SELECT 1;
+ SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
+ END
+ `,
+ `
+ DROP PROCEDURE IF EXISTS test_mrss;
+ CREATE PROCEDURE test_mrss()
+ BEGIN
+ SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
+ SELECT 0 UNION SELECT 1;
+ SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
+ END
+ `,
+ }
+
+ defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss")
+
+ for i, query := range queries {
+ dbt.mustExec(query)
+
+ stmt, err := dbt.db.Prepare("CALL test_mrss()")
+ if err != nil {
+ dbt.Fatalf("%v (i=%d)", err, i)
+ }
+ defer stmt.Close()
+
+ for j := 0; j < 2; j++ {
+ rows, err := stmt.Query()
+ if err != nil {
+ dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j)
+ }
+ checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt)
+ }
+ }
+ })
+}
+
+func TestMultiResultSetNoSelect(t *testing.T) {
+ runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
+ rows := dbt.mustQuery("DO 1; DO 2;")
+ defer rows.Close()
+
+ if rows.Next() {
+ dbt.Error("unexpected row")
+ }
+
+ if rows.NextResultSet() {
+ dbt.Error("unexpected next result set")
+ }
+
+ if err := rows.Err(); err != nil {
+ dbt.Error("expected nil; got ", err)
+ }
+ })
+}
+
+// tests if rows are set in a proper state if some results were ignored before
+// calling rows.NextResultSet.
+func TestSkipResults(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ rows := dbt.mustQuery("SELECT 1, 2")
+ defer rows.Close()
+
+ if !rows.Next() {
+ dbt.Error("expected row")
+ }
+
+ if rows.NextResultSet() {
+ dbt.Error("unexpected next result set")
+ }
+
+ if err := rows.Err(); err != nil {
+ dbt.Error("expected nil; got ", err)
+ }
+ })
+}
+
+func TestPingContext(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ if err := dbt.db.PingContext(ctx); err != context.Canceled {
+ dbt.Errorf("expected context.Canceled, got %v", err)
+ }
+ })
+}
+
+func TestContextCancelExec(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (v INTEGER)")
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Delay execution for just a bit until db.ExecContext has begun.
+ defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
+
+ // This query will be canceled.
+ startTime := time.Now()
+ if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
+ dbt.Errorf("expected context.Canceled, got %v", err)
+ }
+ if d := time.Since(startTime); d > 500*time.Millisecond {
+ dbt.Errorf("too long execution time: %s", d)
+ }
+
+ // Wait for the INSERT query to be done.
+ time.Sleep(time.Second)
+
+ // Check how many times the query is executed.
+ var v int
+ if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+ if v != 1 { // TODO: need to kill the query, and v should be 0.
+ dbt.Skipf("[WARN] expected val to be 1, got %d", v)
+ }
+
+ // Context is already canceled, so error should come before execution.
+ if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil {
+ dbt.Error("expected error")
+ } else if err.Error() != "context canceled" {
+ dbt.Fatalf("unexpected error: %s", err)
+ }
+
+ // The second insert query will fail, so the table has no changes.
+ if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+ if v != 1 {
+ dbt.Skipf("[WARN] expected val to be 1, got %d", v)
+ }
+ })
+}
+
+func TestContextCancelQuery(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (v INTEGER)")
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Delay execution for just a bit until db.ExecContext has begun.
+ defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
+
+ // This query will be canceled.
+ startTime := time.Now()
+ if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
+ dbt.Errorf("expected context.Canceled, got %v", err)
+ }
+ if d := time.Since(startTime); d > 500*time.Millisecond {
+ dbt.Errorf("too long execution time: %s", d)
+ }
+
+ // Wait for the INSERT query to be done.
+ time.Sleep(time.Second)
+
+ // Check how many times the query is executed.
+ var v int
+ if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+ if v != 1 { // TODO: need to kill the query, and v should be 0.
+ dbt.Skipf("[WARN] expected val to be 1, got %d", v)
+ }
+
+ // Context is already canceled, so error should come before execution.
+ if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled {
+ dbt.Errorf("expected context.Canceled, got %v", err)
+ }
+
+ // The second insert query will fail, so the table has no changes.
+ if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+ if v != 1 {
+ dbt.Skipf("[WARN] expected val to be 1, got %d", v)
+ }
+ })
+}
+
+func TestContextCancelQueryRow(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (v INTEGER)")
+ dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)")
+ ctx, cancel := context.WithCancel(context.Background())
+
+ rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test")
+ if err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+
+ // the first row will be succeed.
+ var v int
+ if !rows.Next() {
+ dbt.Fatalf("unexpected end")
+ }
+ if err := rows.Scan(&v); err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+
+ cancel()
+ // make sure the driver receives the cancel request.
+ time.Sleep(100 * time.Millisecond)
+
+ if rows.Next() {
+ dbt.Errorf("expected end, but not")
+ }
+ if err := rows.Err(); err != context.Canceled {
+ dbt.Errorf("expected context.Canceled, got %v", err)
+ }
+ })
+}
+
+func TestContextCancelPrepare(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled {
+ dbt.Errorf("expected context.Canceled, got %v", err)
+ }
+ })
+}
+
+func TestContextCancelStmtExec(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (v INTEGER)")
+ ctx, cancel := context.WithCancel(context.Background())
+ stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))")
+ if err != nil {
+ dbt.Fatalf("unexpected error: %v", err)
+ }
+
+ // Delay execution for just a bit until db.ExecContext has begun.
+ defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
+
+ // This query will be canceled.
+ startTime := time.Now()
+ if _, err := stmt.ExecContext(ctx); err != context.Canceled {
+ dbt.Errorf("expected context.Canceled, got %v", err)
+ }
+ if d := time.Since(startTime); d > 500*time.Millisecond {
+ dbt.Errorf("too long execution time: %s", d)
+ }
+
+ // Wait for the INSERT query to be done.
+ time.Sleep(time.Second)
+
+ // Check how many times the query is executed.
+ var v int
+ if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+ if v != 1 { // TODO: need to kill the query, and v should be 0.
+ dbt.Skipf("[WARN] expected val to be 1, got %d", v)
+ }
+ })
+}
+
+func TestContextCancelStmtQuery(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (v INTEGER)")
+ ctx, cancel := context.WithCancel(context.Background())
+ stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))")
+ if err != nil {
+ dbt.Fatalf("unexpected error: %v", err)
+ }
+
+ // Delay execution for just a bit until db.ExecContext has begun.
+ defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
+
+ // This query will be canceled.
+ startTime := time.Now()
+ if _, err := stmt.QueryContext(ctx); err != context.Canceled {
+ dbt.Errorf("expected context.Canceled, got %v", err)
+ }
+ if d := time.Since(startTime); d > 500*time.Millisecond {
+ dbt.Errorf("too long execution time: %s", d)
+ }
+
+ // Wait for the INSERT query has done.
+ time.Sleep(time.Second)
+
+ // Check how many times the query is executed.
+ var v int
+ if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+ dbt.Fatalf("%s", err.Error())
+ }
+ if v != 1 { // TODO: need to kill the query, and v should be 0.
+ dbt.Skipf("[WARN] expected val to be 1, got %d", v)
+ }
+ })
+}
+
+func TestContextCancelBegin(t *testing.T) {
+ if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
+ t.Skip(`FIXME: it sometime fails with "expected driver.ErrBadConn, got sql: connection is already closed" on windows and macOS`)
+ }
+
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (v INTEGER)")
+ ctx, cancel := context.WithCancel(context.Background())
+ conn, err := dbt.db.Conn(ctx)
+ if err != nil {
+ dbt.Fatal(err)
+ }
+ defer conn.Close()
+ tx, err := conn.BeginTx(ctx, nil)
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ // Delay execution for just a bit until db.ExecContext has begun.
+ defer time.AfterFunc(100*time.Millisecond, cancel).Stop()
+
+ // This query will be canceled.
+ startTime := time.Now()
+ if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
+ dbt.Errorf("expected context.Canceled, got %v", err)
+ }
+ if d := time.Since(startTime); d > 500*time.Millisecond {
+ dbt.Errorf("too long execution time: %s", d)
+ }
+
+ // Transaction is canceled, so expect an error.
+ switch err := tx.Commit(); err {
+ case sql.ErrTxDone:
+ // because the transaction has already been rollbacked.
+ // the database/sql package watches ctx
+ // and rollbacks when ctx is canceled.
+ case context.Canceled:
+ // the database/sql package rollbacks on another goroutine,
+ // so the transaction may not be rollbacked depending on goroutine scheduling.
+ default:
+ dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err)
+ }
+
+ // The connection is now in an inoperable state - so performing other
+ // operations should fail with ErrBadConn
+ // Important to exercise isolation level too - it runs SET TRANSACTION ISOLATION
+ // LEVEL XXX first, which needs to return ErrBadConn if the connection's context
+ // is cancelled
+ _, err = conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted})
+ if err != driver.ErrBadConn {
+ dbt.Errorf("expected driver.ErrBadConn, got %v", err)
+ }
+
+ // cannot begin a transaction (on a different conn) with a canceled context
+ if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled {
+ dbt.Errorf("expected context.Canceled, got %v", err)
+ }
+ })
+}
+
+func TestContextBeginIsolationLevel(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (v INTEGER)")
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tx1, err := dbt.db.BeginTx(ctx, &sql.TxOptions{
+ Isolation: sql.LevelRepeatableRead,
+ })
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ tx2, err := dbt.db.BeginTx(ctx, &sql.TxOptions{
+ Isolation: sql.LevelReadCommitted,
+ })
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ _, err = tx1.ExecContext(ctx, "INSERT INTO test VALUES (1)")
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ var v int
+ row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test")
+ if err := row.Scan(&v); err != nil {
+ dbt.Fatal(err)
+ }
+ // Because writer transaction wasn't commited yet, it should be available
+ if v != 0 {
+ dbt.Errorf("expected val to be 0, got %d", v)
+ }
+
+ err = tx1.Commit()
+ if err != nil {
+ dbt.Fatal(err)
+ }
+
+ row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test")
+ if err := row.Scan(&v); err != nil {
+ dbt.Fatal(err)
+ }
+ // Data written by writer transaction is already commited, it should be selectable
+ if v != 1 {
+ dbt.Errorf("expected val to be 1, got %d", v)
+ }
+ tx2.Commit()
+ })
+}
+
+func TestContextBeginReadOnly(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (v INTEGER)")
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tx, err := dbt.db.BeginTx(ctx, &sql.TxOptions{
+ ReadOnly: true,
+ })
+ if _, ok := err.(*MySQLError); ok {
+ dbt.Skip("It seems that your MySQL does not support READ ONLY transactions")
+ return
+ } else if err != nil {
+ dbt.Fatal(err)
+ }
+
+ // INSERT queries fail in a READ ONLY transaction.
+ _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)")
+ if _, ok := err.(*MySQLError); !ok {
+ dbt.Errorf("expected MySQLError, got %v", err)
+ }
+
+ // SELECT queries can be executed.
+ var v int
+ row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test")
+ if err := row.Scan(&v); err != nil {
+ dbt.Fatal(err)
+ }
+ if v != 0 {
+ dbt.Errorf("expected val to be 0, got %d", v)
+ }
+
+ if err := tx.Commit(); err != nil {
+ dbt.Fatal(err)
+ }
+ })
+}
+
+func TestRowsColumnTypes(t *testing.T) {
+ niNULL := sql.NullInt64{Int64: 0, Valid: false}
+ ni0 := sql.NullInt64{Int64: 0, Valid: true}
+ ni1 := sql.NullInt64{Int64: 1, Valid: true}
+ ni42 := sql.NullInt64{Int64: 42, Valid: true}
+ nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false}
+ nf0 := sql.NullFloat64{Float64: 0.0, Valid: true}
+ nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true}
+ nt0 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true}
+ nt1 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true}
+ nt2 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true}
+ nt6 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true}
+ nd1 := nullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
+ nd2 := nullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
+ ndNULL := nullTime{Time: time.Time{}, Valid: false}
+ rbNULL := sql.RawBytes(nil)
+ rb0 := sql.RawBytes("0")
+ rb42 := sql.RawBytes("42")
+ rbTest := sql.RawBytes("Test")
+ rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00
+ rbx0 := sql.RawBytes("\x00")
+ rbx42 := sql.RawBytes("\x42")
+
+ var columns = []struct {
+ name string
+ fieldType string // type used when creating table schema
+ databaseTypeName string // actual type used by MySQL
+ scanType reflect.Type
+ nullable bool
+ precision int64 // 0 if not ok
+ scale int64
+ valuesIn [3]string
+ valuesOut [3]interface{}
+ }{
+ {"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}},
+ {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}},
+ {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}},
+ {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
+ {"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]interface{}{int16(0), int16(-32768), int16(32767)}},
+ {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
+ {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
+ {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}},
+ {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}},
+ {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}},
+ {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}},
+ {"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}},
+ {"smalluint", "SMALLINT UNSIGNED NOT NULL", "SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}},
+ {"biguint", "BIGINT UNSIGNED NOT NULL", "BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}},
+ {"uint13", "INT(13) UNSIGNED NOT NULL", "INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}},
+ {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}},
+ {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
+ {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
+ {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}},
+ {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
+ {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}},
+ {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}},
+ {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}},
+ {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}},
+ {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}},
+ {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}},
+ {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
+ {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+ {"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}},
+ {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+ {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
+ {"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
+ {"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
+ {"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
+ {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+ {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+ {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+ {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+ {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}},
+ {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}},
+ {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}},
+ {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}},
+ {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}},
+ }
+
+ schema := ""
+ values1 := ""
+ values2 := ""
+ values3 := ""
+ for _, column := range columns {
+ schema += fmt.Sprintf("`%s` %s, ", column.name, column.fieldType)
+ values1 += column.valuesIn[0] + ", "
+ values2 += column.valuesIn[1] + ", "
+ values3 += column.valuesIn[2] + ", "
+ }
+ schema = schema[:len(schema)-2]
+ values1 = values1[:len(values1)-2]
+ values2 = values2[:len(values2)-2]
+ values3 = values3[:len(values3)-2]
+
+ runTests(t, dsn+"&parseTime=true", func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (" + schema + ")")
+ dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")")
+
+ rows, err := dbt.db.Query("SELECT * FROM test")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+
+ tt, err := rows.ColumnTypes()
+ if err != nil {
+ t.Fatalf("ColumnTypes: %v", err)
+ }
+
+ if len(tt) != len(columns) {
+ t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt))
+ }
+
+ types := make([]reflect.Type, len(tt))
+ for i, tp := range tt {
+ column := columns[i]
+
+ // Name
+ name := tp.Name()
+ if name != column.name {
+ t.Errorf("column name mismatch %s != %s", name, column.name)
+ continue
+ }
+
+ // DatabaseTypeName
+ databaseTypeName := tp.DatabaseTypeName()
+ if databaseTypeName != column.databaseTypeName {
+ t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName)
+ continue
+ }
+
+ // ScanType
+ scanType := tp.ScanType()
+ if scanType != column.scanType {
+ if scanType == nil {
+ t.Errorf("scantype is null for column %q", name)
+ } else {
+ t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name())
+ }
+ continue
+ }
+ types[i] = scanType
+
+ // Nullable
+ nullable, ok := tp.Nullable()
+ if !ok {
+ t.Errorf("nullable not ok %q", name)
+ continue
+ }
+ if nullable != column.nullable {
+ t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable)
+ }
+
+ // Length
+ // length, ok := tp.Length()
+ // if length != column.length {
+ // if !ok {
+ // t.Errorf("length not ok for column %q", name)
+ // } else {
+ // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length)
+ // }
+ // continue
+ // }
+
+ // Precision and Scale
+ precision, scale, ok := tp.DecimalSize()
+ if precision != column.precision {
+ if !ok {
+ t.Errorf("precision not ok for column %q", name)
+ } else {
+ t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision)
+ }
+ continue
+ }
+ if scale != column.scale {
+ if !ok {
+ t.Errorf("scale not ok for column %q", name)
+ } else {
+ t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale)
+ }
+ continue
+ }
+ }
+
+ values := make([]interface{}, len(tt))
+ for i := range values {
+ values[i] = reflect.New(types[i]).Interface()
+ }
+ i := 0
+ for rows.Next() {
+ err = rows.Scan(values...)
+ if err != nil {
+ t.Fatalf("failed to scan values in %v", err)
+ }
+ for j := range values {
+ value := reflect.ValueOf(values[j]).Elem().Interface()
+ if !reflect.DeepEqual(value, columns[j].valuesOut[i]) {
+ if columns[j].scanType == scanTypeRawBytes {
+ t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes)))
+ } else {
+ t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
+ }
+ }
+ }
+ i++
+ }
+ if i != 3 {
+ t.Errorf("expected 3 rows, got %d", i)
+ }
+
+ if err := rows.Close(); err != nil {
+ t.Errorf("error closing rows: %s", err)
+ }
+ })
+}
+
+func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (value VARCHAR(255))")
+ dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil))
+ // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value()
+ })
+}
+
+// TestRawBytesAreNotModified checks for a race condition that arises when a query context
+// is canceled while a user is calling rows.Scan. This is a more stringent test than the one
+// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using
+// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit
+// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers.
+func TestRawBytesAreNotModified(t *testing.T) {
+ const blob = "abcdefghijklmnop"
+ const contextRaceIterations = 20
+ const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row.
+ const insertRows = 4
+
+ var sqlBlobs = [2]string{
+ strings.Repeat(blob, blobSize/len(blob)),
+ strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)),
+ }
+
+ runTests(t, dsn, func(dbt *DBTest) {
+ dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8")
+ for i := 0; i < insertRows; i++ {
+ dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1])
+ }
+
+ for i := 0; i < contextRaceIterations; i++ {
+ func() {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var b int
+ var raw sql.RawBytes
+ for rows.Next() {
+ if err := rows.Scan(&b, &raw); err != nil {
+ t.Fatal(err)
+ }
+
+ before := string(raw)
+ // Ensure cancelling the query does not corrupt the contents of `raw`
+ cancel()
+ time.Sleep(time.Microsecond * 100)
+ after := string(raw)
+
+ if before != after {
+ t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i)
+ }
+ }
+ rows.Close()
+ }()
+ }
+ })
+}
+
+var _ driver.DriverContext = &MySQLDriver{}
+
+type dialCtxKey struct{}
+
+func TestConnectorObeysDialTimeouts(t *testing.T) {
+ if !available {
+ t.Skipf("MySQL server not running on %s", netAddr)
+ }
+
+ RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) {
+ var d net.Dialer
+ if !ctx.Value(dialCtxKey{}).(bool) {
+ return nil, fmt.Errorf("test error: query context is not propagated to our dialer")
+ }
+ return d.DialContext(ctx, prot, addr)
+ })
+
+ db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname))
+ if err != nil {
+ t.Fatalf("error connecting: %s", err.Error())
+ }
+ defer db.Close()
+
+ ctx := context.WithValue(context.Background(), dialCtxKey{}, true)
+
+ _, err = db.ExecContext(ctx, "DO 1")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func configForTests(t *testing.T) *Config {
+ if !available {
+ t.Skipf("MySQL server not running on %s", netAddr)
+ }
+
+ mycnf := NewConfig()
+ mycnf.User = user
+ mycnf.Passwd = pass
+ mycnf.Addr = addr
+ mycnf.Net = prot
+ mycnf.DBName = dbname
+ return mycnf
+}
+
+func TestNewConnector(t *testing.T) {
+ mycnf := configForTests(t)
+ conn, err := NewConnector(mycnf)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ db := sql.OpenDB(conn)
+ defer db.Close()
+
+ if err := db.Ping(); err != nil {
+ t.Fatal(err)
+ }
+}
+
+type slowConnection struct {
+ net.Conn
+ slowdown time.Duration
+}
+
+func (sc *slowConnection) Read(b []byte) (int, error) {
+ time.Sleep(sc.slowdown)
+ return sc.Conn.Read(b)
+}
+
+type connectorHijack struct {
+ driver.Connector
+ connErr error
+}
+
+func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) {
+ var conn driver.Conn
+ conn, cw.connErr = cw.Connector.Connect(ctx)
+ return conn, cw.connErr
+}
+
+func TestConnectorTimeoutsDuringOpen(t *testing.T) {
+ RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) {
+ var d net.Dialer
+ conn, err := d.DialContext(ctx, prot, addr)
+ if err != nil {
+ return nil, err
+ }
+ return &slowConnection{Conn: conn, slowdown: 100 * time.Millisecond}, nil
+ })
+
+ mycnf := configForTests(t)
+ mycnf.Net = "slowconn"
+
+ conn, err := NewConnector(mycnf)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ hijack := &connectorHijack{Connector: conn}
+
+ db := sql.OpenDB(hijack)
+ defer db.Close()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ defer cancel()
+
+ _, err = db.ExecContext(ctx, "DO 1")
+ if err != context.DeadlineExceeded {
+ t.Fatalf("ExecContext should have timed out")
+ }
+ if hijack.connErr != context.DeadlineExceeded {
+ t.Fatalf("(*Connector).Connect should have timed out")
+ }
+}
+
+// A connection which can only be closed.
+type dummyConnection struct {
+ net.Conn
+ closed bool
+}
+
+func (d *dummyConnection) Close() error {
+ d.closed = true
+ return nil
+}
+
+func TestConnectorTimeoutsWatchCancel(t *testing.T) {
+ var (
+ cancel func() // Used to cancel the context just after connecting.
+ created *dummyConnection // The created connection.
+ )
+
+ RegisterDialContext("TestConnectorTimeoutsWatchCancel", func(ctx context.Context, addr string) (net.Conn, error) {
+ // Canceling at this time triggers the watchCancel error branch in Connect().
+ cancel()
+ created = &dummyConnection{}
+ return created, nil
+ })
+
+ mycnf := NewConfig()
+ mycnf.User = "root"
+ mycnf.Addr = "foo"
+ mycnf.Net = "TestConnectorTimeoutsWatchCancel"
+
+ conn, err := NewConnector(mycnf)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ db := sql.OpenDB(conn)
+ defer db.Close()
+
+ var ctx context.Context
+ ctx, cancel = context.WithCancel(context.Background())
+ defer cancel()
+
+ if _, err := db.Conn(ctx); err != context.Canceled {
+ t.Errorf("got %v, want context.Canceled", err)
+ }
+
+ if created == nil {
+ t.Fatal("no connection created")
+ }
+ if !created.closed {
+ t.Errorf("connection not closed")
+ }
+}
diff --git a/vendor/mysql/dsn.go b/vendor/mysql/dsn.go
new file mode 100644
index 0000000..93f3548
--- /dev/null
+++ b/vendor/mysql/dsn.go
@@ -0,0 +1,560 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "bytes"
+ "crypto/rsa"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "math/big"
+ "net"
+ "net/url"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+)
+
+var (
+ errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
+ errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
+ errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
+ errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
+)
+
+// Config is a configuration parsed from a DSN string.
+// If a new Config is created instead of being parsed from a DSN string,
+// the NewConfig function should be used, which sets default values.
+type Config struct {
+ User string // Username
+ Passwd string // Password (requires User)
+ Net string // Network type
+ Addr string // Network address (requires Net)
+ DBName string // Database name
+ Params map[string]string // Connection parameters
+ Collation string // Connection collation
+ Loc *time.Location // Location for time.Time values
+ MaxAllowedPacket int // Max packet size allowed
+ ServerPubKey string // Server public key name
+ pubKey *rsa.PublicKey // Server public key
+ TLSConfig string // TLS configuration name
+ tls *tls.Config // TLS configuration
+ Timeout time.Duration // Dial timeout
+ ReadTimeout time.Duration // I/O read timeout
+ WriteTimeout time.Duration // I/O write timeout
+
+ AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
+ AllowCleartextPasswords bool // Allows the cleartext client side plugin
+ AllowNativePasswords bool // Allows the native password authentication method
+ AllowOldPasswords bool // Allows the old insecure password method
+ CheckConnLiveness bool // Check connections for liveness before using them
+ ClientFoundRows bool // Return number of matching rows instead of rows changed
+ ColumnsWithAlias bool // Prepend table alias to column names
+ InterpolateParams bool // Interpolate placeholders into query string
+ MultiStatements bool // Allow multiple statements in one query
+ ParseTime bool // Parse time values to time.Time
+ RejectReadOnly bool // Reject read-only connections
+}
+
+// NewConfig creates a new Config and sets default values.
+func NewConfig() *Config {
+ return &Config{
+ Collation: defaultCollation,
+ Loc: time.UTC,
+ MaxAllowedPacket: defaultMaxAllowedPacket,
+ AllowNativePasswords: true,
+ CheckConnLiveness: true,
+ }
+}
+
+func (cfg *Config) Clone() *Config {
+ cp := *cfg
+ if cp.tls != nil {
+ cp.tls = cfg.tls.Clone()
+ }
+ if len(cp.Params) > 0 {
+ cp.Params = make(map[string]string, len(cfg.Params))
+ for k, v := range cfg.Params {
+ cp.Params[k] = v
+ }
+ }
+ if cfg.pubKey != nil {
+ cp.pubKey = &rsa.PublicKey{
+ N: new(big.Int).Set(cfg.pubKey.N),
+ E: cfg.pubKey.E,
+ }
+ }
+ return &cp
+}
+
+func (cfg *Config) normalize() error {
+ if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
+ return errInvalidDSNUnsafeCollation
+ }
+
+ // Set default network if empty
+ if cfg.Net == "" {
+ cfg.Net = "tcp"
+ }
+
+ // Set default address if empty
+ if cfg.Addr == "" {
+ switch cfg.Net {
+ case "tcp":
+ cfg.Addr = "127.0.0.1:3306"
+ case "unix":
+ cfg.Addr = "/tmp/mysql.sock"
+ default:
+ return errors.New("default addr for network '" + cfg.Net + "' unknown")
+ }
+ } else if cfg.Net == "tcp" {
+ cfg.Addr = ensureHavePort(cfg.Addr)
+ }
+
+ switch cfg.TLSConfig {
+ case "false", "":
+ // don't set anything
+ case "true":
+ cfg.tls = &tls.Config{}
+ case "skip-verify", "preferred":
+ cfg.tls = &tls.Config{InsecureSkipVerify: true}
+ default:
+ cfg.tls = getTLSConfigClone(cfg.TLSConfig)
+ if cfg.tls == nil {
+ return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
+ }
+ }
+
+ if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
+ host, _, err := net.SplitHostPort(cfg.Addr)
+ if err == nil {
+ cfg.tls.ServerName = host
+ }
+ }
+
+ if cfg.ServerPubKey != "" {
+ cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
+ if cfg.pubKey == nil {
+ return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
+ }
+ }
+
+ return nil
+}
+
+func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) {
+ buf.Grow(1 + len(name) + 1 + len(value))
+ if !*hasParam {
+ *hasParam = true
+ buf.WriteByte('?')
+ } else {
+ buf.WriteByte('&')
+ }
+ buf.WriteString(name)
+ buf.WriteByte('=')
+ buf.WriteString(value)
+}
+
+// FormatDSN formats the given Config into a DSN string which can be passed to
+// the driver.
+func (cfg *Config) FormatDSN() string {
+ var buf bytes.Buffer
+
+ // [username[:password]@]
+ if len(cfg.User) > 0 {
+ buf.WriteString(cfg.User)
+ if len(cfg.Passwd) > 0 {
+ buf.WriteByte(':')
+ buf.WriteString(cfg.Passwd)
+ }
+ buf.WriteByte('@')
+ }
+
+ // [protocol[(address)]]
+ if len(cfg.Net) > 0 {
+ buf.WriteString(cfg.Net)
+ if len(cfg.Addr) > 0 {
+ buf.WriteByte('(')
+ buf.WriteString(cfg.Addr)
+ buf.WriteByte(')')
+ }
+ }
+
+ // /dbname
+ buf.WriteByte('/')
+ buf.WriteString(cfg.DBName)
+
+ // [?param1=value1&...&paramN=valueN]
+ hasParam := false
+
+ if cfg.AllowAllFiles {
+ hasParam = true
+ buf.WriteString("?allowAllFiles=true")
+ }
+
+ if cfg.AllowCleartextPasswords {
+ writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true")
+ }
+
+ if !cfg.AllowNativePasswords {
+ writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false")
+ }
+
+ if cfg.AllowOldPasswords {
+ writeDSNParam(&buf, &hasParam, "allowOldPasswords", "true")
+ }
+
+ if !cfg.CheckConnLiveness {
+ writeDSNParam(&buf, &hasParam, "checkConnLiveness", "false")
+ }
+
+ if cfg.ClientFoundRows {
+ writeDSNParam(&buf, &hasParam, "clientFoundRows", "true")
+ }
+
+ if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
+ writeDSNParam(&buf, &hasParam, "collation", col)
+ }
+
+ if cfg.ColumnsWithAlias {
+ writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true")
+ }
+
+ if cfg.InterpolateParams {
+ writeDSNParam(&buf, &hasParam, "interpolateParams", "true")
+ }
+
+ if cfg.Loc != time.UTC && cfg.Loc != nil {
+ writeDSNParam(&buf, &hasParam, "loc", url.QueryEscape(cfg.Loc.String()))
+ }
+
+ if cfg.MultiStatements {
+ writeDSNParam(&buf, &hasParam, "multiStatements", "true")
+ }
+
+ if cfg.ParseTime {
+ writeDSNParam(&buf, &hasParam, "parseTime", "true")
+ }
+
+ if cfg.ReadTimeout > 0 {
+ writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String())
+ }
+
+ if cfg.RejectReadOnly {
+ writeDSNParam(&buf, &hasParam, "rejectReadOnly", "true")
+ }
+
+ if len(cfg.ServerPubKey) > 0 {
+ writeDSNParam(&buf, &hasParam, "serverPubKey", url.QueryEscape(cfg.ServerPubKey))
+ }
+
+ if cfg.Timeout > 0 {
+ writeDSNParam(&buf, &hasParam, "timeout", cfg.Timeout.String())
+ }
+
+ if len(cfg.TLSConfig) > 0 {
+ writeDSNParam(&buf, &hasParam, "tls", url.QueryEscape(cfg.TLSConfig))
+ }
+
+ if cfg.WriteTimeout > 0 {
+ writeDSNParam(&buf, &hasParam, "writeTimeout", cfg.WriteTimeout.String())
+ }
+
+ if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
+ writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket))
+ }
+
+ // other params
+ if cfg.Params != nil {
+ var params []string
+ for param := range cfg.Params {
+ params = append(params, param)
+ }
+ sort.Strings(params)
+ for _, param := range params {
+ writeDSNParam(&buf, &hasParam, param, url.QueryEscape(cfg.Params[param]))
+ }
+ }
+
+ return buf.String()
+}
+
+// ParseDSN parses the DSN string to a Config
+func ParseDSN(dsn string) (cfg *Config, err error) {
+ // New config with some default values
+ cfg = NewConfig()
+
+ // [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
+ // Find the last '/' (since the password or the net addr might contain a '/')
+ foundSlash := false
+ for i := len(dsn) - 1; i >= 0; i-- {
+ if dsn[i] == '/' {
+ foundSlash = true
+ var j, k int
+
+ // left part is empty if i <= 0
+ if i > 0 {
+ // [username[:password]@][protocol[(address)]]
+ // Find the last '@' in dsn[:i]
+ for j = i; j >= 0; j-- {
+ if dsn[j] == '@' {
+ // username[:password]
+ // Find the first ':' in dsn[:j]
+ for k = 0; k < j; k++ {
+ if dsn[k] == ':' {
+ cfg.Passwd = dsn[k+1 : j]
+ break
+ }
+ }
+ cfg.User = dsn[:k]
+
+ break
+ }
+ }
+
+ // [protocol[(address)]]
+ // Find the first '(' in dsn[j+1:i]
+ for k = j + 1; k < i; k++ {
+ if dsn[k] == '(' {
+ // dsn[i-1] must be == ')' if an address is specified
+ if dsn[i-1] != ')' {
+ if strings.ContainsRune(dsn[k+1:i], ')') {
+ return nil, errInvalidDSNUnescaped
+ }
+ return nil, errInvalidDSNAddr
+ }
+ cfg.Addr = dsn[k+1 : i-1]
+ break
+ }
+ }
+ cfg.Net = dsn[j+1 : k]
+ }
+
+ // dbname[?param1=value1&...&paramN=valueN]
+ // Find the first '?' in dsn[i+1:]
+ for j = i + 1; j < len(dsn); j++ {
+ if dsn[j] == '?' {
+ if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
+ return
+ }
+ break
+ }
+ }
+ cfg.DBName = dsn[i+1 : j]
+
+ break
+ }
+ }
+
+ if !foundSlash && len(dsn) > 0 {
+ return nil, errInvalidDSNNoSlash
+ }
+
+ if err = cfg.normalize(); err != nil {
+ return nil, err
+ }
+ return
+}
+
+// parseDSNParams parses the DSN "query string"
+// Values must be url.QueryEscape'ed
+func parseDSNParams(cfg *Config, params string) (err error) {
+ for _, v := range strings.Split(params, "&") {
+ param := strings.SplitN(v, "=", 2)
+ if len(param) != 2 {
+ continue
+ }
+
+ // cfg params
+ switch value := param[1]; param[0] {
+ // Disable INFILE allowlist / enable all files
+ case "allowAllFiles":
+ var isBool bool
+ cfg.AllowAllFiles, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // Use cleartext authentication mode (MySQL 5.5.10+)
+ case "allowCleartextPasswords":
+ var isBool bool
+ cfg.AllowCleartextPasswords, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // Use native password authentication
+ case "allowNativePasswords":
+ var isBool bool
+ cfg.AllowNativePasswords, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // Use old authentication mode (pre MySQL 4.1)
+ case "allowOldPasswords":
+ var isBool bool
+ cfg.AllowOldPasswords, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // Check connections for Liveness before using them
+ case "checkConnLiveness":
+ var isBool bool
+ cfg.CheckConnLiveness, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // Switch "rowsAffected" mode
+ case "clientFoundRows":
+ var isBool bool
+ cfg.ClientFoundRows, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // Collation
+ case "collation":
+ cfg.Collation = value
+ break
+
+ case "columnsWithAlias":
+ var isBool bool
+ cfg.ColumnsWithAlias, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // Compression
+ case "compress":
+ return errors.New("compression not implemented yet")
+
+ // Enable client side placeholder substitution
+ case "interpolateParams":
+ var isBool bool
+ cfg.InterpolateParams, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // Time Location
+ case "loc":
+ if value, err = url.QueryUnescape(value); err != nil {
+ return
+ }
+ cfg.Loc, err = time.LoadLocation(value)
+ if err != nil {
+ return
+ }
+
+ // multiple statements in one query
+ case "multiStatements":
+ var isBool bool
+ cfg.MultiStatements, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // time.Time parsing
+ case "parseTime":
+ var isBool bool
+ cfg.ParseTime, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // I/O read Timeout
+ case "readTimeout":
+ cfg.ReadTimeout, err = time.ParseDuration(value)
+ if err != nil {
+ return
+ }
+
+ // Reject read-only connections
+ case "rejectReadOnly":
+ var isBool bool
+ cfg.RejectReadOnly, isBool = readBool(value)
+ if !isBool {
+ return errors.New("invalid bool value: " + value)
+ }
+
+ // Server public key
+ case "serverPubKey":
+ name, err := url.QueryUnescape(value)
+ if err != nil {
+ return fmt.Errorf("invalid value for server pub key name: %v", err)
+ }
+ cfg.ServerPubKey = name
+
+ // Strict mode
+ case "strict":
+ panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
+
+ // Dial Timeout
+ case "timeout":
+ cfg.Timeout, err = time.ParseDuration(value)
+ if err != nil {
+ return
+ }
+
+ // TLS-Encryption
+ case "tls":
+ boolValue, isBool := readBool(value)
+ if isBool {
+ if boolValue {
+ cfg.TLSConfig = "true"
+ } else {
+ cfg.TLSConfig = "false"
+ }
+ } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
+ cfg.TLSConfig = vl
+ } else {
+ name, err := url.QueryUnescape(value)
+ if err != nil {
+ return fmt.Errorf("invalid value for TLS config name: %v", err)
+ }
+ cfg.TLSConfig = name
+ }
+
+ // I/O write Timeout
+ case "writeTimeout":
+ cfg.WriteTimeout, err = time.ParseDuration(value)
+ if err != nil {
+ return
+ }
+ case "maxAllowedPacket":
+ cfg.MaxAllowedPacket, err = strconv.Atoi(value)
+ if err != nil {
+ return
+ }
+ default:
+ // lazy init
+ if cfg.Params == nil {
+ cfg.Params = make(map[string]string)
+ }
+
+ if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
+ return
+ }
+ }
+ }
+
+ return
+}
+
+func ensureHavePort(addr string) string {
+ if _, _, err := net.SplitHostPort(addr); err != nil {
+ return net.JoinHostPort(addr, "3306")
+ }
+ return addr
+}
diff --git a/vendor/mysql/dsn_test.go b/vendor/mysql/dsn_test.go
new file mode 100644
index 0000000..89815b3
--- /dev/null
+++ b/vendor/mysql/dsn_test.go
@@ -0,0 +1,415 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "crypto/tls"
+ "fmt"
+ "net/url"
+ "reflect"
+ "testing"
+ "time"
+)
+
+var testDSNs = []struct {
+ in string
+ out *Config
+}{{
+ "username:password@protocol(address)/dbname?param=value",
+ &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+}, {
+ "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true",
+ &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true},
+}, {
+ "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true",
+ &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true},
+}, {
+ "user@unix(/path/to/socket)/dbname?charset=utf8",
+ &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+}, {
+ "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true",
+ &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"},
+}, {
+ "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
+ &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"},
+}, {
+ "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true",
+ &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true},
+}, {
+ "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0",
+ &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false},
+}, {
+ "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
+ &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+}, {
+ "/dbname",
+ &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+}, {
+ "@/",
+ &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+}, {
+ "/",
+ &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+}, {
+ "",
+ &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+}, {
+ "user:p@/ssword@/",
+ &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+}, {
+ "unix/?arg=%2Fsome%2Fpath.ext",
+ &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+}, {
+ "tcp(127.0.0.1)/dbname",
+ &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+}, {
+ "tcp(de:ad:be:ef::ca:fe)/dbname",
+ &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
+},
+}
+
+func TestDSNParser(t *testing.T) {
+ for i, tst := range testDSNs {
+ cfg, err := ParseDSN(tst.in)
+ if err != nil {
+ t.Error(err.Error())
+ }
+
+ // pointer not static
+ cfg.tls = nil
+
+ if !reflect.DeepEqual(cfg, tst.out) {
+ t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
+ }
+ }
+}
+
+func TestDSNParserInvalid(t *testing.T) {
+ var invalidDSNs = []string{
+ "@net(addr/", // no closing brace
+ "@tcp(/", // no closing brace
+ "tcp(/", // no closing brace
+ "(/", // no closing brace
+ "net(addr)//", // unescaped
+ "User:pass@tcp(1.2.3.4:3306)", // no trailing slash
+ "net()/", // unknown default addr
+ //"/dbname?arg=/some/unescaped/path",
+ }
+
+ for i, tst := range invalidDSNs {
+ if _, err := ParseDSN(tst); err == nil {
+ t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
+ }
+ }
+}
+
+func TestDSNReformat(t *testing.T) {
+ for i, tst := range testDSNs {
+ dsn1 := tst.in
+ cfg1, err := ParseDSN(dsn1)
+ if err != nil {
+ t.Error(err.Error())
+ continue
+ }
+ cfg1.tls = nil // pointer not static
+ res1 := fmt.Sprintf("%+v", cfg1)
+
+ dsn2 := cfg1.FormatDSN()
+ cfg2, err := ParseDSN(dsn2)
+ if err != nil {
+ t.Error(err.Error())
+ continue
+ }
+ cfg2.tls = nil // pointer not static
+ res2 := fmt.Sprintf("%+v", cfg2)
+
+ if res1 != res2 {
+ t.Errorf("%d. %q does not match %q", i, res2, res1)
+ }
+ }
+}
+
+func TestDSNServerPubKey(t *testing.T) {
+ baseDSN := "User:password@tcp(localhost:5555)/dbname?serverPubKey="
+
+ RegisterServerPubKey("testKey", testPubKeyRSA)
+ defer DeregisterServerPubKey("testKey")
+
+ tst := baseDSN + "testKey"
+ cfg, err := ParseDSN(tst)
+ if err != nil {
+ t.Error(err.Error())
+ }
+
+ if cfg.ServerPubKey != "testKey" {
+ t.Errorf("unexpected cfg.ServerPubKey value: %v", cfg.ServerPubKey)
+ }
+ if cfg.pubKey != testPubKeyRSA {
+ t.Error("pub key pointer doesn't match")
+ }
+
+ // Key is missing
+ tst = baseDSN + "invalid_name"
+ cfg, err = ParseDSN(tst)
+ if err == nil {
+ t.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg)
+ }
+}
+
+func TestDSNServerPubKeyQueryEscape(t *testing.T) {
+ const name = "&%!:"
+ dsn := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + url.QueryEscape(name)
+
+ RegisterServerPubKey(name, testPubKeyRSA)
+ defer DeregisterServerPubKey(name)
+
+ cfg, err := ParseDSN(dsn)
+ if err != nil {
+ t.Error(err.Error())
+ }
+
+ if cfg.pubKey != testPubKeyRSA {
+ t.Error("pub key pointer doesn't match")
+ }
+}
+
+func TestDSNWithCustomTLS(t *testing.T) {
+ baseDSN := "User:password@tcp(localhost:5555)/dbname?tls="
+ tlsCfg := tls.Config{}
+
+ RegisterTLSConfig("utils_test", &tlsCfg)
+ defer DeregisterTLSConfig("utils_test")
+
+ // Custom TLS is missing
+ tst := baseDSN + "invalid_tls"
+ cfg, err := ParseDSN(tst)
+ if err == nil {
+ t.Errorf("invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg)
+ }
+
+ tst = baseDSN + "utils_test"
+
+ // Custom TLS with a server name
+ name := "foohost"
+ tlsCfg.ServerName = name
+ cfg, err = ParseDSN(tst)
+
+ if err != nil {
+ t.Error(err.Error())
+ } else if cfg.tls.ServerName != name {
+ t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
+ }
+
+ // Custom TLS without a server name
+ name = "localhost"
+ tlsCfg.ServerName = ""
+ cfg, err = ParseDSN(tst)
+
+ if err != nil {
+ t.Error(err.Error())
+ } else if cfg.tls.ServerName != name {
+ t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
+ } else if tlsCfg.ServerName != "" {
+ t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst)
+ }
+}
+
+func TestDSNTLSConfig(t *testing.T) {
+ expectedServerName := "example.com"
+ dsn := "tcp(example.com:1234)/?tls=true"
+
+ cfg, err := ParseDSN(dsn)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ if cfg.tls == nil {
+ t.Error("cfg.tls should not be nil")
+ }
+ if cfg.tls.ServerName != expectedServerName {
+ t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
+ }
+
+ dsn = "tcp(example.com)/?tls=true"
+ cfg, err = ParseDSN(dsn)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ if cfg.tls == nil {
+ t.Error("cfg.tls should not be nil")
+ }
+ if cfg.tls.ServerName != expectedServerName {
+ t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
+ }
+}
+
+func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
+ const configKey = "&%!:"
+ dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey)
+ name := "foohost"
+ tlsCfg := tls.Config{ServerName: name}
+
+ RegisterTLSConfig(configKey, &tlsCfg)
+ defer DeregisterTLSConfig(configKey)
+
+ cfg, err := ParseDSN(dsn)
+
+ if err != nil {
+ t.Error(err.Error())
+ } else if cfg.tls.ServerName != name {
+ t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn)
+ }
+}
+
+func TestDSNUnsafeCollation(t *testing.T) {
+ _, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
+ if err != errInvalidDSNUnsafeCollation {
+ t.Errorf("expected %v, got %v", errInvalidDSNUnsafeCollation, err)
+ }
+
+ _, err = ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false")
+ if err != nil {
+ t.Errorf("expected %v, got %v", nil, err)
+ }
+
+ _, err = ParseDSN("/dbname?collation=gbk_chinese_ci")
+ if err != nil {
+ t.Errorf("expected %v, got %v", nil, err)
+ }
+
+ _, err = ParseDSN("/dbname?collation=ascii_bin&interpolateParams=true")
+ if err != nil {
+ t.Errorf("expected %v, got %v", nil, err)
+ }
+
+ _, err = ParseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true")
+ if err != nil {
+ t.Errorf("expected %v, got %v", nil, err)
+ }
+
+ _, err = ParseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true")
+ if err != nil {
+ t.Errorf("expected %v, got %v", nil, err)
+ }
+
+ _, err = ParseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true")
+ if err != nil {
+ t.Errorf("expected %v, got %v", nil, err)
+ }
+}
+
+func TestParamsAreSorted(t *testing.T) {
+ expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo"
+ cfg := NewConfig()
+ cfg.DBName = "dbname"
+ cfg.InterpolateParams = true
+ cfg.Params = map[string]string{
+ "quux": "loo",
+ "foobar": "baz",
+ }
+ actual := cfg.FormatDSN()
+ if actual != expected {
+ t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual)
+ }
+}
+
+func TestCloneConfig(t *testing.T) {
+ RegisterServerPubKey("testKey", testPubKeyRSA)
+ defer DeregisterServerPubKey("testKey")
+
+ expectedServerName := "example.com"
+ dsn := "tcp(example.com:1234)/?tls=true&foobar=baz&serverPubKey=testKey"
+ cfg, err := ParseDSN(dsn)
+ if err != nil {
+ t.Fatal(err.Error())
+ }
+
+ cfg2 := cfg.Clone()
+ if cfg == cfg2 {
+ t.Errorf("Config.Clone did not create a separate config struct")
+ }
+
+ if cfg2.tls.ServerName != expectedServerName {
+ t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
+ }
+
+ cfg2.tls.ServerName = "example2.com"
+ if cfg.tls.ServerName == cfg2.tls.ServerName {
+ t.Errorf("changed cfg.tls.Server name should not propagate to original Config")
+ }
+
+ if _, ok := cfg2.Params["foobar"]; !ok {
+ t.Errorf("cloned Config is missing custom params")
+ }
+
+ delete(cfg2.Params, "foobar")
+
+ if _, ok := cfg.Params["foobar"]; !ok {
+ t.Errorf("custom params in cloned Config should not propagate to original Config")
+ }
+
+ if !reflect.DeepEqual(cfg.pubKey, cfg2.pubKey) {
+ t.Errorf("public key in Config should be identical")
+ }
+}
+
+func TestNormalizeTLSConfig(t *testing.T) {
+ tt := []struct {
+ tlsConfig string
+ want *tls.Config
+ }{
+ {"", nil},
+ {"false", nil},
+ {"true", &tls.Config{ServerName: "myserver"}},
+ {"skip-verify", &tls.Config{InsecureSkipVerify: true}},
+ {"preferred", &tls.Config{InsecureSkipVerify: true}},
+ {"test_tls_config", &tls.Config{ServerName: "myServerName"}},
+ }
+
+ RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"})
+ defer func() { DeregisterTLSConfig("test_tls_config") }()
+
+ for _, tc := range tt {
+ t.Run(tc.tlsConfig, func(t *testing.T) {
+ cfg := &Config{
+ Addr: "myserver:3306",
+ TLSConfig: tc.tlsConfig,
+ }
+
+ cfg.normalize()
+
+ if cfg.tls == nil {
+ if tc.want != nil {
+ t.Fatal("wanted a tls config but got nil instead")
+ }
+ return
+ }
+
+ if cfg.tls.ServerName != tc.want.ServerName {
+ t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')",
+ tc.want.ServerName, cfg.tls.ServerName)
+ }
+ if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify {
+ t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)",
+ tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify)
+ }
+ })
+ }
+}
+
+func BenchmarkParseDSN(b *testing.B) {
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ for _, tst := range testDSNs {
+ if _, err := ParseDSN(tst.in); err != nil {
+ b.Error(err.Error())
+ }
+ }
+ }
+}
diff --git a/vendor/mysql/errors.go b/vendor/mysql/errors.go
new file mode 100644
index 0000000..760782f
--- /dev/null
+++ b/vendor/mysql/errors.go
@@ -0,0 +1,65 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "errors"
+ "fmt"
+ "log"
+ "os"
+)
+
+// Various errors the driver might return. Can change between driver versions.
+var (
+ ErrInvalidConn = errors.New("invalid connection")
+ ErrMalformPkt = errors.New("malformed packet")
+ ErrNoTLS = errors.New("TLS requested but server does not support TLS")
+ ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
+ ErrNativePassword = errors.New("this user requires mysql native password authentication.")
+ ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
+ ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
+ ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
+ ErrPktSync = errors.New("commands out of sync. You can't run this command now")
+ ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
+ ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
+ ErrBusyBuffer = errors.New("busy buffer")
+
+ // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
+ // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
+ // to trigger a resend.
+ // See https://github.com/go-sql-driver/mysql/pull/302
+ errBadConnNoWrite = errors.New("bad connection")
+)
+
+var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
+
+// Logger is used to log critical error messages.
+type Logger interface {
+ Print(v ...interface{})
+}
+
+// SetLogger is used to set the logger for critical errors.
+// The initial logger is os.Stderr.
+func SetLogger(logger Logger) error {
+ if logger == nil {
+ return errors.New("logger is nil")
+ }
+ errLog = logger
+ return nil
+}
+
+// MySQLError is an error type which represents a single MySQL error
+type MySQLError struct {
+ Number uint16
+ Message string
+}
+
+func (me *MySQLError) Error() string {
+ return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
+}
diff --git a/vendor/mysql/errors_test.go b/vendor/mysql/errors_test.go
new file mode 100644
index 0000000..96f9126
--- /dev/null
+++ b/vendor/mysql/errors_test.go
@@ -0,0 +1,42 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "bytes"
+ "log"
+ "testing"
+)
+
+func TestErrorsSetLogger(t *testing.T) {
+ previous := errLog
+ defer func() {
+ errLog = previous
+ }()
+
+ // set up logger
+ const expected = "prefix: test\n"
+ buffer := bytes.NewBuffer(make([]byte, 0, 64))
+ logger := log.New(buffer, "prefix: ", 0)
+
+ // print
+ SetLogger(logger)
+ errLog.Print("test")
+
+ // check result
+ if actual := buffer.String(); actual != expected {
+ t.Errorf("expected %q, got %q", expected, actual)
+ }
+}
+
+func TestErrorsStrictIgnoreNotes(t *testing.T) {
+ runTests(t, dsn+"&sql_notes=false", func(dbt *DBTest) {
+ dbt.mustExec("DROP TABLE IF EXISTS does_not_exist")
+ })
+}
diff --git a/vendor/mysql/fields.go b/vendor/mysql/fields.go
new file mode 100644
index 0000000..ed6c7a3
--- /dev/null
+++ b/vendor/mysql/fields.go
@@ -0,0 +1,194 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "database/sql"
+ "reflect"
+)
+
+func (mf *mysqlField) typeDatabaseName() string {
+ switch mf.fieldType {
+ case fieldTypeBit:
+ return "BIT"
+ case fieldTypeBLOB:
+ if mf.charSet != collations[binaryCollation] {
+ return "TEXT"
+ }
+ return "BLOB"
+ case fieldTypeDate:
+ return "DATE"
+ case fieldTypeDateTime:
+ return "DATETIME"
+ case fieldTypeDecimal:
+ return "DECIMAL"
+ case fieldTypeDouble:
+ return "DOUBLE"
+ case fieldTypeEnum:
+ return "ENUM"
+ case fieldTypeFloat:
+ return "FLOAT"
+ case fieldTypeGeometry:
+ return "GEOMETRY"
+ case fieldTypeInt24:
+ return "MEDIUMINT"
+ case fieldTypeJSON:
+ return "JSON"
+ case fieldTypeLong:
+ return "INT"
+ case fieldTypeLongBLOB:
+ if mf.charSet != collations[binaryCollation] {
+ return "LONGTEXT"
+ }
+ return "LONGBLOB"
+ case fieldTypeLongLong:
+ return "BIGINT"
+ case fieldTypeMediumBLOB:
+ if mf.charSet != collations[binaryCollation] {
+ return "MEDIUMTEXT"
+ }
+ return "MEDIUMBLOB"
+ case fieldTypeNewDate:
+ return "DATE"
+ case fieldTypeNewDecimal:
+ return "DECIMAL"
+ case fieldTypeNULL:
+ return "NULL"
+ case fieldTypeSet:
+ return "SET"
+ case fieldTypeShort:
+ return "SMALLINT"
+ case fieldTypeString:
+ if mf.charSet == collations[binaryCollation] {
+ return "BINARY"
+ }
+ return "CHAR"
+ case fieldTypeTime:
+ return "TIME"
+ case fieldTypeTimestamp:
+ return "TIMESTAMP"
+ case fieldTypeTiny:
+ return "TINYINT"
+ case fieldTypeTinyBLOB:
+ if mf.charSet != collations[binaryCollation] {
+ return "TINYTEXT"
+ }
+ return "TINYBLOB"
+ case fieldTypeVarChar:
+ if mf.charSet == collations[binaryCollation] {
+ return "VARBINARY"
+ }
+ return "VARCHAR"
+ case fieldTypeVarString:
+ if mf.charSet == collations[binaryCollation] {
+ return "VARBINARY"
+ }
+ return "VARCHAR"
+ case fieldTypeYear:
+ return "YEAR"
+ default:
+ return ""
+ }
+}
+
+var (
+ scanTypeFloat32 = reflect.TypeOf(float32(0))
+ scanTypeFloat64 = reflect.TypeOf(float64(0))
+ scanTypeInt8 = reflect.TypeOf(int8(0))
+ scanTypeInt16 = reflect.TypeOf(int16(0))
+ scanTypeInt32 = reflect.TypeOf(int32(0))
+ scanTypeInt64 = reflect.TypeOf(int64(0))
+ scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
+ scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
+ scanTypeNullTime = reflect.TypeOf(nullTime{})
+ scanTypeUint8 = reflect.TypeOf(uint8(0))
+ scanTypeUint16 = reflect.TypeOf(uint16(0))
+ scanTypeUint32 = reflect.TypeOf(uint32(0))
+ scanTypeUint64 = reflect.TypeOf(uint64(0))
+ scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
+ scanTypeUnknown = reflect.TypeOf(new(interface{}))
+)
+
+type mysqlField struct {
+ tableName string
+ name string
+ length uint32
+ flags fieldFlag
+ fieldType fieldType
+ decimals byte
+ charSet uint8
+}
+
+func (mf *mysqlField) scanType() reflect.Type {
+ switch mf.fieldType {
+ case fieldTypeTiny:
+ if mf.flags&flagNotNULL != 0 {
+ if mf.flags&flagUnsigned != 0 {
+ return scanTypeUint8
+ }
+ return scanTypeInt8
+ }
+ return scanTypeNullInt
+
+ case fieldTypeShort, fieldTypeYear:
+ if mf.flags&flagNotNULL != 0 {
+ if mf.flags&flagUnsigned != 0 {
+ return scanTypeUint16
+ }
+ return scanTypeInt16
+ }
+ return scanTypeNullInt
+
+ case fieldTypeInt24, fieldTypeLong:
+ if mf.flags&flagNotNULL != 0 {
+ if mf.flags&flagUnsigned != 0 {
+ return scanTypeUint32
+ }
+ return scanTypeInt32
+ }
+ return scanTypeNullInt
+
+ case fieldTypeLongLong:
+ if mf.flags&flagNotNULL != 0 {
+ if mf.flags&flagUnsigned != 0 {
+ return scanTypeUint64
+ }
+ return scanTypeInt64
+ }
+ return scanTypeNullInt
+
+ case fieldTypeFloat:
+ if mf.flags&flagNotNULL != 0 {
+ return scanTypeFloat32
+ }
+ return scanTypeNullFloat
+
+ case fieldTypeDouble:
+ if mf.flags&flagNotNULL != 0 {
+ return scanTypeFloat64
+ }
+ return scanTypeNullFloat
+
+ case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
+ fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
+ fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
+ fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
+ fieldTypeTime:
+ return scanTypeRawBytes
+
+ case fieldTypeDate, fieldTypeNewDate,
+ fieldTypeTimestamp, fieldTypeDateTime:
+ // NullTime is always returned for more consistent behavior as it can
+ // handle both cases of parseTime regardless if the field is nullable.
+ return scanTypeNullTime
+
+ default:
+ return scanTypeUnknown
+ }
+}
diff --git a/vendor/mysql/fuzz.go b/vendor/mysql/fuzz.go
new file mode 100644
index 0000000..fa75adf
--- /dev/null
+++ b/vendor/mysql/fuzz.go
@@ -0,0 +1,24 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
+//
+// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build gofuzz
+
+package mysql
+
+import (
+ "database/sql"
+)
+
+func Fuzz(data []byte) int {
+ db, err := sql.Open("mysql", string(data))
+ if err != nil {
+ return 0
+ }
+ db.Close()
+ return 1
+}
diff --git a/vendor/mysql/infile.go b/vendor/mysql/infile.go
new file mode 100644
index 0000000..60effdf
--- /dev/null
+++ b/vendor/mysql/infile.go
@@ -0,0 +1,182 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "strings"
+ "sync"
+)
+
+var (
+ fileRegister map[string]bool
+ fileRegisterLock sync.RWMutex
+ readerRegister map[string]func() io.Reader
+ readerRegisterLock sync.RWMutex
+)
+
+// RegisterLocalFile adds the given file to the file allowlist,
+// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
+// Alternatively you can allow the use of all local files with
+// the DSN parameter 'allowAllFiles=true'
+//
+// filePath := "/home/gopher/data.csv"
+// mysql.RegisterLocalFile(filePath)
+// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
+// if err != nil {
+// ...
+//
+func RegisterLocalFile(filePath string) {
+ fileRegisterLock.Lock()
+ // lazy map init
+ if fileRegister == nil {
+ fileRegister = make(map[string]bool)
+ }
+
+ fileRegister[strings.Trim(filePath, `"`)] = true
+ fileRegisterLock.Unlock()
+}
+
+// DeregisterLocalFile removes the given filepath from the allowlist.
+func DeregisterLocalFile(filePath string) {
+ fileRegisterLock.Lock()
+ delete(fileRegister, strings.Trim(filePath, `"`))
+ fileRegisterLock.Unlock()
+}
+
+// RegisterReaderHandler registers a handler function which is used
+// to receive a io.Reader.
+// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
+// If the handler returns a io.ReadCloser Close() is called when the
+// request is finished.
+//
+// mysql.RegisterReaderHandler("data", func() io.Reader {
+// var csvReader io.Reader // Some Reader that returns CSV data
+// ... // Open Reader here
+// return csvReader
+// })
+// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
+// if err != nil {
+// ...
+//
+func RegisterReaderHandler(name string, handler func() io.Reader) {
+ readerRegisterLock.Lock()
+ // lazy map init
+ if readerRegister == nil {
+ readerRegister = make(map[string]func() io.Reader)
+ }
+
+ readerRegister[name] = handler
+ readerRegisterLock.Unlock()
+}
+
+// DeregisterReaderHandler removes the ReaderHandler function with
+// the given name from the registry.
+func DeregisterReaderHandler(name string) {
+ readerRegisterLock.Lock()
+ delete(readerRegister, name)
+ readerRegisterLock.Unlock()
+}
+
+func deferredClose(err *error, closer io.Closer) {
+ closeErr := closer.Close()
+ if *err == nil {
+ *err = closeErr
+ }
+}
+
+func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
+ var rdr io.Reader
+ var data []byte
+ packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
+ if mc.maxWriteSize < packetSize {
+ packetSize = mc.maxWriteSize
+ }
+
+ if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader
+ // The server might return an an absolute path. See issue #355.
+ name = name[idx+8:]
+
+ readerRegisterLock.RLock()
+ handler, inMap := readerRegister[name]
+ readerRegisterLock.RUnlock()
+
+ if inMap {
+ rdr = handler()
+ if rdr != nil {
+ if cl, ok := rdr.(io.Closer); ok {
+ defer deferredClose(&err, cl)
+ }
+ } else {
+ err = fmt.Errorf("Reader '%s' is <nil>", name)
+ }
+ } else {
+ err = fmt.Errorf("Reader '%s' is not registered", name)
+ }
+ } else { // File
+ name = strings.Trim(name, `"`)
+ fileRegisterLock.RLock()
+ fr := fileRegister[name]
+ fileRegisterLock.RUnlock()
+ if mc.cfg.AllowAllFiles || fr {
+ var file *os.File
+ var fi os.FileInfo
+
+ if file, err = os.Open(name); err == nil {
+ defer deferredClose(&err, file)
+
+ // get file size
+ if fi, err = file.Stat(); err == nil {
+ rdr = file
+ if fileSize := int(fi.Size()); fileSize < packetSize {
+ packetSize = fileSize
+ }
+ }
+ }
+ } else {
+ err = fmt.Errorf("local file '%s' is not registered", name)
+ }
+ }
+
+ // send content packets
+ // if packetSize == 0, the Reader contains no data
+ if err == nil && packetSize > 0 {
+ data := make([]byte, 4+packetSize)
+ var n int
+ for err == nil {
+ n, err = rdr.Read(data[4:])
+ if n > 0 {
+ if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
+ return ioErr
+ }
+ }
+ }
+ if err == io.EOF {
+ err = nil
+ }
+ }
+
+ // send empty packet (termination)
+ if data == nil {
+ data = make([]byte, 4)
+ }
+ if ioErr := mc.writePacket(data[:4]); ioErr != nil {
+ return ioErr
+ }
+
+ // read OK packet
+ if err == nil {
+ return mc.readResultOK()
+ }
+
+ mc.readPacket()
+ return err
+}
diff --git a/vendor/mysql/nulltime.go b/vendor/mysql/nulltime.go
new file mode 100644
index 0000000..651723a
--- /dev/null
+++ b/vendor/mysql/nulltime.go
@@ -0,0 +1,50 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "database/sql/driver"
+ "fmt"
+ "time"
+)
+
+// Scan implements the Scanner interface.
+// The value type must be time.Time or string / []byte (formatted time-string),
+// otherwise Scan fails.
+func (nt *NullTime) Scan(value interface{}) (err error) {
+ if value == nil {
+ nt.Time, nt.Valid = time.Time{}, false
+ return
+ }
+
+ switch v := value.(type) {
+ case time.Time:
+ nt.Time, nt.Valid = v, true
+ return
+ case []byte:
+ nt.Time, err = parseDateTime(v, time.UTC)
+ nt.Valid = (err == nil)
+ return
+ case string:
+ nt.Time, err = parseDateTime([]byte(v), time.UTC)
+ nt.Valid = (err == nil)
+ return
+ }
+
+ nt.Valid = false
+ return fmt.Errorf("Can't convert %T to time.Time", value)
+}
+
+// Value implements the driver Valuer interface.
+func (nt NullTime) Value() (driver.Value, error) {
+ if !nt.Valid {
+ return nil, nil
+ }
+ return nt.Time, nil
+}
diff --git a/vendor/mysql/nulltime_go113.go b/vendor/mysql/nulltime_go113.go
new file mode 100644
index 0000000..453b4b3
--- /dev/null
+++ b/vendor/mysql/nulltime_go113.go
@@ -0,0 +1,40 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build go1.13
+
+package mysql
+
+import (
+ "database/sql"
+)
+
+// NullTime represents a time.Time that may be NULL.
+// NullTime implements the Scanner interface so
+// it can be used as a scan destination:
+//
+// var nt NullTime
+// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
+// ...
+// if nt.Valid {
+// // use nt.Time
+// } else {
+// // NULL value
+// }
+//
+// This NullTime implementation is not driver-specific
+//
+// Deprecated: NullTime doesn't honor the loc DSN parameter.
+// NullTime.Scan interprets a time as UTC, not the loc DSN parameter.
+// Use sql.NullTime instead.
+type NullTime sql.NullTime
+
+// for internal use.
+// the mysql package uses sql.NullTime if it is available.
+// if not, the package uses mysql.NullTime.
+type nullTime = sql.NullTime // sql.NullTime is available
diff --git a/vendor/mysql/nulltime_legacy.go b/vendor/mysql/nulltime_legacy.go
new file mode 100644
index 0000000..9f7ae27
--- /dev/null
+++ b/vendor/mysql/nulltime_legacy.go
@@ -0,0 +1,39 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build !go1.13
+
+package mysql
+
+import (
+ "time"
+)
+
+// NullTime represents a time.Time that may be NULL.
+// NullTime implements the Scanner interface so
+// it can be used as a scan destination:
+//
+// var nt NullTime
+// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
+// ...
+// if nt.Valid {
+// // use nt.Time
+// } else {
+// // NULL value
+// }
+//
+// This NullTime implementation is not driver-specific
+type NullTime struct {
+ Time time.Time
+ Valid bool // Valid is true if Time is not NULL
+}
+
+// for internal use.
+// the mysql package uses sql.NullTime if it is available.
+// if not, the package uses mysql.NullTime.
+type nullTime = NullTime // sql.NullTime is not available
diff --git a/vendor/mysql/nulltime_test.go b/vendor/mysql/nulltime_test.go
new file mode 100644
index 0000000..a14ec06
--- /dev/null
+++ b/vendor/mysql/nulltime_test.go
@@ -0,0 +1,62 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "database/sql"
+ "database/sql/driver"
+ "testing"
+ "time"
+)
+
+var (
+ // Check implementation of interfaces
+ _ driver.Valuer = NullTime{}
+ _ sql.Scanner = (*NullTime)(nil)
+)
+
+func TestScanNullTime(t *testing.T) {
+ var scanTests = []struct {
+ in interface{}
+ error bool
+ valid bool
+ time time.Time
+ }{
+ {tDate, false, true, tDate},
+ {sDate, false, true, tDate},
+ {[]byte(sDate), false, true, tDate},
+ {tDateTime, false, true, tDateTime},
+ {sDateTime, false, true, tDateTime},
+ {[]byte(sDateTime), false, true, tDateTime},
+ {tDate0, false, true, tDate0},
+ {sDate0, false, true, tDate0},
+ {[]byte(sDate0), false, true, tDate0},
+ {sDateTime0, false, true, tDate0},
+ {[]byte(sDateTime0), false, true, tDate0},
+ {"", true, false, tDate0},
+ {"1234", true, false, tDate0},
+ {0, true, false, tDate0},
+ }
+
+ var nt = NullTime{}
+ var err error
+
+ for _, tst := range scanTests {
+ err = nt.Scan(tst.in)
+ if (err != nil) != tst.error {
+ t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil))
+ }
+ if nt.Valid != tst.valid {
+ t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid)
+ }
+ if nt.Time != tst.time {
+ t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time)
+ }
+ }
+}
diff --git a/vendor/mysql/packets.go b/vendor/mysql/packets.go
new file mode 100644
index 0000000..6664e5a
--- /dev/null
+++ b/vendor/mysql/packets.go
@@ -0,0 +1,1349 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "bytes"
+ "crypto/tls"
+ "database/sql/driver"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "time"
+)
+
+// Packets documentation:
+// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
+
+// Read packet to buffer 'data'
+func (mc *mysqlConn) readPacket() ([]byte, error) {
+ var prevData []byte
+ for {
+ // read packet header
+ data, err := mc.buf.readNext(4)
+ if err != nil {
+ if cerr := mc.canceled.Value(); cerr != nil {
+ return nil, cerr
+ }
+ errLog.Print(err)
+ mc.Close()
+ return nil, ErrInvalidConn
+ }
+
+ // packet length [24 bit]
+ pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
+
+ // check packet sync [8 bit]
+ if data[3] != mc.sequence {
+ if data[3] > mc.sequence {
+ return nil, ErrPktSyncMul
+ }
+ return nil, ErrPktSync
+ }
+ mc.sequence++
+
+ // packets with length 0 terminate a previous packet which is a
+ // multiple of (2^24)-1 bytes long
+ if pktLen == 0 {
+ // there was no previous packet
+ if prevData == nil {
+ errLog.Print(ErrMalformPkt)
+ mc.Close()
+ return nil, ErrInvalidConn
+ }
+
+ return prevData, nil
+ }
+
+ // read packet body [pktLen bytes]
+ data, err = mc.buf.readNext(pktLen)
+ if err != nil {
+ if cerr := mc.canceled.Value(); cerr != nil {
+ return nil, cerr
+ }
+ errLog.Print(err)
+ mc.Close()
+ return nil, ErrInvalidConn
+ }
+
+ // return data if this was the last packet
+ if pktLen < maxPacketSize {
+ // zero allocations for non-split packets
+ if prevData == nil {
+ return data, nil
+ }
+
+ return append(prevData, data...), nil
+ }
+
+ prevData = append(prevData, data...)
+ }
+}
+
+// Write packet buffer 'data'
+func (mc *mysqlConn) writePacket(data []byte) error {
+ pktLen := len(data) - 4
+
+ if pktLen > mc.maxAllowedPacket {
+ return ErrPktTooLarge
+ }
+
+ // Perform a stale connection check. We only perform this check for
+ // the first query on a connection that has been checked out of the
+ // connection pool: a fresh connection from the pool is more likely
+ // to be stale, and it has not performed any previous writes that
+ // could cause data corruption, so it's safe to return ErrBadConn
+ // if the check fails.
+ if mc.reset {
+ mc.reset = false
+ conn := mc.netConn
+ if mc.rawConn != nil {
+ conn = mc.rawConn
+ }
+ var err error
+ // If this connection has a ReadTimeout which we've been setting on
+ // reads, reset it to its default value before we attempt a non-blocking
+ // read, otherwise the scheduler will just time us out before we can read
+ if mc.cfg.ReadTimeout != 0 {
+ err = conn.SetReadDeadline(time.Time{})
+ }
+ if err == nil && mc.cfg.CheckConnLiveness {
+ err = connCheck(conn)
+ }
+ if err != nil {
+ errLog.Print("closing bad idle connection: ", err)
+ mc.Close()
+ return driver.ErrBadConn
+ }
+ }
+
+ for {
+ var size int
+ if pktLen >= maxPacketSize {
+ data[0] = 0xff
+ data[1] = 0xff
+ data[2] = 0xff
+ size = maxPacketSize
+ } else {
+ data[0] = byte(pktLen)
+ data[1] = byte(pktLen >> 8)
+ data[2] = byte(pktLen >> 16)
+ size = pktLen
+ }
+ data[3] = mc.sequence
+
+ // Write packet
+ if mc.writeTimeout > 0 {
+ if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
+ return err
+ }
+ }
+
+ n, err := mc.netConn.Write(data[:4+size])
+ if err == nil && n == 4+size {
+ mc.sequence++
+ if size != maxPacketSize {
+ return nil
+ }
+ pktLen -= size
+ data = data[size:]
+ continue
+ }
+
+ // Handle error
+ if err == nil { // n != len(data)
+ mc.cleanup()
+ errLog.Print(ErrMalformPkt)
+ } else {
+ if cerr := mc.canceled.Value(); cerr != nil {
+ return cerr
+ }
+ if n == 0 && pktLen == len(data)-4 {
+ // only for the first loop iteration when nothing was written yet
+ return errBadConnNoWrite
+ }
+ mc.cleanup()
+ errLog.Print(err)
+ }
+ return ErrInvalidConn
+ }
+}
+
+/******************************************************************************
+* Initialization Process *
+******************************************************************************/
+
+// Handshake Initialization Packet
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
+func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
+ data, err = mc.readPacket()
+ if err != nil {
+ // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
+ // in connection initialization we don't risk retrying non-idempotent actions.
+ if err == ErrInvalidConn {
+ return nil, "", driver.ErrBadConn
+ }
+ return
+ }
+
+ if data[0] == iERR {
+ return nil, "", mc.handleErrorPacket(data)
+ }
+
+ // protocol version [1 byte]
+ if data[0] < minProtocolVersion {
+ return nil, "", fmt.Errorf(
+ "unsupported protocol version %d. Version %d or higher is required",
+ data[0],
+ minProtocolVersion,
+ )
+ }
+
+ // server version [null terminated string]
+ // connection id [4 bytes]
+ pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
+
+ // first part of the password cipher [8 bytes]
+ authData := data[pos : pos+8]
+
+ // (filler) always 0x00 [1 byte]
+ pos += 8 + 1
+
+ // capability flags (lower 2 bytes) [2 bytes]
+ mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
+ if mc.flags&clientProtocol41 == 0 {
+ return nil, "", ErrOldProtocol
+ }
+ if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
+ if mc.cfg.TLSConfig == "preferred" {
+ mc.cfg.tls = nil
+ } else {
+ return nil, "", ErrNoTLS
+ }
+ }
+ pos += 2
+
+ if len(data) > pos {
+ // character set [1 byte]
+ // status flags [2 bytes]
+ // capability flags (upper 2 bytes) [2 bytes]
+ // length of auth-plugin-data [1 byte]
+ // reserved (all [00]) [10 bytes]
+ pos += 1 + 2 + 2 + 1 + 10
+
+ // second part of the password cipher [mininum 13 bytes],
+ // where len=MAX(13, length of auth-plugin-data - 8)
+ //
+ // The web documentation is ambiguous about the length. However,
+ // according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
+ // the 13th byte is "\0 byte, terminating the second part of
+ // a scramble". So the second part of the password cipher is
+ // a NULL terminated string that's at least 13 bytes with the
+ // last byte being NULL.
+ //
+ // The official Python library uses the fixed length 12
+ // which seems to work but technically could have a hidden bug.
+ authData = append(authData, data[pos:pos+12]...)
+ pos += 13
+
+ // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
+ // \NUL otherwise
+ if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
+ plugin = string(data[pos : pos+end])
+ } else {
+ plugin = string(data[pos:])
+ }
+
+ // make a memory safe copy of the cipher slice
+ var b [20]byte
+ copy(b[:], authData)
+ return b[:], plugin, nil
+ }
+
+ // make a memory safe copy of the cipher slice
+ var b [8]byte
+ copy(b[:], authData)
+ return b[:], plugin, nil
+}
+
+// Client Authentication Packet
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
+func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
+ // Adjust client flags based on server support
+ clientFlags := clientProtocol41 |
+ clientSecureConn |
+ clientLongPassword |
+ clientTransactions |
+ clientLocalFiles |
+ clientPluginAuth |
+ clientMultiResults |
+ mc.flags&clientLongFlag
+
+ if mc.cfg.ClientFoundRows {
+ clientFlags |= clientFoundRows
+ }
+
+ // To enable TLS / SSL
+ if mc.cfg.tls != nil {
+ clientFlags |= clientSSL
+ }
+
+ if mc.cfg.MultiStatements {
+ clientFlags |= clientMultiStatements
+ }
+
+ // encode length of the auth plugin data
+ var authRespLEIBuf [9]byte
+ authRespLen := len(authResp)
+ authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
+ if len(authRespLEI) > 1 {
+ // if the length can not be written in 1 byte, it must be written as a
+ // length encoded integer
+ clientFlags |= clientPluginAuthLenEncClientData
+ }
+
+ pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
+
+ // To specify a db name
+ if n := len(mc.cfg.DBName); n > 0 {
+ clientFlags |= clientConnectWithDB
+ pktLen += n + 1
+ }
+
+ // Calculate packet length and get buffer with that size
+ data, err := mc.buf.takeSmallBuffer(pktLen + 4)
+ if err != nil {
+ // cannot take the buffer. Something must be wrong with the connection
+ errLog.Print(err)
+ return errBadConnNoWrite
+ }
+
+ // ClientFlags [32 bit]
+ data[4] = byte(clientFlags)
+ data[5] = byte(clientFlags >> 8)
+ data[6] = byte(clientFlags >> 16)
+ data[7] = byte(clientFlags >> 24)
+
+ // MaxPacketSize [32 bit] (none)
+ data[8] = 0x00
+ data[9] = 0x00
+ data[10] = 0x00
+ data[11] = 0x00
+
+ // Charset [1 byte]
+ var found bool
+ data[12], found = collations[mc.cfg.Collation]
+ if !found {
+ // Note possibility for false negatives:
+ // could be triggered although the collation is valid if the
+ // collations map does not contain entries the server supports.
+ return errors.New("unknown collation")
+ }
+
+ // Filler [23 bytes] (all 0x00)
+ pos := 13
+ for ; pos < 13+23; pos++ {
+ data[pos] = 0
+ }
+
+ // SSL Connection Request Packet
+ // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
+ if mc.cfg.tls != nil {
+ // Send TLS / SSL request packet
+ if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
+ return err
+ }
+
+ // Switch to TLS
+ tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
+ if err := tlsConn.Handshake(); err != nil {
+ return err
+ }
+ mc.rawConn = mc.netConn
+ mc.netConn = tlsConn
+ mc.buf.nc = tlsConn
+ }
+
+ // User [null terminated string]
+ if len(mc.cfg.User) > 0 {
+ pos += copy(data[pos:], mc.cfg.User)
+ }
+ data[pos] = 0x00
+ pos++
+
+ // Auth Data [length encoded integer]
+ pos += copy(data[pos:], authRespLEI)
+ pos += copy(data[pos:], authResp)
+
+ // Databasename [null terminated string]
+ if len(mc.cfg.DBName) > 0 {
+ pos += copy(data[pos:], mc.cfg.DBName)
+ data[pos] = 0x00
+ pos++
+ }
+
+ pos += copy(data[pos:], plugin)
+ data[pos] = 0x00
+ pos++
+
+ // Send Auth packet
+ return mc.writePacket(data[:pos])
+}
+
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
+func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
+ pktLen := 4 + len(authData)
+ data, err := mc.buf.takeSmallBuffer(pktLen)
+ if err != nil {
+ // cannot take the buffer. Something must be wrong with the connection
+ errLog.Print(err)
+ return errBadConnNoWrite
+ }
+
+ // Add the auth data [EOF]
+ copy(data[4:], authData)
+ return mc.writePacket(data)
+}
+
+/******************************************************************************
+* Command Packets *
+******************************************************************************/
+
+func (mc *mysqlConn) writeCommandPacket(command byte) error {
+ // Reset Packet Sequence
+ mc.sequence = 0
+
+ data, err := mc.buf.takeSmallBuffer(4 + 1)
+ if err != nil {
+ // cannot take the buffer. Something must be wrong with the connection
+ errLog.Print(err)
+ return errBadConnNoWrite
+ }
+
+ // Add command byte
+ data[4] = command
+
+ // Send CMD packet
+ return mc.writePacket(data)
+}
+
+func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
+ // Reset Packet Sequence
+ mc.sequence = 0
+
+ pktLen := 1 + len(arg)
+ data, err := mc.buf.takeBuffer(pktLen + 4)
+ if err != nil {
+ // cannot take the buffer. Something must be wrong with the connection
+ errLog.Print(err)
+ return errBadConnNoWrite
+ }
+
+ // Add command byte
+ data[4] = command
+
+ // Add arg
+ copy(data[5:], arg)
+
+ // Send CMD packet
+ return mc.writePacket(data)
+}
+
+func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
+ // Reset Packet Sequence
+ mc.sequence = 0
+
+ data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
+ if err != nil {
+ // cannot take the buffer. Something must be wrong with the connection
+ errLog.Print(err)
+ return errBadConnNoWrite
+ }
+
+ // Add command byte
+ data[4] = command
+
+ // Add arg [32 bit]
+ data[5] = byte(arg)
+ data[6] = byte(arg >> 8)
+ data[7] = byte(arg >> 16)
+ data[8] = byte(arg >> 24)
+
+ // Send CMD packet
+ return mc.writePacket(data)
+}
+
+/******************************************************************************
+* Result Packets *
+******************************************************************************/
+
+func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
+ data, err := mc.readPacket()
+ if err != nil {
+ return nil, "", err
+ }
+
+ // packet indicator
+ switch data[0] {
+
+ case iOK:
+ return nil, "", mc.handleOkPacket(data)
+
+ case iAuthMoreData:
+ return data[1:], "", err
+
+ case iEOF:
+ if len(data) == 1 {
+ // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
+ return nil, "mysql_old_password", nil
+ }
+ pluginEndIndex := bytes.IndexByte(data, 0x00)
+ if pluginEndIndex < 0 {
+ return nil, "", ErrMalformPkt
+ }
+ plugin := string(data[1:pluginEndIndex])
+ authData := data[pluginEndIndex+1:]
+ return authData, plugin, nil
+
+ default: // Error otherwise
+ return nil, "", mc.handleErrorPacket(data)
+ }
+}
+
+// Returns error if Packet is not an 'Result OK'-Packet
+func (mc *mysqlConn) readResultOK() error {
+ data, err := mc.readPacket()
+ if err != nil {
+ return err
+ }
+
+ if data[0] == iOK {
+ return mc.handleOkPacket(data)
+ }
+ return mc.handleErrorPacket(data)
+}
+
+// Result Set Header Packet
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
+func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
+ data, err := mc.readPacket()
+ if err == nil {
+ switch data[0] {
+
+ case iOK:
+ return 0, mc.handleOkPacket(data)
+
+ case iERR:
+ return 0, mc.handleErrorPacket(data)
+
+ case iLocalInFile:
+ return 0, mc.handleInFileRequest(string(data[1:]))
+ }
+
+ // column count
+ num, _, n := readLengthEncodedInteger(data)
+ if n-len(data) == 0 {
+ return int(num), nil
+ }
+
+ return 0, ErrMalformPkt
+ }
+ return 0, err
+}
+
+// Error Packet
+// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
+func (mc *mysqlConn) handleErrorPacket(data []byte) error {
+ if data[0] != iERR {
+ return ErrMalformPkt
+ }
+
+ // 0xff [1 byte]
+
+ // Error Number [16 bit uint]
+ errno := binary.LittleEndian.Uint16(data[1:3])
+
+ // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
+ // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
+ if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly {
+ // Oops; we are connected to a read-only connection, and won't be able
+ // to issue any write statements. Since RejectReadOnly is configured,
+ // we throw away this connection hoping this one would have write
+ // permission. This is specifically for a possible race condition
+ // during failover (e.g. on AWS Aurora). See README.md for more.
+ //
+ // We explicitly close the connection before returning
+ // driver.ErrBadConn to ensure that `database/sql` purges this
+ // connection and initiates a new one for next statement next time.
+ mc.Close()
+ return driver.ErrBadConn
+ }
+
+ pos := 3
+
+ // SQL State [optional: # + 5bytes string]
+ if data[3] == 0x23 {
+ //sqlstate := string(data[4 : 4+5])
+ pos = 9
+ }
+
+ // Error Message [string]
+ return &MySQLError{
+ Number: errno,
+ Message: string(data[pos:]),
+ }
+}
+
+func readStatus(b []byte) statusFlag {
+ return statusFlag(b[0]) | statusFlag(b[1])<<8
+}
+
+// Ok Packet
+// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
+func (mc *mysqlConn) handleOkPacket(data []byte) error {
+ var n, m int
+
+ // 0x00 [1 byte]
+
+ // Affected rows [Length Coded Binary]
+ mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
+
+ // Insert id [Length Coded Binary]
+ mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
+
+ // server_status [2 bytes]
+ mc.status = readStatus(data[1+n+m : 1+n+m+2])
+ if mc.status&statusMoreResultsExists != 0 {
+ return nil
+ }
+
+ // warning count [2 bytes]
+
+ return nil
+}
+
+// Read Packets as Field Packets until EOF-Packet or an Error appears
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
+func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
+ columns := make([]mysqlField, count)
+
+ for i := 0; ; i++ {
+ data, err := mc.readPacket()
+ if err != nil {
+ return nil, err
+ }
+
+ // EOF Packet
+ if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
+ if i == count {
+ return columns, nil
+ }
+ return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
+ }
+
+ // Catalog
+ pos, err := skipLengthEncodedString(data)
+ if err != nil {
+ return nil, err
+ }
+
+ // Database [len coded string]
+ n, err := skipLengthEncodedString(data[pos:])
+ if err != nil {
+ return nil, err
+ }
+ pos += n
+
+ // Table [len coded string]
+ if mc.cfg.ColumnsWithAlias {
+ tableName, _, n, err := readLengthEncodedString(data[pos:])
+ if err != nil {
+ return nil, err
+ }
+ pos += n
+ columns[i].tableName = string(tableName)
+ } else {
+ n, err = skipLengthEncodedString(data[pos:])
+ if err != nil {
+ return nil, err
+ }
+ pos += n
+ }
+
+ // Original table [len coded string]
+ n, err = skipLengthEncodedString(data[pos:])
+ if err != nil {
+ return nil, err
+ }
+ pos += n
+
+ // Name [len coded string]
+ name, _, n, err := readLengthEncodedString(data[pos:])
+ if err != nil {
+ return nil, err
+ }
+ columns[i].name = string(name)
+ pos += n
+
+ // Original name [len coded string]
+ n, err = skipLengthEncodedString(data[pos:])
+ if err != nil {
+ return nil, err
+ }
+ pos += n
+
+ // Filler [uint8]
+ pos++
+
+ // Charset [charset, collation uint8]
+ columns[i].charSet = data[pos]
+ pos += 2
+
+ // Length [uint32]
+ columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
+ pos += 4
+
+ // Field type [uint8]
+ columns[i].fieldType = fieldType(data[pos])
+ pos++
+
+ // Flags [uint16]
+ columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
+ pos += 2
+
+ // Decimals [uint8]
+ columns[i].decimals = data[pos]
+ //pos++
+
+ // Default value [len coded binary]
+ //if pos < len(data) {
+ // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
+ //}
+ }
+}
+
+// Read Packets as Field Packets until EOF-Packet or an Error appears
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
+func (rows *textRows) readRow(dest []driver.Value) error {
+ mc := rows.mc
+
+ if rows.rs.done {
+ return io.EOF
+ }
+
+ data, err := mc.readPacket()
+ if err != nil {
+ return err
+ }
+
+ // EOF Packet
+ if data[0] == iEOF && len(data) == 5 {
+ // server_status [2 bytes]
+ rows.mc.status = readStatus(data[3:])
+ rows.rs.done = true
+ if !rows.HasNextResultSet() {
+ rows.mc = nil
+ }
+ return io.EOF
+ }
+ if data[0] == iERR {
+ rows.mc = nil
+ return mc.handleErrorPacket(data)
+ }
+
+ // RowSet Packet
+ var n int
+ var isNull bool
+ pos := 0
+
+ for i := range dest {
+ // Read bytes and convert to string
+ dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
+ pos += n
+ if err == nil {
+ if !isNull {
+ if !mc.parseTime {
+ continue
+ } else {
+ switch rows.rs.columns[i].fieldType {
+ case fieldTypeTimestamp, fieldTypeDateTime,
+ fieldTypeDate, fieldTypeNewDate:
+ dest[i], err = parseDateTime(
+ dest[i].([]byte),
+ mc.cfg.Loc,
+ )
+ if err == nil {
+ continue
+ }
+ default:
+ continue
+ }
+ }
+
+ } else {
+ dest[i] = nil
+ continue
+ }
+ }
+ return err // err != nil
+ }
+
+ return nil
+}
+
+// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
+func (mc *mysqlConn) readUntilEOF() error {
+ for {
+ data, err := mc.readPacket()
+ if err != nil {
+ return err
+ }
+
+ switch data[0] {
+ case iERR:
+ return mc.handleErrorPacket(data)
+ case iEOF:
+ if len(data) == 5 {
+ mc.status = readStatus(data[3:])
+ }
+ return nil
+ }
+ }
+}
+
+/******************************************************************************
+* Prepared Statements *
+******************************************************************************/
+
+// Prepare Result Packets
+// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
+func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
+ data, err := stmt.mc.readPacket()
+ if err == nil {
+ // packet indicator [1 byte]
+ if data[0] != iOK {
+ return 0, stmt.mc.handleErrorPacket(data)
+ }
+
+ // statement id [4 bytes]
+ stmt.id = binary.LittleEndian.Uint32(data[1:5])
+
+ // Column count [16 bit uint]
+ columnCount := binary.LittleEndian.Uint16(data[5:7])
+
+ // Param count [16 bit uint]
+ stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
+
+ // Reserved [8 bit]
+
+ // Warning count [16 bit uint]
+
+ return columnCount, nil
+ }
+ return 0, err
+}
+
+// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
+func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
+ maxLen := stmt.mc.maxAllowedPacket - 1
+ pktLen := maxLen
+
+ // After the header (bytes 0-3) follows before the data:
+ // 1 byte command
+ // 4 bytes stmtID
+ // 2 bytes paramID
+ const dataOffset = 1 + 4 + 2
+
+ // Cannot use the write buffer since
+ // a) the buffer is too small
+ // b) it is in use
+ data := make([]byte, 4+1+4+2+len(arg))
+
+ copy(data[4+dataOffset:], arg)
+
+ for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
+ if dataOffset+argLen < maxLen {
+ pktLen = dataOffset + argLen
+ }
+
+ stmt.mc.sequence = 0
+ // Add command byte [1 byte]
+ data[4] = comStmtSendLongData
+
+ // Add stmtID [32 bit]
+ data[5] = byte(stmt.id)
+ data[6] = byte(stmt.id >> 8)
+ data[7] = byte(stmt.id >> 16)
+ data[8] = byte(stmt.id >> 24)
+
+ // Add paramID [16 bit]
+ data[9] = byte(paramID)
+ data[10] = byte(paramID >> 8)
+
+ // Send CMD packet
+ err := stmt.mc.writePacket(data[:4+pktLen])
+ if err == nil {
+ data = data[pktLen-dataOffset:]
+ continue
+ }
+ return err
+
+ }
+
+ // Reset Packet Sequence
+ stmt.mc.sequence = 0
+ return nil
+}
+
+// Execute Prepared Statement
+// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
+func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
+ if len(args) != stmt.paramCount {
+ return fmt.Errorf(
+ "argument count mismatch (got: %d; has: %d)",
+ len(args),
+ stmt.paramCount,
+ )
+ }
+
+ const minPktLen = 4 + 1 + 4 + 1 + 4
+ mc := stmt.mc
+
+ // Determine threshold dynamically to avoid packet size shortage.
+ longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
+ if longDataSize < 64 {
+ longDataSize = 64
+ }
+
+ // Reset packet-sequence
+ mc.sequence = 0
+
+ var data []byte
+ var err error
+
+ if len(args) == 0 {
+ data, err = mc.buf.takeBuffer(minPktLen)
+ } else {
+ data, err = mc.buf.takeCompleteBuffer()
+ // In this case the len(data) == cap(data) which is used to optimise the flow below.
+ }
+ if err != nil {
+ // cannot take the buffer. Something must be wrong with the connection
+ errLog.Print(err)
+ return errBadConnNoWrite
+ }
+
+ // command [1 byte]
+ data[4] = comStmtExecute
+
+ // statement_id [4 bytes]
+ data[5] = byte(stmt.id)
+ data[6] = byte(stmt.id >> 8)
+ data[7] = byte(stmt.id >> 16)
+ data[8] = byte(stmt.id >> 24)
+
+ // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
+ data[9] = 0x00
+
+ // iteration_count (uint32(1)) [4 bytes]
+ data[10] = 0x01
+ data[11] = 0x00
+ data[12] = 0x00
+ data[13] = 0x00
+
+ if len(args) > 0 {
+ pos := minPktLen
+
+ var nullMask []byte
+ if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
+ // buffer has to be extended but we don't know by how much so
+ // we depend on append after all data with known sizes fit.
+ // We stop at that because we deal with a lot of columns here
+ // which makes the required allocation size hard to guess.
+ tmp := make([]byte, pos+maskLen+typesLen)
+ copy(tmp[:pos], data[:pos])
+ data = tmp
+ nullMask = data[pos : pos+maskLen]
+ // No need to clean nullMask as make ensures that.
+ pos += maskLen
+ } else {
+ nullMask = data[pos : pos+maskLen]
+ for i := range nullMask {
+ nullMask[i] = 0
+ }
+ pos += maskLen
+ }
+
+ // newParameterBoundFlag 1 [1 byte]
+ data[pos] = 0x01
+ pos++
+
+ // type of each parameter [len(args)*2 bytes]
+ paramTypes := data[pos:]
+ pos += len(args) * 2
+
+ // value of each parameter [n bytes]
+ paramValues := data[pos:pos]
+ valuesCap := cap(paramValues)
+
+ for i, arg := range args {
+ // build NULL-bitmap
+ if arg == nil {
+ nullMask[i/8] |= 1 << (uint(i) & 7)
+ paramTypes[i+i] = byte(fieldTypeNULL)
+ paramTypes[i+i+1] = 0x00
+ continue
+ }
+
+ if v, ok := arg.(json.RawMessage); ok {
+ arg = []byte(v)
+ }
+ // cache types and values
+ switch v := arg.(type) {
+ case int64:
+ paramTypes[i+i] = byte(fieldTypeLongLong)
+ paramTypes[i+i+1] = 0x00
+
+ if cap(paramValues)-len(paramValues)-8 >= 0 {
+ paramValues = paramValues[:len(paramValues)+8]
+ binary.LittleEndian.PutUint64(
+ paramValues[len(paramValues)-8:],
+ uint64(v),
+ )
+ } else {
+ paramValues = append(paramValues,
+ uint64ToBytes(uint64(v))...,
+ )
+ }
+
+ case uint64:
+ paramTypes[i+i] = byte(fieldTypeLongLong)
+ paramTypes[i+i+1] = 0x80 // type is unsigned
+
+ if cap(paramValues)-len(paramValues)-8 >= 0 {
+ paramValues = paramValues[:len(paramValues)+8]
+ binary.LittleEndian.PutUint64(
+ paramValues[len(paramValues)-8:],
+ uint64(v),
+ )
+ } else {
+ paramValues = append(paramValues,
+ uint64ToBytes(uint64(v))...,
+ )
+ }
+
+ case float64:
+ paramTypes[i+i] = byte(fieldTypeDouble)
+ paramTypes[i+i+1] = 0x00
+
+ if cap(paramValues)-len(paramValues)-8 >= 0 {
+ paramValues = paramValues[:len(paramValues)+8]
+ binary.LittleEndian.PutUint64(
+ paramValues[len(paramValues)-8:],
+ math.Float64bits(v),
+ )
+ } else {
+ paramValues = append(paramValues,
+ uint64ToBytes(math.Float64bits(v))...,
+ )
+ }
+
+ case bool:
+ paramTypes[i+i] = byte(fieldTypeTiny)
+ paramTypes[i+i+1] = 0x00
+
+ if v {
+ paramValues = append(paramValues, 0x01)
+ } else {
+ paramValues = append(paramValues, 0x00)
+ }
+
+ case []byte:
+ // Common case (non-nil value) first
+ if v != nil {
+ paramTypes[i+i] = byte(fieldTypeString)
+ paramTypes[i+i+1] = 0x00
+
+ if len(v) < longDataSize {
+ paramValues = appendLengthEncodedInteger(paramValues,
+ uint64(len(v)),
+ )
+ paramValues = append(paramValues, v...)
+ } else {
+ if err := stmt.writeCommandLongData(i, v); err != nil {
+ return err
+ }
+ }
+ continue
+ }
+
+ // Handle []byte(nil) as a NULL value
+ nullMask[i/8] |= 1 << (uint(i) & 7)
+ paramTypes[i+i] = byte(fieldTypeNULL)
+ paramTypes[i+i+1] = 0x00
+
+ case string:
+ paramTypes[i+i] = byte(fieldTypeString)
+ paramTypes[i+i+1] = 0x00
+
+ if len(v) < longDataSize {
+ paramValues = appendLengthEncodedInteger(paramValues,
+ uint64(len(v)),
+ )
+ paramValues = append(paramValues, v...)
+ } else {
+ if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
+ return err
+ }
+ }
+
+ case time.Time:
+ paramTypes[i+i] = byte(fieldTypeString)
+ paramTypes[i+i+1] = 0x00
+
+ var a [64]byte
+ var b = a[:0]
+
+ if v.IsZero() {
+ b = append(b, "0000-00-00"...)
+ } else {
+ b, err = appendDateTime(b, v.In(mc.cfg.Loc))
+ if err != nil {
+ return err
+ }
+ }
+
+ paramValues = appendLengthEncodedInteger(paramValues,
+ uint64(len(b)),
+ )
+ paramValues = append(paramValues, b...)
+
+ default:
+ return fmt.Errorf("cannot convert type: %T", arg)
+ }
+ }
+
+ // Check if param values exceeded the available buffer
+ // In that case we must build the data packet with the new values buffer
+ if valuesCap != cap(paramValues) {
+ data = append(data[:pos], paramValues...)
+ if err = mc.buf.store(data); err != nil {
+ errLog.Print(err)
+ return errBadConnNoWrite
+ }
+ }
+
+ pos += len(paramValues)
+ data = data[:pos]
+ }
+
+ return mc.writePacket(data)
+}
+
+func (mc *mysqlConn) discardResults() error {
+ for mc.status&statusMoreResultsExists != 0 {
+ resLen, err := mc.readResultSetHeaderPacket()
+ if err != nil {
+ return err
+ }
+ if resLen > 0 {
+ // columns
+ if err := mc.readUntilEOF(); err != nil {
+ return err
+ }
+ // rows
+ if err := mc.readUntilEOF(); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
+func (rows *binaryRows) readRow(dest []driver.Value) error {
+ data, err := rows.mc.readPacket()
+ if err != nil {
+ return err
+ }
+
+ // packet indicator [1 byte]
+ if data[0] != iOK {
+ // EOF Packet
+ if data[0] == iEOF && len(data) == 5 {
+ rows.mc.status = readStatus(data[3:])
+ rows.rs.done = true
+ if !rows.HasNextResultSet() {
+ rows.mc = nil
+ }
+ return io.EOF
+ }
+ mc := rows.mc
+ rows.mc = nil
+
+ // Error otherwise
+ return mc.handleErrorPacket(data)
+ }
+
+ // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
+ pos := 1 + (len(dest)+7+2)>>3
+ nullMask := data[1:pos]
+
+ for i := range dest {
+ // Field is NULL
+ // (byte >> bit-pos) % 2 == 1
+ if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
+ dest[i] = nil
+ continue
+ }
+
+ // Convert to byte-coded string
+ switch rows.rs.columns[i].fieldType {
+ case fieldTypeNULL:
+ dest[i] = nil
+ continue
+
+ // Numeric Types
+ case fieldTypeTiny:
+ if rows.rs.columns[i].flags&flagUnsigned != 0 {
+ dest[i] = int64(data[pos])
+ } else {
+ dest[i] = int64(int8(data[pos]))
+ }
+ pos++
+ continue
+
+ case fieldTypeShort, fieldTypeYear:
+ if rows.rs.columns[i].flags&flagUnsigned != 0 {
+ dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
+ } else {
+ dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
+ }
+ pos += 2
+ continue
+
+ case fieldTypeInt24, fieldTypeLong:
+ if rows.rs.columns[i].flags&flagUnsigned != 0 {
+ dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
+ } else {
+ dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
+ }
+ pos += 4
+ continue
+
+ case fieldTypeLongLong:
+ if rows.rs.columns[i].flags&flagUnsigned != 0 {
+ val := binary.LittleEndian.Uint64(data[pos : pos+8])
+ if val > math.MaxInt64 {
+ dest[i] = uint64ToString(val)
+ } else {
+ dest[i] = int64(val)
+ }
+ } else {
+ dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
+ }
+ pos += 8
+ continue
+
+ case fieldTypeFloat:
+ dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
+ pos += 4
+ continue
+
+ case fieldTypeDouble:
+ dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
+ pos += 8
+ continue
+
+ // Length coded Binary Strings
+ case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
+ fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
+ fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
+ fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON:
+ var isNull bool
+ var n int
+ dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
+ pos += n
+ if err == nil {
+ if !isNull {
+ continue
+ } else {
+ dest[i] = nil
+ continue
+ }
+ }
+ return err
+
+ case
+ fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
+ fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal]
+ fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
+
+ num, isNull, n := readLengthEncodedInteger(data[pos:])
+ pos += n
+
+ switch {
+ case isNull:
+ dest[i] = nil
+ continue
+ case rows.rs.columns[i].fieldType == fieldTypeTime:
+ // database/sql does not support an equivalent to TIME, return a string
+ var dstlen uint8
+ switch decimals := rows.rs.columns[i].decimals; decimals {
+ case 0x00, 0x1f:
+ dstlen = 8
+ case 1, 2, 3, 4, 5, 6:
+ dstlen = 8 + 1 + decimals
+ default:
+ return fmt.Errorf(
+ "protocol error, illegal decimals value %d",
+ rows.rs.columns[i].decimals,
+ )
+ }
+ dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
+ case rows.mc.parseTime:
+ dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
+ default:
+ var dstlen uint8
+ if rows.rs.columns[i].fieldType == fieldTypeDate {
+ dstlen = 10
+ } else {
+ switch decimals := rows.rs.columns[i].decimals; decimals {
+ case 0x00, 0x1f:
+ dstlen = 19
+ case 1, 2, 3, 4, 5, 6:
+ dstlen = 19 + 1 + decimals
+ default:
+ return fmt.Errorf(
+ "protocol error, illegal decimals value %d",
+ rows.rs.columns[i].decimals,
+ )
+ }
+ }
+ dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
+ }
+
+ if err == nil {
+ pos += int(num)
+ continue
+ } else {
+ return err
+ }
+
+ // Please report if this happens!
+ default:
+ return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/mysql/packets_test.go b/vendor/mysql/packets_test.go
new file mode 100644
index 0000000..b61e4db
--- /dev/null
+++ b/vendor/mysql/packets_test.go
@@ -0,0 +1,336 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "bytes"
+ "errors"
+ "net"
+ "testing"
+ "time"
+)
+
+var (
+ errConnClosed = errors.New("connection is closed")
+ errConnTooManyReads = errors.New("too many reads")
+ errConnTooManyWrites = errors.New("too many writes")
+)
+
+// struct to mock a net.Conn for testing purposes
+type mockConn struct {
+ laddr net.Addr
+ raddr net.Addr
+ data []byte
+ written []byte
+ queuedReplies [][]byte
+ closed bool
+ read int
+ reads int
+ writes int
+ maxReads int
+ maxWrites int
+}
+
+func (m *mockConn) Read(b []byte) (n int, err error) {
+ if m.closed {
+ return 0, errConnClosed
+ }
+
+ m.reads++
+ if m.maxReads > 0 && m.reads > m.maxReads {
+ return 0, errConnTooManyReads
+ }
+
+ n = copy(b, m.data)
+ m.read += n
+ m.data = m.data[n:]
+ return
+}
+func (m *mockConn) Write(b []byte) (n int, err error) {
+ if m.closed {
+ return 0, errConnClosed
+ }
+
+ m.writes++
+ if m.maxWrites > 0 && m.writes > m.maxWrites {
+ return 0, errConnTooManyWrites
+ }
+
+ n = len(b)
+ m.written = append(m.written, b...)
+
+ if n > 0 && len(m.queuedReplies) > 0 {
+ m.data = m.queuedReplies[0]
+ m.queuedReplies = m.queuedReplies[1:]
+ }
+ return
+}
+func (m *mockConn) Close() error {
+ m.closed = true
+ return nil
+}
+func (m *mockConn) LocalAddr() net.Addr {
+ return m.laddr
+}
+func (m *mockConn) RemoteAddr() net.Addr {
+ return m.raddr
+}
+func (m *mockConn) SetDeadline(t time.Time) error {
+ return nil
+}
+func (m *mockConn) SetReadDeadline(t time.Time) error {
+ return nil
+}
+func (m *mockConn) SetWriteDeadline(t time.Time) error {
+ return nil
+}
+
+// make sure mockConn implements the net.Conn interface
+var _ net.Conn = new(mockConn)
+
+func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
+ conn := new(mockConn)
+ mc := &mysqlConn{
+ buf: newBuffer(conn),
+ cfg: NewConfig(),
+ netConn: conn,
+ closech: make(chan struct{}),
+ maxAllowedPacket: defaultMaxAllowedPacket,
+ sequence: sequence,
+ }
+ return conn, mc
+}
+
+func TestReadPacketSingleByte(t *testing.T) {
+ conn := new(mockConn)
+ mc := &mysqlConn{
+ buf: newBuffer(conn),
+ }
+
+ conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
+ conn.maxReads = 1
+ packet, err := mc.readPacket()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(packet) != 1 {
+ t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet))
+ }
+ if packet[0] != 0xff {
+ t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0])
+ }
+}
+
+func TestReadPacketWrongSequenceID(t *testing.T) {
+ conn := new(mockConn)
+ mc := &mysqlConn{
+ buf: newBuffer(conn),
+ }
+
+ // too low sequence id
+ conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
+ conn.maxReads = 1
+ mc.sequence = 1
+ _, err := mc.readPacket()
+ if err != ErrPktSync {
+ t.Errorf("expected ErrPktSync, got %v", err)
+ }
+
+ // reset
+ conn.reads = 0
+ mc.sequence = 0
+ mc.buf = newBuffer(conn)
+
+ // too high sequence id
+ conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
+ _, err = mc.readPacket()
+ if err != ErrPktSyncMul {
+ t.Errorf("expected ErrPktSyncMul, got %v", err)
+ }
+}
+
+func TestReadPacketSplit(t *testing.T) {
+ conn := new(mockConn)
+ mc := &mysqlConn{
+ buf: newBuffer(conn),
+ }
+
+ data := make([]byte, maxPacketSize*2+4*3)
+ const pkt2ofs = maxPacketSize + 4
+ const pkt3ofs = 2 * (maxPacketSize + 4)
+
+ // case 1: payload has length maxPacketSize
+ data = data[:pkt2ofs+4]
+
+ // 1st packet has maxPacketSize length and sequence id 0
+ // ff ff ff 00 ...
+ data[0] = 0xff
+ data[1] = 0xff
+ data[2] = 0xff
+
+ // mark the payload start and end of 1st packet so that we can check if the
+ // content was correctly appended
+ data[4] = 0x11
+ data[maxPacketSize+3] = 0x22
+
+ // 2nd packet has payload length 0 and squence id 1
+ // 00 00 00 01
+ data[pkt2ofs+3] = 0x01
+
+ conn.data = data
+ conn.maxReads = 3
+ packet, err := mc.readPacket()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(packet) != maxPacketSize {
+ t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet))
+ }
+ if packet[0] != 0x11 {
+ t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
+ }
+ if packet[maxPacketSize-1] != 0x22 {
+ t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1])
+ }
+
+ // case 2: payload has length which is a multiple of maxPacketSize
+ data = data[:cap(data)]
+
+ // 2nd packet now has maxPacketSize length
+ data[pkt2ofs] = 0xff
+ data[pkt2ofs+1] = 0xff
+ data[pkt2ofs+2] = 0xff
+
+ // mark the payload start and end of the 2nd packet
+ data[pkt2ofs+4] = 0x33
+ data[pkt2ofs+maxPacketSize+3] = 0x44
+
+ // 3rd packet has payload length 0 and squence id 2
+ // 00 00 00 02
+ data[pkt3ofs+3] = 0x02
+
+ conn.data = data
+ conn.reads = 0
+ conn.maxReads = 5
+ mc.sequence = 0
+ packet, err = mc.readPacket()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(packet) != 2*maxPacketSize {
+ t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet))
+ }
+ if packet[0] != 0x11 {
+ t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
+ }
+ if packet[2*maxPacketSize-1] != 0x44 {
+ t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1])
+ }
+
+ // case 3: payload has a length larger maxPacketSize, which is not an exact
+ // multiple of it
+ data = data[:pkt2ofs+4+42]
+ data[pkt2ofs] = 0x2a
+ data[pkt2ofs+1] = 0x00
+ data[pkt2ofs+2] = 0x00
+ data[pkt2ofs+4+41] = 0x44
+
+ conn.data = data
+ conn.reads = 0
+ conn.maxReads = 4
+ mc.sequence = 0
+ packet, err = mc.readPacket()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(packet) != maxPacketSize+42 {
+ t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet))
+ }
+ if packet[0] != 0x11 {
+ t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
+ }
+ if packet[maxPacketSize+41] != 0x44 {
+ t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41])
+ }
+}
+
+func TestReadPacketFail(t *testing.T) {
+ conn := new(mockConn)
+ mc := &mysqlConn{
+ buf: newBuffer(conn),
+ closech: make(chan struct{}),
+ }
+
+ // illegal empty (stand-alone) packet
+ conn.data = []byte{0x00, 0x00, 0x00, 0x00}
+ conn.maxReads = 1
+ _, err := mc.readPacket()
+ if err != ErrInvalidConn {
+ t.Errorf("expected ErrInvalidConn, got %v", err)
+ }
+
+ // reset
+ conn.reads = 0
+ mc.sequence = 0
+ mc.buf = newBuffer(conn)
+
+ // fail to read header
+ conn.closed = true
+ _, err = mc.readPacket()
+ if err != ErrInvalidConn {
+ t.Errorf("expected ErrInvalidConn, got %v", err)
+ }
+
+ // reset
+ conn.closed = false
+ conn.reads = 0
+ mc.sequence = 0
+ mc.buf = newBuffer(conn)
+
+ // fail to read body
+ conn.maxReads = 1
+ _, err = mc.readPacket()
+ if err != ErrInvalidConn {
+ t.Errorf("expected ErrInvalidConn, got %v", err)
+ }
+}
+
+// https://github.com/go-sql-driver/mysql/pull/801
+// not-NUL terminated plugin_name in init packet
+func TestRegression801(t *testing.T) {
+ conn := new(mockConn)
+ mc := &mysqlConn{
+ buf: newBuffer(conn),
+ cfg: new(Config),
+ sequence: 42,
+ closech: make(chan struct{}),
+ }
+
+ conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,
+ 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77,
+ 50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95,
+ 112, 97, 115, 115, 119, 111, 114, 100}
+ conn.maxReads = 1
+
+ authData, pluginName, err := mc.readHandshakePacket()
+ if err != nil {
+ t.Fatalf("got error: %v", err)
+ }
+
+ if pluginName != "mysql_native_password" {
+ t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName)
+ }
+
+ expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114,
+ 47, 85, 75, 109, 99, 51, 77, 50, 64}
+ if !bytes.Equal(authData, expectedAuthData) {
+ t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData)
+ }
+}
diff --git a/vendor/mysql/result.go b/vendor/mysql/result.go
new file mode 100644
index 0000000..c6438d0
--- /dev/null
+++ b/vendor/mysql/result.go
@@ -0,0 +1,22 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+type mysqlResult struct {
+ affectedRows int64
+ insertId int64
+}
+
+func (res *mysqlResult) LastInsertId() (int64, error) {
+ return res.insertId, nil
+}
+
+func (res *mysqlResult) RowsAffected() (int64, error) {
+ return res.affectedRows, nil
+}
diff --git a/vendor/mysql/rows.go b/vendor/mysql/rows.go
new file mode 100644
index 0000000..888bdb5
--- /dev/null
+++ b/vendor/mysql/rows.go
@@ -0,0 +1,223 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "database/sql/driver"
+ "io"
+ "math"
+ "reflect"
+)
+
+type resultSet struct {
+ columns []mysqlField
+ columnNames []string
+ done bool
+}
+
+type mysqlRows struct {
+ mc *mysqlConn
+ rs resultSet
+ finish func()
+}
+
+type binaryRows struct {
+ mysqlRows
+}
+
+type textRows struct {
+ mysqlRows
+}
+
+func (rows *mysqlRows) Columns() []string {
+ if rows.rs.columnNames != nil {
+ return rows.rs.columnNames
+ }
+
+ columns := make([]string, len(rows.rs.columns))
+ if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
+ for i := range columns {
+ if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
+ columns[i] = tableName + "." + rows.rs.columns[i].name
+ } else {
+ columns[i] = rows.rs.columns[i].name
+ }
+ }
+ } else {
+ for i := range columns {
+ columns[i] = rows.rs.columns[i].name
+ }
+ }
+
+ rows.rs.columnNames = columns
+ return columns
+}
+
+func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string {
+ return rows.rs.columns[i].typeDatabaseName()
+}
+
+// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) {
+// return int64(rows.rs.columns[i].length), true
+// }
+
+func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) {
+ return rows.rs.columns[i].flags&flagNotNULL == 0, true
+}
+
+func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) {
+ column := rows.rs.columns[i]
+ decimals := int64(column.decimals)
+
+ switch column.fieldType {
+ case fieldTypeDecimal, fieldTypeNewDecimal:
+ if decimals > 0 {
+ return int64(column.length) - 2, decimals, true
+ }
+ return int64(column.length) - 1, decimals, true
+ case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime:
+ return decimals, decimals, true
+ case fieldTypeFloat, fieldTypeDouble:
+ if decimals == 0x1f {
+ return math.MaxInt64, math.MaxInt64, true
+ }
+ return math.MaxInt64, decimals, true
+ }
+
+ return 0, 0, false
+}
+
+func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type {
+ return rows.rs.columns[i].scanType()
+}
+
+func (rows *mysqlRows) Close() (err error) {
+ if f := rows.finish; f != nil {
+ f()
+ rows.finish = nil
+ }
+
+ mc := rows.mc
+ if mc == nil {
+ return nil
+ }
+ if err := mc.error(); err != nil {
+ return err
+ }
+
+ // flip the buffer for this connection if we need to drain it.
+ // note that for a successful query (i.e. one where rows.next()
+ // has been called until it returns false), `rows.mc` will be nil
+ // by the time the user calls `(*Rows).Close`, so we won't reach this
+ // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47
+ mc.buf.flip()
+
+ // Remove unread packets from stream
+ if !rows.rs.done {
+ err = mc.readUntilEOF()
+ }
+ if err == nil {
+ if err = mc.discardResults(); err != nil {
+ return err
+ }
+ }
+
+ rows.mc = nil
+ return err
+}
+
+func (rows *mysqlRows) HasNextResultSet() (b bool) {
+ if rows.mc == nil {
+ return false
+ }
+ return rows.mc.status&statusMoreResultsExists != 0
+}
+
+func (rows *mysqlRows) nextResultSet() (int, error) {
+ if rows.mc == nil {
+ return 0, io.EOF
+ }
+ if err := rows.mc.error(); err != nil {
+ return 0, err
+ }
+
+ // Remove unread packets from stream
+ if !rows.rs.done {
+ if err := rows.mc.readUntilEOF(); err != nil {
+ return 0, err
+ }
+ rows.rs.done = true
+ }
+
+ if !rows.HasNextResultSet() {
+ rows.mc = nil
+ return 0, io.EOF
+ }
+ rows.rs = resultSet{}
+ return rows.mc.readResultSetHeaderPacket()
+}
+
+func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {
+ for {
+ resLen, err := rows.nextResultSet()
+ if err != nil {
+ return 0, err
+ }
+
+ if resLen > 0 {
+ return resLen, nil
+ }
+
+ rows.rs.done = true
+ }
+}
+
+func (rows *binaryRows) NextResultSet() error {
+ resLen, err := rows.nextNotEmptyResultSet()
+ if err != nil {
+ return err
+ }
+
+ rows.rs.columns, err = rows.mc.readColumns(resLen)
+ return err
+}
+
+func (rows *binaryRows) Next(dest []driver.Value) error {
+ if mc := rows.mc; mc != nil {
+ if err := mc.error(); err != nil {
+ return err
+ }
+
+ // Fetch next row from stream
+ return rows.readRow(dest)
+ }
+ return io.EOF
+}
+
+func (rows *textRows) NextResultSet() (err error) {
+ resLen, err := rows.nextNotEmptyResultSet()
+ if err != nil {
+ return err
+ }
+
+ rows.rs.columns, err = rows.mc.readColumns(resLen)
+ return err
+}
+
+func (rows *textRows) Next(dest []driver.Value) error {
+ if mc := rows.mc; mc != nil {
+ if err := mc.error(); err != nil {
+ return err
+ }
+
+ // Fetch next row from stream
+ return rows.readRow(dest)
+ }
+ return io.EOF
+}
diff --git a/vendor/mysql/statement.go b/vendor/mysql/statement.go
new file mode 100644
index 0000000..18a3ae4
--- /dev/null
+++ b/vendor/mysql/statement.go
@@ -0,0 +1,220 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+ "fmt"
+ "io"
+ "reflect"
+)
+
+type mysqlStmt struct {
+ mc *mysqlConn
+ id uint32
+ paramCount int
+}
+
+func (stmt *mysqlStmt) Close() error {
+ if stmt.mc == nil || stmt.mc.closed.IsSet() {
+ // driver.Stmt.Close can be called more than once, thus this function
+ // has to be idempotent.
+ // See also Issue #450 and golang/go#16019.
+ //errLog.Print(ErrInvalidConn)
+ return driver.ErrBadConn
+ }
+
+ err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
+ stmt.mc = nil
+ return err
+}
+
+func (stmt *mysqlStmt) NumInput() int {
+ return stmt.paramCount
+}
+
+func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
+ return converter{}
+}
+
+func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
+ nv.Value, err = converter{}.ConvertValue(nv.Value)
+ return
+}
+
+func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
+ if stmt.mc.closed.IsSet() {
+ errLog.Print(ErrInvalidConn)
+ return nil, driver.ErrBadConn
+ }
+ // Send command
+ err := stmt.writeExecutePacket(args)
+ if err != nil {
+ return nil, stmt.mc.markBadConn(err)
+ }
+
+ mc := stmt.mc
+
+ mc.affectedRows = 0
+ mc.insertId = 0
+
+ // Read Result
+ resLen, err := mc.readResultSetHeaderPacket()
+ if err != nil {
+ return nil, err
+ }
+
+ if resLen > 0 {
+ // Columns
+ if err = mc.readUntilEOF(); err != nil {
+ return nil, err
+ }
+
+ // Rows
+ if err := mc.readUntilEOF(); err != nil {
+ return nil, err
+ }
+ }
+
+ if err := mc.discardResults(); err != nil {
+ return nil, err
+ }
+
+ return &mysqlResult{
+ affectedRows: int64(mc.affectedRows),
+ insertId: int64(mc.insertId),
+ }, nil
+}
+
+func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
+ return stmt.query(args)
+}
+
+func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
+ if stmt.mc.closed.IsSet() {
+ errLog.Print(ErrInvalidConn)
+ return nil, driver.ErrBadConn
+ }
+ // Send command
+ err := stmt.writeExecutePacket(args)
+ if err != nil {
+ return nil, stmt.mc.markBadConn(err)
+ }
+
+ mc := stmt.mc
+
+ // Read Result
+ resLen, err := mc.readResultSetHeaderPacket()
+ if err != nil {
+ return nil, err
+ }
+
+ rows := new(binaryRows)
+
+ if resLen > 0 {
+ rows.mc = mc
+ rows.rs.columns, err = mc.readColumns(resLen)
+ } else {
+ rows.rs.done = true
+
+ switch err := rows.NextResultSet(); err {
+ case nil, io.EOF:
+ return rows, nil
+ default:
+ return nil, err
+ }
+ }
+
+ return rows, err
+}
+
+var jsonType = reflect.TypeOf(json.RawMessage{})
+
+type converter struct{}
+
+// ConvertValue mirrors the reference/default converter in database/sql/driver
+// with _one_ exception. We support uint64 with their high bit and the default
+// implementation does not. This function should be kept in sync with
+// database/sql/driver defaultConverter.ConvertValue() except for that
+// deliberate difference.
+func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
+ if driver.IsValue(v) {
+ return v, nil
+ }
+
+ if vr, ok := v.(driver.Valuer); ok {
+ sv, err := callValuerValue(vr)
+ if err != nil {
+ return nil, err
+ }
+ if driver.IsValue(sv) {
+ return sv, nil
+ }
+ // A value returend from the Valuer interface can be "a type handled by
+ // a database driver's NamedValueChecker interface" so we should accept
+ // uint64 here as well.
+ if u, ok := sv.(uint64); ok {
+ return u, nil
+ }
+ return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
+ }
+ rv := reflect.ValueOf(v)
+ switch rv.Kind() {
+ case reflect.Ptr:
+ // indirect pointers
+ if rv.IsNil() {
+ return nil, nil
+ } else {
+ return c.ConvertValue(rv.Elem().Interface())
+ }
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return rv.Int(), nil
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return rv.Uint(), nil
+ case reflect.Float32, reflect.Float64:
+ return rv.Float(), nil
+ case reflect.Bool:
+ return rv.Bool(), nil
+ case reflect.Slice:
+ switch t := rv.Type(); {
+ case t == jsonType:
+ return v, nil
+ case t.Elem().Kind() == reflect.Uint8:
+ return rv.Bytes(), nil
+ default:
+ return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
+ }
+ case reflect.String:
+ return rv.String(), nil
+ }
+ return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
+}
+
+var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
+
+// callValuerValue returns vr.Value(), with one exception:
+// If vr.Value is an auto-generated method on a pointer type and the
+// pointer is nil, it would panic at runtime in the panicwrap
+// method. Treat it like nil instead.
+//
+// This is so people can implement driver.Value on value types and
+// still use nil pointers to those types to mean nil/NULL, just like
+// string/*string.
+//
+// This is an exact copy of the same-named unexported function from the
+// database/sql package.
+func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
+ if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
+ rv.IsNil() &&
+ rv.Type().Elem().Implements(valuerReflectType) {
+ return nil, nil
+ }
+ return vr.Value()
+}
diff --git a/vendor/mysql/statement_test.go b/vendor/mysql/statement_test.go
new file mode 100644
index 0000000..ac6b92d
--- /dev/null
+++ b/vendor/mysql/statement_test.go
@@ -0,0 +1,151 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "bytes"
+ "database/sql/driver"
+ "encoding/json"
+ "testing"
+)
+
+func TestConvertDerivedString(t *testing.T) {
+ type derived string
+
+ output, err := converter{}.ConvertValue(derived("value"))
+ if err != nil {
+ t.Fatal("Derived string type not convertible", err)
+ }
+
+ if output != "value" {
+ t.Fatalf("Derived string type not converted, got %#v %T", output, output)
+ }
+}
+
+func TestConvertDerivedByteSlice(t *testing.T) {
+ type derived []uint8
+
+ output, err := converter{}.ConvertValue(derived("value"))
+ if err != nil {
+ t.Fatal("Byte slice not convertible", err)
+ }
+
+ if bytes.Compare(output.([]byte), []byte("value")) != 0 {
+ t.Fatalf("Byte slice not converted, got %#v %T", output, output)
+ }
+}
+
+func TestConvertDerivedUnsupportedSlice(t *testing.T) {
+ type derived []int
+
+ _, err := converter{}.ConvertValue(derived{1})
+ if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" {
+ t.Fatal("Unexpected error", err)
+ }
+}
+
+func TestConvertDerivedBool(t *testing.T) {
+ type derived bool
+
+ output, err := converter{}.ConvertValue(derived(true))
+ if err != nil {
+ t.Fatal("Derived bool type not convertible", err)
+ }
+
+ if output != true {
+ t.Fatalf("Derived bool type not converted, got %#v %T", output, output)
+ }
+}
+
+func TestConvertPointer(t *testing.T) {
+ str := "value"
+
+ output, err := converter{}.ConvertValue(&str)
+ if err != nil {
+ t.Fatal("Pointer type not convertible", err)
+ }
+
+ if output != "value" {
+ t.Fatalf("Pointer type not converted, got %#v %T", output, output)
+ }
+}
+
+func TestConvertSignedIntegers(t *testing.T) {
+ values := []interface{}{
+ int8(-42),
+ int16(-42),
+ int32(-42),
+ int64(-42),
+ int(-42),
+ }
+
+ for _, value := range values {
+ output, err := converter{}.ConvertValue(value)
+ if err != nil {
+ t.Fatalf("%T type not convertible %s", value, err)
+ }
+
+ if output != int64(-42) {
+ t.Fatalf("%T type not converted, got %#v %T", value, output, output)
+ }
+ }
+}
+
+type myUint64 struct {
+ value uint64
+}
+
+func (u myUint64) Value() (driver.Value, error) {
+ return u.value, nil
+}
+
+func TestConvertUnsignedIntegers(t *testing.T) {
+ values := []interface{}{
+ uint8(42),
+ uint16(42),
+ uint32(42),
+ uint64(42),
+ uint(42),
+ myUint64{uint64(42)},
+ }
+
+ for _, value := range values {
+ output, err := converter{}.ConvertValue(value)
+ if err != nil {
+ t.Fatalf("%T type not convertible %s", value, err)
+ }
+
+ if output != uint64(42) {
+ t.Fatalf("%T type not converted, got %#v %T", value, output, output)
+ }
+ }
+
+ output, err := converter{}.ConvertValue(^uint64(0))
+ if err != nil {
+ t.Fatal("uint64 high-bit not convertible", err)
+ }
+
+ if output != ^uint64(0) {
+ t.Fatalf("uint64 high-bit converted, got %#v %T", output, output)
+ }
+}
+
+func TestConvertJSON(t *testing.T) {
+ raw := json.RawMessage("{}")
+
+ out, err := converter{}.ConvertValue(raw)
+
+ if err != nil {
+ t.Fatal("json.RawMessage was failed in convert", err)
+ }
+
+ if _, ok := out.(json.RawMessage); !ok {
+ t.Fatalf("json.RawMessage converted, got %#v %T", out, out)
+ }
+}
diff --git a/vendor/mysql/transaction.go b/vendor/mysql/transaction.go
new file mode 100644
index 0000000..417d727
--- /dev/null
+++ b/vendor/mysql/transaction.go
@@ -0,0 +1,31 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+type mysqlTx struct {
+ mc *mysqlConn
+}
+
+func (tx *mysqlTx) Commit() (err error) {
+ if tx.mc == nil || tx.mc.closed.IsSet() {
+ return ErrInvalidConn
+ }
+ err = tx.mc.exec("COMMIT")
+ tx.mc = nil
+ return
+}
+
+func (tx *mysqlTx) Rollback() (err error) {
+ if tx.mc == nil || tx.mc.closed.IsSet() {
+ return ErrInvalidConn
+ }
+ err = tx.mc.exec("ROLLBACK")
+ tx.mc = nil
+ return
+}
diff --git a/vendor/mysql/utils.go b/vendor/mysql/utils.go
new file mode 100644
index 0000000..d6545f5
--- /dev/null
+++ b/vendor/mysql/utils.go
@@ -0,0 +1,868 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "crypto/tls"
+ "database/sql"
+ "database/sql/driver"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// Registry for custom tls.Configs
+var (
+ tlsConfigLock sync.RWMutex
+ tlsConfigRegistry map[string]*tls.Config
+)
+
+// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
+// Use the key as a value in the DSN where tls=value.
+//
+// Note: The provided tls.Config is exclusively owned by the driver after
+// registering it.
+//
+// rootCertPool := x509.NewCertPool()
+// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
+// if err != nil {
+// log.Fatal(err)
+// }
+// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
+// log.Fatal("Failed to append PEM.")
+// }
+// clientCert := make([]tls.Certificate, 0, 1)
+// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
+// if err != nil {
+// log.Fatal(err)
+// }
+// clientCert = append(clientCert, certs)
+// mysql.RegisterTLSConfig("custom", &tls.Config{
+// RootCAs: rootCertPool,
+// Certificates: clientCert,
+// })
+// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
+//
+func RegisterTLSConfig(key string, config *tls.Config) error {
+ if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
+ return fmt.Errorf("key '%s' is reserved", key)
+ }
+
+ tlsConfigLock.Lock()
+ if tlsConfigRegistry == nil {
+ tlsConfigRegistry = make(map[string]*tls.Config)
+ }
+
+ tlsConfigRegistry[key] = config
+ tlsConfigLock.Unlock()
+ return nil
+}
+
+// DeregisterTLSConfig removes the tls.Config associated with key.
+func DeregisterTLSConfig(key string) {
+ tlsConfigLock.Lock()
+ if tlsConfigRegistry != nil {
+ delete(tlsConfigRegistry, key)
+ }
+ tlsConfigLock.Unlock()
+}
+
+func getTLSConfigClone(key string) (config *tls.Config) {
+ tlsConfigLock.RLock()
+ if v, ok := tlsConfigRegistry[key]; ok {
+ config = v.Clone()
+ }
+ tlsConfigLock.RUnlock()
+ return
+}
+
+// Returns the bool value of the input.
+// The 2nd return value indicates if the input was a valid bool value
+func readBool(input string) (value bool, valid bool) {
+ switch input {
+ case "1", "true", "TRUE", "True":
+ return true, true
+ case "0", "false", "FALSE", "False":
+ return false, true
+ }
+
+ // Not a valid bool value
+ return
+}
+
+/******************************************************************************
+* Time related utils *
+******************************************************************************/
+
+func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
+ const base = "0000-00-00 00:00:00.000000"
+ switch len(b) {
+ case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
+ if string(b) == base[:len(b)] {
+ return time.Time{}, nil
+ }
+
+ year, err := parseByteYear(b)
+ if err != nil {
+ return time.Time{}, err
+ }
+ if year <= 0 {
+ year = 1
+ }
+
+ if b[4] != '-' {
+ return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
+ }
+
+ m, err := parseByte2Digits(b[5], b[6])
+ if err != nil {
+ return time.Time{}, err
+ }
+ if m <= 0 {
+ m = 1
+ }
+ month := time.Month(m)
+
+ if b[7] != '-' {
+ return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
+ }
+
+ day, err := parseByte2Digits(b[8], b[9])
+ if err != nil {
+ return time.Time{}, err
+ }
+ if day <= 0 {
+ day = 1
+ }
+ if len(b) == 10 {
+ return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
+ }
+
+ if b[10] != ' ' {
+ return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
+ }
+
+ hour, err := parseByte2Digits(b[11], b[12])
+ if err != nil {
+ return time.Time{}, err
+ }
+ if b[13] != ':' {
+ return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
+ }
+
+ min, err := parseByte2Digits(b[14], b[15])
+ if err != nil {
+ return time.Time{}, err
+ }
+ if b[16] != ':' {
+ return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
+ }
+
+ sec, err := parseByte2Digits(b[17], b[18])
+ if err != nil {
+ return time.Time{}, err
+ }
+ if len(b) == 19 {
+ return time.Date(year, month, day, hour, min, sec, 0, loc), nil
+ }
+
+ if b[19] != '.' {
+ return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
+ }
+ nsec, err := parseByteNanoSec(b[20:])
+ if err != nil {
+ return time.Time{}, err
+ }
+ return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
+ default:
+ return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
+ }
+}
+
+func parseByteYear(b []byte) (int, error) {
+ year, n := 0, 1000
+ for i := 0; i < 4; i++ {
+ v, err := bToi(b[i])
+ if err != nil {
+ return 0, err
+ }
+ year += v * n
+ n = n / 10
+ }
+ return year, nil
+}
+
+func parseByte2Digits(b1, b2 byte) (int, error) {
+ d1, err := bToi(b1)
+ if err != nil {
+ return 0, err
+ }
+ d2, err := bToi(b2)
+ if err != nil {
+ return 0, err
+ }
+ return d1*10 + d2, nil
+}
+
+func parseByteNanoSec(b []byte) (int, error) {
+ ns, digit := 0, 100000 // max is 6-digits
+ for i := 0; i < len(b); i++ {
+ v, err := bToi(b[i])
+ if err != nil {
+ return 0, err
+ }
+ ns += v * digit
+ digit /= 10
+ }
+ // nanoseconds has 10-digits. (needs to scale digits)
+ // 10 - 6 = 4, so we have to multiple 1000.
+ return ns * 1000, nil
+}
+
+func bToi(b byte) (int, error) {
+ if b < '0' || b > '9' {
+ return 0, errors.New("not [0-9]")
+ }
+ return int(b - '0'), nil
+}
+
+func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
+ switch num {
+ case 0:
+ return time.Time{}, nil
+ case 4:
+ return time.Date(
+ int(binary.LittleEndian.Uint16(data[:2])), // year
+ time.Month(data[2]), // month
+ int(data[3]), // day
+ 0, 0, 0, 0,
+ loc,
+ ), nil
+ case 7:
+ return time.Date(
+ int(binary.LittleEndian.Uint16(data[:2])), // year
+ time.Month(data[2]), // month
+ int(data[3]), // day
+ int(data[4]), // hour
+ int(data[5]), // minutes
+ int(data[6]), // seconds
+ 0,
+ loc,
+ ), nil
+ case 11:
+ return time.Date(
+ int(binary.LittleEndian.Uint16(data[:2])), // year
+ time.Month(data[2]), // month
+ int(data[3]), // day
+ int(data[4]), // hour
+ int(data[5]), // minutes
+ int(data[6]), // seconds
+ int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
+ loc,
+ ), nil
+ }
+ return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
+}
+
+func appendDateTime(buf []byte, t time.Time) ([]byte, error) {
+ year, month, day := t.Date()
+ hour, min, sec := t.Clock()
+ nsec := t.Nanosecond()
+
+ if year < 1 || year > 9999 {
+ return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap
+ }
+ year100 := year / 100
+ year1 := year % 100
+
+ var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape
+ localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1]
+ localBuf[4] = '-'
+ localBuf[5], localBuf[6] = digits10[month], digits01[month]
+ localBuf[7] = '-'
+ localBuf[8], localBuf[9] = digits10[day], digits01[day]
+
+ if hour == 0 && min == 0 && sec == 0 && nsec == 0 {
+ return append(buf, localBuf[:10]...), nil
+ }
+
+ localBuf[10] = ' '
+ localBuf[11], localBuf[12] = digits10[hour], digits01[hour]
+ localBuf[13] = ':'
+ localBuf[14], localBuf[15] = digits10[min], digits01[min]
+ localBuf[16] = ':'
+ localBuf[17], localBuf[18] = digits10[sec], digits01[sec]
+
+ if nsec == 0 {
+ return append(buf, localBuf[:19]...), nil
+ }
+ nsec100000000 := nsec / 100000000
+ nsec1000000 := (nsec / 1000000) % 100
+ nsec10000 := (nsec / 10000) % 100
+ nsec100 := (nsec / 100) % 100
+ nsec1 := nsec % 100
+ localBuf[19] = '.'
+
+ // milli second
+ localBuf[20], localBuf[21], localBuf[22] =
+ digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000]
+ // micro second
+ localBuf[23], localBuf[24], localBuf[25] =
+ digits10[nsec10000], digits01[nsec10000], digits10[nsec100]
+ // nano second
+ localBuf[26], localBuf[27], localBuf[28] =
+ digits01[nsec100], digits10[nsec1], digits01[nsec1]
+
+ // trim trailing zeros
+ n := len(localBuf)
+ for n > 0 && localBuf[n-1] == '0' {
+ n--
+ }
+
+ return append(buf, localBuf[:n]...), nil
+}
+
+// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
+// if the DATE or DATETIME has the zero value.
+// It must never be changed.
+// The current behavior depends on database/sql copying the result.
+var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
+
+const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
+const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
+
+func appendMicrosecs(dst, src []byte, decimals int) []byte {
+ if decimals <= 0 {
+ return dst
+ }
+ if len(src) == 0 {
+ return append(dst, ".000000"[:decimals+1]...)
+ }
+
+ microsecs := binary.LittleEndian.Uint32(src[:4])
+ p1 := byte(microsecs / 10000)
+ microsecs -= 10000 * uint32(p1)
+ p2 := byte(microsecs / 100)
+ microsecs -= 100 * uint32(p2)
+ p3 := byte(microsecs)
+
+ switch decimals {
+ default:
+ return append(dst, '.',
+ digits10[p1], digits01[p1],
+ digits10[p2], digits01[p2],
+ digits10[p3], digits01[p3],
+ )
+ case 1:
+ return append(dst, '.',
+ digits10[p1],
+ )
+ case 2:
+ return append(dst, '.',
+ digits10[p1], digits01[p1],
+ )
+ case 3:
+ return append(dst, '.',
+ digits10[p1], digits01[p1],
+ digits10[p2],
+ )
+ case 4:
+ return append(dst, '.',
+ digits10[p1], digits01[p1],
+ digits10[p2], digits01[p2],
+ )
+ case 5:
+ return append(dst, '.',
+ digits10[p1], digits01[p1],
+ digits10[p2], digits01[p2],
+ digits10[p3],
+ )
+ }
+}
+
+func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
+ // length expects the deterministic length of the zero value,
+ // negative time and 100+ hours are automatically added if needed
+ if len(src) == 0 {
+ return zeroDateTime[:length], nil
+ }
+ var dst []byte // return value
+ var p1, p2, p3 byte // current digit pair
+
+ switch length {
+ case 10, 19, 21, 22, 23, 24, 25, 26:
+ default:
+ t := "DATE"
+ if length > 10 {
+ t += "TIME"
+ }
+ return nil, fmt.Errorf("illegal %s length %d", t, length)
+ }
+ switch len(src) {
+ case 4, 7, 11:
+ default:
+ t := "DATE"
+ if length > 10 {
+ t += "TIME"
+ }
+ return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
+ }
+ dst = make([]byte, 0, length)
+ // start with the date
+ year := binary.LittleEndian.Uint16(src[:2])
+ pt := year / 100
+ p1 = byte(year - 100*uint16(pt))
+ p2, p3 = src[2], src[3]
+ dst = append(dst,
+ digits10[pt], digits01[pt],
+ digits10[p1], digits01[p1], '-',
+ digits10[p2], digits01[p2], '-',
+ digits10[p3], digits01[p3],
+ )
+ if length == 10 {
+ return dst, nil
+ }
+ if len(src) == 4 {
+ return append(dst, zeroDateTime[10:length]...), nil
+ }
+ dst = append(dst, ' ')
+ p1 = src[4] // hour
+ src = src[5:]
+
+ // p1 is 2-digit hour, src is after hour
+ p2, p3 = src[0], src[1]
+ dst = append(dst,
+ digits10[p1], digits01[p1], ':',
+ digits10[p2], digits01[p2], ':',
+ digits10[p3], digits01[p3],
+ )
+ return appendMicrosecs(dst, src[2:], int(length)-20), nil
+}
+
+func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
+ // length expects the deterministic length of the zero value,
+ // negative time and 100+ hours are automatically added if needed
+ if len(src) == 0 {
+ return zeroDateTime[11 : 11+length], nil
+ }
+ var dst []byte // return value
+
+ switch length {
+ case
+ 8, // time (can be up to 10 when negative and 100+ hours)
+ 10, 11, 12, 13, 14, 15: // time with fractional seconds
+ default:
+ return nil, fmt.Errorf("illegal TIME length %d", length)
+ }
+ switch len(src) {
+ case 8, 12:
+ default:
+ return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
+ }
+ // +2 to enable negative time and 100+ hours
+ dst = make([]byte, 0, length+2)
+ if src[0] == 1 {
+ dst = append(dst, '-')
+ }
+ days := binary.LittleEndian.Uint32(src[1:5])
+ hours := int64(days)*24 + int64(src[5])
+
+ if hours >= 100 {
+ dst = strconv.AppendInt(dst, hours, 10)
+ } else {
+ dst = append(dst, digits10[hours], digits01[hours])
+ }
+
+ min, sec := src[6], src[7]
+ dst = append(dst, ':',
+ digits10[min], digits01[min], ':',
+ digits10[sec], digits01[sec],
+ )
+ return appendMicrosecs(dst, src[8:], int(length)-9), nil
+}
+
+/******************************************************************************
+* Convert from and to bytes *
+******************************************************************************/
+
+func uint64ToBytes(n uint64) []byte {
+ return []byte{
+ byte(n),
+ byte(n >> 8),
+ byte(n >> 16),
+ byte(n >> 24),
+ byte(n >> 32),
+ byte(n >> 40),
+ byte(n >> 48),
+ byte(n >> 56),
+ }
+}
+
+func uint64ToString(n uint64) []byte {
+ var a [20]byte
+ i := 20
+
+ // U+0030 = 0
+ // ...
+ // U+0039 = 9
+
+ var q uint64
+ for n >= 10 {
+ i--
+ q = n / 10
+ a[i] = uint8(n-q*10) + 0x30
+ n = q
+ }
+
+ i--
+ a[i] = uint8(n) + 0x30
+
+ return a[i:]
+}
+
+// treats string value as unsigned integer representation
+func stringToInt(b []byte) int {
+ val := 0
+ for i := range b {
+ val *= 10
+ val += int(b[i] - 0x30)
+ }
+ return val
+}
+
+// returns the string read as a bytes slice, wheter the value is NULL,
+// the number of bytes read and an error, in case the string is longer than
+// the input slice
+func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
+ // Get length
+ num, isNull, n := readLengthEncodedInteger(b)
+ if num < 1 {
+ return b[n:n], isNull, n, nil
+ }
+
+ n += int(num)
+
+ // Check data length
+ if len(b) >= n {
+ return b[n-int(num) : n : n], false, n, nil
+ }
+ return nil, false, n, io.EOF
+}
+
+// returns the number of bytes skipped and an error, in case the string is
+// longer than the input slice
+func skipLengthEncodedString(b []byte) (int, error) {
+ // Get length
+ num, _, n := readLengthEncodedInteger(b)
+ if num < 1 {
+ return n, nil
+ }
+
+ n += int(num)
+
+ // Check data length
+ if len(b) >= n {
+ return n, nil
+ }
+ return n, io.EOF
+}
+
+// returns the number read, whether the value is NULL and the number of bytes read
+func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
+ // See issue #349
+ if len(b) == 0 {
+ return 0, true, 1
+ }
+
+ switch b[0] {
+ // 251: NULL
+ case 0xfb:
+ return 0, true, 1
+
+ // 252: value of following 2
+ case 0xfc:
+ return uint64(b[1]) | uint64(b[2])<<8, false, 3
+
+ // 253: value of following 3
+ case 0xfd:
+ return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
+
+ // 254: value of following 8
+ case 0xfe:
+ return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
+ uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
+ uint64(b[7])<<48 | uint64(b[8])<<56,
+ false, 9
+ }
+
+ // 0-250: value of first byte
+ return uint64(b[0]), false, 1
+}
+
+// encodes a uint64 value and appends it to the given bytes slice
+func appendLengthEncodedInteger(b []byte, n uint64) []byte {
+ switch {
+ case n <= 250:
+ return append(b, byte(n))
+
+ case n <= 0xffff:
+ return append(b, 0xfc, byte(n), byte(n>>8))
+
+ case n <= 0xffffff:
+ return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
+ }
+ return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
+ byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
+}
+
+// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
+// If cap(buf) is not enough, reallocate new buffer.
+func reserveBuffer(buf []byte, appendSize int) []byte {
+ newSize := len(buf) + appendSize
+ if cap(buf) < newSize {
+ // Grow buffer exponentially
+ newBuf := make([]byte, len(buf)*2+appendSize)
+ copy(newBuf, buf)
+ buf = newBuf
+ }
+ return buf[:newSize]
+}
+
+// escapeBytesBackslash escapes []byte with backslashes (\)
+// This escapes the contents of a string (provided as []byte) by adding backslashes before special
+// characters, and turning others into specific escape sequences, such as
+// turning newlines into \n and null bytes into \0.
+// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
+func escapeBytesBackslash(buf, v []byte) []byte {
+ pos := len(buf)
+ buf = reserveBuffer(buf, len(v)*2)
+
+ for _, c := range v {
+ switch c {
+ case '\x00':
+ buf[pos] = '\\'
+ buf[pos+1] = '0'
+ pos += 2
+ case '\n':
+ buf[pos] = '\\'
+ buf[pos+1] = 'n'
+ pos += 2
+ case '\r':
+ buf[pos] = '\\'
+ buf[pos+1] = 'r'
+ pos += 2
+ case '\x1a':
+ buf[pos] = '\\'
+ buf[pos+1] = 'Z'
+ pos += 2
+ case '\'':
+ buf[pos] = '\\'
+ buf[pos+1] = '\''
+ pos += 2
+ case '"':
+ buf[pos] = '\\'
+ buf[pos+1] = '"'
+ pos += 2
+ case '\\':
+ buf[pos] = '\\'
+ buf[pos+1] = '\\'
+ pos += 2
+ default:
+ buf[pos] = c
+ pos++
+ }
+ }
+
+ return buf[:pos]
+}
+
+// escapeStringBackslash is similar to escapeBytesBackslash but for string.
+func escapeStringBackslash(buf []byte, v string) []byte {
+ pos := len(buf)
+ buf = reserveBuffer(buf, len(v)*2)
+
+ for i := 0; i < len(v); i++ {
+ c := v[i]
+ switch c {
+ case '\x00':
+ buf[pos] = '\\'
+ buf[pos+1] = '0'
+ pos += 2
+ case '\n':
+ buf[pos] = '\\'
+ buf[pos+1] = 'n'
+ pos += 2
+ case '\r':
+ buf[pos] = '\\'
+ buf[pos+1] = 'r'
+ pos += 2
+ case '\x1a':
+ buf[pos] = '\\'
+ buf[pos+1] = 'Z'
+ pos += 2
+ case '\'':
+ buf[pos] = '\\'
+ buf[pos+1] = '\''
+ pos += 2
+ case '"':
+ buf[pos] = '\\'
+ buf[pos+1] = '"'
+ pos += 2
+ case '\\':
+ buf[pos] = '\\'
+ buf[pos+1] = '\\'
+ pos += 2
+ default:
+ buf[pos] = c
+ pos++
+ }
+ }
+
+ return buf[:pos]
+}
+
+// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
+// This escapes the contents of a string by doubling up any apostrophes that
+// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
+// effect on the server.
+// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
+func escapeBytesQuotes(buf, v []byte) []byte {
+ pos := len(buf)
+ buf = reserveBuffer(buf, len(v)*2)
+
+ for _, c := range v {
+ if c == '\'' {
+ buf[pos] = '\''
+ buf[pos+1] = '\''
+ pos += 2
+ } else {
+ buf[pos] = c
+ pos++
+ }
+ }
+
+ return buf[:pos]
+}
+
+// escapeStringQuotes is similar to escapeBytesQuotes but for string.
+func escapeStringQuotes(buf []byte, v string) []byte {
+ pos := len(buf)
+ buf = reserveBuffer(buf, len(v)*2)
+
+ for i := 0; i < len(v); i++ {
+ c := v[i]
+ if c == '\'' {
+ buf[pos] = '\''
+ buf[pos+1] = '\''
+ pos += 2
+ } else {
+ buf[pos] = c
+ pos++
+ }
+ }
+
+ return buf[:pos]
+}
+
+/******************************************************************************
+* Sync utils *
+******************************************************************************/
+
+// noCopy may be embedded into structs which must not be copied
+// after the first use.
+//
+// See https://github.com/golang/go/issues/8005#issuecomment-190753527
+// for details.
+type noCopy struct{}
+
+// Lock is a no-op used by -copylocks checker from `go vet`.
+func (*noCopy) Lock() {}
+
+// atomicBool is a wrapper around uint32 for usage as a boolean value with
+// atomic access.
+type atomicBool struct {
+ _noCopy noCopy
+ value uint32
+}
+
+// IsSet returns whether the current boolean value is true
+func (ab *atomicBool) IsSet() bool {
+ return atomic.LoadUint32(&ab.value) > 0
+}
+
+// Set sets the value of the bool regardless of the previous value
+func (ab *atomicBool) Set(value bool) {
+ if value {
+ atomic.StoreUint32(&ab.value, 1)
+ } else {
+ atomic.StoreUint32(&ab.value, 0)
+ }
+}
+
+// TrySet sets the value of the bool and returns whether the value changed
+func (ab *atomicBool) TrySet(value bool) bool {
+ if value {
+ return atomic.SwapUint32(&ab.value, 1) == 0
+ }
+ return atomic.SwapUint32(&ab.value, 0) > 0
+}
+
+// atomicError is a wrapper for atomically accessed error values
+type atomicError struct {
+ _noCopy noCopy
+ value atomic.Value
+}
+
+// Set sets the error value regardless of the previous value.
+// The value must not be nil
+func (ae *atomicError) Set(value error) {
+ ae.value.Store(value)
+}
+
+// Value returns the current error value
+func (ae *atomicError) Value() error {
+ if v := ae.value.Load(); v != nil {
+ // this will panic if the value doesn't implement the error interface
+ return v.(error)
+ }
+ return nil
+}
+
+func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
+ dargs := make([]driver.Value, len(named))
+ for n, param := range named {
+ if len(param.Name) > 0 {
+ // TODO: support the use of Named Parameters #561
+ return nil, errors.New("mysql: driver does not support the use of Named Parameters")
+ }
+ dargs[n] = param.Value
+ }
+ return dargs, nil
+}
+
+func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
+ switch sql.IsolationLevel(level) {
+ case sql.LevelRepeatableRead:
+ return "REPEATABLE READ", nil
+ case sql.LevelReadCommitted:
+ return "READ COMMITTED", nil
+ case sql.LevelReadUncommitted:
+ return "READ UNCOMMITTED", nil
+ case sql.LevelSerializable:
+ return "SERIALIZABLE", nil
+ default:
+ return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
+ }
+}
diff --git a/vendor/mysql/utils_test.go b/vendor/mysql/utils_test.go
new file mode 100644
index 0000000..67b132d
--- /dev/null
+++ b/vendor/mysql/utils_test.go
@@ -0,0 +1,508 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "bytes"
+ "database/sql"
+ "database/sql/driver"
+ "encoding/binary"
+ "testing"
+ "time"
+)
+
+func TestLengthEncodedInteger(t *testing.T) {
+ var integerTests = []struct {
+ num uint64
+ encoded []byte
+ }{
+ {0x0000000000000000, []byte{0x00}},
+ {0x0000000000000012, []byte{0x12}},
+ {0x00000000000000fa, []byte{0xfa}},
+ {0x0000000000000100, []byte{0xfc, 0x00, 0x01}},
+ {0x0000000000001234, []byte{0xfc, 0x34, 0x12}},
+ {0x000000000000ffff, []byte{0xfc, 0xff, 0xff}},
+ {0x0000000000010000, []byte{0xfd, 0x00, 0x00, 0x01}},
+ {0x0000000000123456, []byte{0xfd, 0x56, 0x34, 0x12}},
+ {0x0000000000ffffff, []byte{0xfd, 0xff, 0xff, 0xff}},
+ {0x0000000001000000, []byte{0xfe, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}},
+ {0x123456789abcdef0, []byte{0xfe, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12}},
+ {0xffffffffffffffff, []byte{0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
+ }
+
+ for _, tst := range integerTests {
+ num, isNull, numLen := readLengthEncodedInteger(tst.encoded)
+ if isNull {
+ t.Errorf("%x: expected %d, got NULL", tst.encoded, tst.num)
+ }
+ if num != tst.num {
+ t.Errorf("%x: expected %d, got %d", tst.encoded, tst.num, num)
+ }
+ if numLen != len(tst.encoded) {
+ t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen)
+ }
+ encoded := appendLengthEncodedInteger(nil, num)
+ if !bytes.Equal(encoded, tst.encoded) {
+ t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded)
+ }
+ }
+}
+
+func TestFormatBinaryDateTime(t *testing.T) {
+ rawDate := [11]byte{}
+ binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years
+ rawDate[2] = 12 // months
+ rawDate[3] = 30 // days
+ rawDate[4] = 15 // hours
+ rawDate[5] = 46 // minutes
+ rawDate[6] = 23 // seconds
+ binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds
+ expect := func(expected string, inlen, outlen uint8) {
+ actual, _ := formatBinaryDateTime(rawDate[:inlen], outlen)
+ bytes, ok := actual.([]byte)
+ if !ok {
+ t.Errorf("formatBinaryDateTime must return []byte, was %T", actual)
+ }
+ if string(bytes) != expected {
+ t.Errorf(
+ "expected %q, got %q for length in %d, out %d",
+ expected, actual, inlen, outlen,
+ )
+ }
+ }
+ expect("0000-00-00", 0, 10)
+ expect("0000-00-00 00:00:00", 0, 19)
+ expect("1978-12-30", 4, 10)
+ expect("1978-12-30 15:46:23", 7, 19)
+ expect("1978-12-30 15:46:23.987654", 11, 26)
+}
+
+func TestFormatBinaryTime(t *testing.T) {
+ expect := func(expected string, src []byte, outlen uint8) {
+ actual, _ := formatBinaryTime(src, outlen)
+ bytes, ok := actual.([]byte)
+ if !ok {
+ t.Errorf("formatBinaryDateTime must return []byte, was %T", actual)
+ }
+ if string(bytes) != expected {
+ t.Errorf(
+ "expected %q, got %q for src=%q and outlen=%d",
+ expected, actual, src, outlen)
+ }
+ }
+
+ // binary format:
+ // sign (0: positive, 1: negative), days(4), hours, minutes, seconds, micro(4)
+
+ // Zeros
+ expect("00:00:00", []byte{}, 8)
+ expect("00:00:00.0", []byte{}, 10)
+ expect("00:00:00.000000", []byte{}, 15)
+
+ // Without micro(4)
+ expect("12:34:56", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 8)
+ expect("-12:34:56", []byte{1, 0, 0, 0, 0, 12, 34, 56}, 8)
+ expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 11)
+ expect("24:34:56", []byte{0, 1, 0, 0, 0, 0, 34, 56}, 8)
+ expect("-99:34:56", []byte{1, 4, 0, 0, 0, 3, 34, 56}, 8)
+ expect("103079215103:34:56", []byte{0, 255, 255, 255, 255, 23, 34, 56}, 8)
+
+ // With micro(4)
+ expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 11)
+ expect("12:34:56.000099", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 15)
+}
+
+func TestEscapeBackslash(t *testing.T) {
+ expect := func(expected, value string) {
+ actual := string(escapeBytesBackslash([]byte{}, []byte(value)))
+ if actual != expected {
+ t.Errorf(
+ "expected %s, got %s",
+ expected, actual,
+ )
+ }
+
+ actual = string(escapeStringBackslash([]byte{}, value))
+ if actual != expected {
+ t.Errorf(
+ "expected %s, got %s",
+ expected, actual,
+ )
+ }
+ }
+
+ expect("foo\\0bar", "foo\x00bar")
+ expect("foo\\nbar", "foo\nbar")
+ expect("foo\\rbar", "foo\rbar")
+ expect("foo\\Zbar", "foo\x1abar")
+ expect("foo\\\"bar", "foo\"bar")
+ expect("foo\\\\bar", "foo\\bar")
+ expect("foo\\'bar", "foo'bar")
+}
+
+func TestEscapeQuotes(t *testing.T) {
+ expect := func(expected, value string) {
+ actual := string(escapeBytesQuotes([]byte{}, []byte(value)))
+ if actual != expected {
+ t.Errorf(
+ "expected %s, got %s",
+ expected, actual,
+ )
+ }
+
+ actual = string(escapeStringQuotes([]byte{}, value))
+ if actual != expected {
+ t.Errorf(
+ "expected %s, got %s",
+ expected, actual,
+ )
+ }
+ }
+
+ expect("foo\x00bar", "foo\x00bar") // not affected
+ expect("foo\nbar", "foo\nbar") // not affected
+ expect("foo\rbar", "foo\rbar") // not affected
+ expect("foo\x1abar", "foo\x1abar") // not affected
+ expect("foo''bar", "foo'bar") // affected
+ expect("foo\"bar", "foo\"bar") // not affected
+}
+
+func TestAtomicBool(t *testing.T) {
+ var ab atomicBool
+ if ab.IsSet() {
+ t.Fatal("Expected value to be false")
+ }
+
+ ab.Set(true)
+ if ab.value != 1 {
+ t.Fatal("Set(true) did not set value to 1")
+ }
+ if !ab.IsSet() {
+ t.Fatal("Expected value to be true")
+ }
+
+ ab.Set(true)
+ if !ab.IsSet() {
+ t.Fatal("Expected value to be true")
+ }
+
+ ab.Set(false)
+ if ab.value != 0 {
+ t.Fatal("Set(false) did not set value to 0")
+ }
+ if ab.IsSet() {
+ t.Fatal("Expected value to be false")
+ }
+
+ ab.Set(false)
+ if ab.IsSet() {
+ t.Fatal("Expected value to be false")
+ }
+ if ab.TrySet(false) {
+ t.Fatal("Expected TrySet(false) to fail")
+ }
+ if !ab.TrySet(true) {
+ t.Fatal("Expected TrySet(true) to succeed")
+ }
+ if !ab.IsSet() {
+ t.Fatal("Expected value to be true")
+ }
+
+ ab.Set(true)
+ if !ab.IsSet() {
+ t.Fatal("Expected value to be true")
+ }
+ if ab.TrySet(true) {
+ t.Fatal("Expected TrySet(true) to fail")
+ }
+ if !ab.TrySet(false) {
+ t.Fatal("Expected TrySet(false) to succeed")
+ }
+ if ab.IsSet() {
+ t.Fatal("Expected value to be false")
+ }
+
+ ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯
+}
+
+func TestAtomicError(t *testing.T) {
+ var ae atomicError
+ if ae.Value() != nil {
+ t.Fatal("Expected value to be nil")
+ }
+
+ ae.Set(ErrMalformPkt)
+ if v := ae.Value(); v != ErrMalformPkt {
+ if v == nil {
+ t.Fatal("Value is still nil")
+ }
+ t.Fatal("Error did not match")
+ }
+ ae.Set(ErrPktSync)
+ if ae.Value() == ErrMalformPkt {
+ t.Fatal("Error still matches old error")
+ }
+ if v := ae.Value(); v != ErrPktSync {
+ t.Fatal("Error did not match")
+ }
+}
+
+func TestIsolationLevelMapping(t *testing.T) {
+ data := []struct {
+ level driver.IsolationLevel
+ expected string
+ }{
+ {
+ level: driver.IsolationLevel(sql.LevelReadCommitted),
+ expected: "READ COMMITTED",
+ },
+ {
+ level: driver.IsolationLevel(sql.LevelRepeatableRead),
+ expected: "REPEATABLE READ",
+ },
+ {
+ level: driver.IsolationLevel(sql.LevelReadUncommitted),
+ expected: "READ UNCOMMITTED",
+ },
+ {
+ level: driver.IsolationLevel(sql.LevelSerializable),
+ expected: "SERIALIZABLE",
+ },
+ }
+
+ for i, td := range data {
+ if actual, err := mapIsolationLevel(td.level); actual != td.expected || err != nil {
+ t.Fatal(i, td.expected, actual, err)
+ }
+ }
+
+ // check unsupported mapping
+ expectedErr := "mysql: unsupported isolation level: 7"
+ actual, err := mapIsolationLevel(driver.IsolationLevel(sql.LevelLinearizable))
+ if actual != "" || err == nil {
+ t.Fatal("Expected error on unsupported isolation level")
+ }
+ if err.Error() != expectedErr {
+ t.Fatalf("Expected error to be %q, got %q", expectedErr, err)
+ }
+}
+
+func TestAppendDateTime(t *testing.T) {
+ tests := []struct {
+ t time.Time
+ str string
+ }{
+ {
+ t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC),
+ str: "1234-05-06",
+ },
+ {
+ t: time.Date(4567, 12, 31, 12, 0, 0, 0, time.UTC),
+ str: "4567-12-31 12:00:00",
+ },
+ {
+ t: time.Date(2020, 5, 30, 12, 34, 0, 0, time.UTC),
+ str: "2020-05-30 12:34:00",
+ },
+ {
+ t: time.Date(2020, 5, 30, 12, 34, 56, 0, time.UTC),
+ str: "2020-05-30 12:34:56",
+ },
+ {
+ t: time.Date(2020, 5, 30, 22, 33, 44, 123000000, time.UTC),
+ str: "2020-05-30 22:33:44.123",
+ },
+ {
+ t: time.Date(2020, 5, 30, 22, 33, 44, 123456000, time.UTC),
+ str: "2020-05-30 22:33:44.123456",
+ },
+ {
+ t: time.Date(2020, 5, 30, 22, 33, 44, 123456789, time.UTC),
+ str: "2020-05-30 22:33:44.123456789",
+ },
+ {
+ t: time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC),
+ str: "9999-12-31 23:59:59.999999999",
+ },
+ {
+ t: time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC),
+ str: "0001-01-01",
+ },
+ }
+ for _, v := range tests {
+ buf := make([]byte, 0, 32)
+ buf, _ = appendDateTime(buf, v.t)
+ if str := string(buf); str != v.str {
+ t.Errorf("appendDateTime(%v), have: %s, want: %s", v.t, str, v.str)
+ }
+ }
+
+ // year out of range
+ {
+ v := time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC)
+ buf := make([]byte, 0, 32)
+ _, err := appendDateTime(buf, v)
+ if err == nil {
+ t.Error("want an error")
+ return
+ }
+ }
+ {
+ v := time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC)
+ buf := make([]byte, 0, 32)
+ _, err := appendDateTime(buf, v)
+ if err == nil {
+ t.Error("want an error")
+ return
+ }
+ }
+}
+
+func TestParseDateTime(t *testing.T) {
+ cases := []struct {
+ name string
+ str string
+ }{
+ {
+ name: "parse date",
+ str: "2020-05-13",
+ },
+ {
+ name: "parse null date",
+ str: sDate0,
+ },
+ {
+ name: "parse datetime",
+ str: "2020-05-13 21:30:45",
+ },
+ {
+ name: "parse null datetime",
+ str: sDateTime0,
+ },
+ {
+ name: "parse datetime nanosec 1-digit",
+ str: "2020-05-25 23:22:01.1",
+ },
+ {
+ name: "parse datetime nanosec 2-digits",
+ str: "2020-05-25 23:22:01.15",
+ },
+ {
+ name: "parse datetime nanosec 3-digits",
+ str: "2020-05-25 23:22:01.159",
+ },
+ {
+ name: "parse datetime nanosec 4-digits",
+ str: "2020-05-25 23:22:01.1594",
+ },
+ {
+ name: "parse datetime nanosec 5-digits",
+ str: "2020-05-25 23:22:01.15949",
+ },
+ {
+ name: "parse datetime nanosec 6-digits",
+ str: "2020-05-25 23:22:01.159491",
+ },
+ }
+
+ for _, loc := range []*time.Location{
+ time.UTC,
+ time.FixedZone("test", 8*60*60),
+ } {
+ for _, cc := range cases {
+ t.Run(cc.name+"-"+loc.String(), func(t *testing.T) {
+ var want time.Time
+ if cc.str != sDate0 && cc.str != sDateTime0 {
+ var err error
+ want, err = time.ParseInLocation(timeFormat[:len(cc.str)], cc.str, loc)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ got, err := parseDateTime([]byte(cc.str), loc)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !want.Equal(got) {
+ t.Fatalf("want: %v, but got %v", want, got)
+ }
+ })
+ }
+ }
+}
+
+func TestParseDateTimeFail(t *testing.T) {
+ cases := []struct {
+ name string
+ str string
+ wantErr string
+ }{
+ {
+ name: "parse invalid time",
+ str: "hello",
+ wantErr: "invalid time bytes: hello",
+ },
+ {
+ name: "parse year",
+ str: "000!-00-00 00:00:00.000000",
+ wantErr: "not [0-9]",
+ },
+ {
+ name: "parse month",
+ str: "0000-!0-00 00:00:00.000000",
+ wantErr: "not [0-9]",
+ },
+ {
+ name: `parse "-" after parsed year`,
+ str: "0000:00-00 00:00:00.000000",
+ wantErr: "bad value for field: `:`",
+ },
+ {
+ name: `parse "-" after parsed month`,
+ str: "0000-00:00 00:00:00.000000",
+ wantErr: "bad value for field: `:`",
+ },
+ {
+ name: `parse " " after parsed date`,
+ str: "0000-00-00+00:00:00.000000",
+ wantErr: "bad value for field: `+`",
+ },
+ {
+ name: `parse ":" after parsed date`,
+ str: "0000-00-00 00-00:00.000000",
+ wantErr: "bad value for field: `-`",
+ },
+ {
+ name: `parse ":" after parsed hour`,
+ str: "0000-00-00 00:00-00.000000",
+ wantErr: "bad value for field: `-`",
+ },
+ {
+ name: `parse "." after parsed sec`,
+ str: "0000-00-00 00:00:00?000000",
+ wantErr: "bad value for field: `?`",
+ },
+ }
+
+ for _, cc := range cases {
+ t.Run(cc.name, func(t *testing.T) {
+ got, err := parseDateTime([]byte(cc.str), time.UTC)
+ if err == nil {
+ t.Fatal("want error")
+ }
+ if cc.wantErr != err.Error() {
+ t.Fatalf("want `%s`, but got `%s`", cc.wantErr, err)
+ }
+ if !got.IsZero() {
+ t.Fatal("want zero time")
+ }
+ })
+ }
+}
diff --git a/vendor/syscon/syscon.go b/vendor/syscon/syscon.go
new file mode 100644
index 0000000..50254e0
--- /dev/null
+++ b/vendor/syscon/syscon.go
@@ -0,0 +1,203 @@
+package syscon
+
+import (
+ "bufio"
+ "bytes"
+ "config"
+ "datax"
+ "errors"
+ "fmt"
+ "net"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strconv"
+ "strings"
+ "syslog"
+ "time"
+)
+
+func StartNewService(conf *config.Config) {
+ if err := datax.SQLInit(conf); err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+ defer datax.DBClose()
+ lisfunc, isthere := networkListenType[conf.GetConf("listener_type", 0)]
+ if !isthere {
+ panic(syslog.BigError{Why: errors.New("invalid listener_type in config file, should be either inet or unix"), Cod: 1})
+ }
+ l := lisfunc(conf)
+ defer l.Close()
+ syslog.InformGreen("connection to sql server established")
+ syn := regexp.MustCompile(`(^sasl_username=.*[^\s].*)$`)
+
+ for {
+ conn, err := l.Accept()
+ if err != nil {
+ syslog.InformError(err)
+ continue
+ }
+
+ go handlePostfixConn(conn, syn)
+ }
+
+}
+
+func handlePostfixConn(conn net.Conn, syn *regexp.Regexp) {
+ defer conn.Close()
+ scanner := bufio.NewScanner(conn)
+ for scanner.Scan() {
+ if find := syn.MatchString(scanner.Text()); find {
+ pureName := strings.ReplaceAll(scanner.Text(), "sasl_username=", "")
+ userAcc, err := datax.GetUserFromDatabase(&pureName)
+ if err != nil {
+ syslog.InformError(err)
+ _, err := conn.Write([]byte("action=REJECT fixrate error: internal server error <1>, contact administrator\n\n"))
+ if err != nil {
+ syslog.InformError(err)
+ return
+ }
+ continue
+ }
+
+ if userAcc.LastReset.Add(time.Duration(userAcc.Reset) * time.Second).Before(time.Now()) {
+ if err := userAcc.UpdateUserLastReset(time.Now()); err != nil {
+ syslog.InformError(err)
+ _, err := conn.Write([]byte("action=REJECT fixrate error: internal server error <2>, contact administrator\n\n"))
+ if err != nil {
+ syslog.InformError(err)
+ return
+ }
+ continue
+ }
+
+ if err := userAcc.UpdateUserCounter(0); err != nil {
+ syslog.InformError(err)
+ _, err := conn.Write([]byte("action=REJECT fixrate error: internal server error <3>, contact administrator\n\n"))
+ if err != nil {
+ syslog.InformError(err)
+ return
+ }
+ continue
+ }
+ userAcc.Counter = 0
+ }
+ if userAcc.Limit <= userAcc.Counter {
+ diff := userAcc.LastReset.Add(time.Duration(userAcc.Reset) * time.Second).Sub(time.Now())
+ _, err := conn.Write([]byte(fmt.Sprintf("action=REJECT fixrate: sending limit exceeded, you can't send anything until next %v\n\n", diff.Round(time.Second))))
+ if err != nil {
+ syslog.InformError(err)
+ return
+ }
+ continue
+ }
+ if err := userAcc.UpdateUserCounter(userAcc.Counter + 1); err != nil {
+ syslog.InformError(err)
+ _, err := conn.Write([]byte("action=REJECT fixrate error: internal server error <4>, contact administrator\n\n"))
+ if err != nil {
+ syslog.InformError(err)
+ return
+ }
+ continue
+ }
+ if _, err := conn.Write([]byte("action=OK\n\n")); err != nil {
+ syslog.InformError(err)
+ return
+ }
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+}
+
+func checkUinixSocket(socketAddr string) (string, bool) {
+ unix, err := os.Open("/proc/net/unix")
+ if err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+ defer unix.Close()
+ u := bufio.NewScanner(unix)
+ for u.Scan() {
+ if bytes.Contains(u.Bytes(), []byte(socketAddr)) {
+ pid := findPid(strings.Fields(u.Text())[6])
+ return pid, true
+ }
+ }
+ if err := u.Err(); err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+ return "", false
+}
+
+var networkListenType = map[string]func(conf *config.Config) net.Listener{
+ "unix": func(conf *config.Config) net.Listener {
+ socketAddr := conf.GetConf("socket_path", 0)
+ if pid, exists := checkUinixSocket(socketAddr); exists {
+ panic(syslog.BigError{Why: fmt.Errorf("unix socket %v is held by another process (pid: %v), is another fixrate daemon runing?", socketAddr, pid), Cod: 1})
+ }
+ if err := os.RemoveAll(socketAddr); err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+ SockPerm, err := strconv.Atoi(conf.GetConf("socket_perm", 0))
+ if err != nil {
+ panic(syslog.BigError{Why: errors.New("socket_path should hold an integer"), Cod: 1})
+ }
+ l, err := net.Listen("unix", socketAddr)
+ if err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+
+ mod := os.FileMode(SockPerm)
+ if err = os.Chmod(socketAddr, mod); err != nil {
+ l.Close()
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+ syslog.InformGreen("start listening on unix socket", socketAddr)
+ return l
+ },
+ "inet": func(conf *config.Config) net.Listener {
+ listenAddr := conf.GetConf("listen_addr", 0)
+ l, err := net.Listen("tcp", listenAddr)
+ if err != nil {
+ panic(syslog.BigError{Why: err, Cod: 1})
+ }
+ syslog.InformGreen("start listening on tcp address", listenAddr)
+ return l
+ },
+}
+
+func findPid(inode string) string {
+ type sysInodes struct {
+ path string
+ link string
+ }
+ var pid string
+ fd, err := filepath.Glob("/proc/[0-9]*/fd/[0-9]*")
+ if err != nil {
+ return pid
+ }
+
+ inodes := make([]sysInodes, len(fd))
+ mx := make(chan sysInodes, len(fd))
+
+ go func(fd *[]string, outchan chan<- sysInodes) {
+ for _, item := range *fd {
+ link, _ := os.Readlink(item)
+ outchan <- sysInodes{item, link}
+ }
+ }(&fd, mx)
+
+ for range fd {
+ inodes = append(inodes, <-mx)
+ }
+
+ re := regexp.MustCompile(inode)
+ for _, item := range inodes {
+ out := re.FindString(item.link)
+ if len(out) != 0 {
+ pid = strings.Split(item.path, "/")[2]
+ }
+ }
+ return pid
+}
diff --git a/vendor/syslog/logger.go b/vendor/syslog/logger.go
new file mode 100644
index 0000000..fb3354c
--- /dev/null
+++ b/vendor/syslog/logger.go
@@ -0,0 +1,38 @@
+package syslog
+
+import (
+ "fmt"
+ "log"
+ "os"
+)
+
+type BigError struct {
+ Why error
+ Cod int
+ Pid int
+}
+
+func HandlePan() {
+ if hap := recover(); hap != nil {
+ if ms, owkey := hap.(BigError); owkey {
+ ms.Pid = os.Getpid()
+ fmt.Println("\033[31mfatal:\033[0m", ms.Why, "\nprocess", ms.Pid, "exit with ststus", ms.Cod)
+ os.Exit(ms.Cod)
+ }
+ panic(hap)
+ }
+}
+
+func InformYellow(h ...interface{}) {
+ h = append([]interface{}{("\033[33minfo:\033[0m")}, h...)
+ log.Println(h...)
+}
+
+func InformGreen(h ...interface{}) {
+ h = append([]interface{}{("\033[32minfo:\033[0m")}, h...)
+ log.Println(h...)
+}
+func InformError(h ...interface{}) {
+ h = append([]interface{}{("\033[31merror:\033[0m")}, h...)
+ log.Println(h...)
+}

Snix LLC Git Repository Holder Copyright(C) 2022 All Rights Reserved Email To Snix.IR