用于 Go 的 Databricks SQL 驱动程序

Databricks SQL Driver for Go 是一个 Go 库,它让你可以使用 Go 代码在 Azure Databricks 计算资源上运行 SQL 命令。 本文是对 Databricks SQL Driver for Go READMEAPI 参考示例的补充。

要求

Databricks SQL Driver for Go 入门

  1. 在已安装 Go 1.20 或更高版本且已创建现有 Go 代码项目的开发计算机上,通过运行 go mod init 命令创建一个 go.mod 文件来跟踪 Go 代码的依赖项,例如:

    go mod init sample
    
  2. 通过运行 go mod edit -require 命令(将 v1.5.2 替换为版本中列出的 Databricks SQL Driver for Go 包的最新版本)来依赖于 Databricks SQL Driver for Go 包:

    go mod edit -require github.com/databricks/databricks-sql-go@v1.5.2
    

    go.mod 文件现在应如下所示:

    module sample
    
    go 1.20
    
    require github.com/databricks/databricks-sql-go v1.5.2
    
  3. 在你的项目中,创建一个 Go 代码文件,用于导入 Databricks SQL Driver for Go。 以下示例位于包含以下内容的名为 main.go 的文件中,它列出 Azure Databricks 工作区中的所有群集:

    package main
    
    import (
      "database/sql"
      "os"
      _ "github.com/databricks/databricks-sql-go"
    )
    
    func main() {
      dsn := os.Getenv("DATABRICKS_DSN")
    
      if dsn == "" {
        panic("No connection string found. " +
         "Set the DATABRICKS_DSN environment variable, and try again.")
      }
    
      db, err := sql.Open("databricks", dsn)
      if err != nil {
        panic(err)
      }
      defer db.Close()
    
      if err := db.Ping(); err != nil {
        panic(err)
      }
    }
    
  4. 通过运行 go mod tidy 命令添加任何缺少的模块依赖项:

    go mod tidy
    

    注意

    如果收到错误 go: warning: "all" matched no packages,则表示你忘记添加用于导入 Databricks SQL Driver for Go 的 Go 代码文件。

  5. 通过运行 go mod vendor 命令,制作支持生成和测试 main 模块中的包所需的所有包的副本:

    go mod vendor
    
  6. 根据需要修改代码以设置 Azure Databricks 身份验证DATABRICKS_DSN 环境变量。 另请参阅使用 DSN 连接字符串进行连接

  7. 通过运行 go run 命令来运行 Go 代码文件(假设文件名为 main.go):

    go run main.go
    
  8. 如果未返回任何错误,则表示你已成功使用 Azure Databricks 工作区对 Databricks SQL Driver for Go 进行了身份验证,并连接到了该工作区中正在运行的 Azure Databricks 群集或 SQL 仓库。

使用 DSN 连接字符串进行连接

若要访问群集和 SQL 仓库,请使用 sql.Open() 通过数据源名称 (DSN) 连接字符串创建数据库句柄。 此代码示例从名为 DATABRICKS_DSN 的环境变量中检索 DSN 连接字符串:

package main

import (
  "database/sql"
  "os"
  _ "github.com/databricks/databricks-sql-go"
)

func main() {
  dsn := os.Getenv("DATABRICKS_DSN")

  if dsn == "" {
    panic("No connection string found. " +
          "Set the DATABRICKS_DSN environment variable, and try again.")
  }

  db, err := sql.Open("databricks", dsn)
  if err != nil {
    panic(err)
  }
  defer db.Close()

  if err := db.Ping(); err != nil {
    panic(err)
  }
}

要以正确的格式指定 DSN 连接字符串,请参阅身份验证中的 DSN 连接字符串示例。 例如,对于 Azure Databricks 个人访问令牌身份验证,请使用以下语法,其中:

  • <personal-access-token> 是要求中的 Azure Databricks 个人访问令牌。
  • <server-hostname> 是要求中的“服务器主机名”值。
  • <port-number> 是要求中的端口值,通常是 443
  • <http-path> 是要求中的“HTTP 路径”值。
  • <paramX=valueX> 是本文稍后列出的一个或多个可选参数
token:<personal-access-token>@<server-hostname>:<port-number>/<http-path>?<param1=value1>&<param2=value2>

例如,对于群集,请执行以下操作:

token:<36bit_Your_Token_PlaceHolder>@adb-1234567890123456.7.databricks.azure.cn:443/sql/protocolv1/o/1234567890123456/1234-567890-abcdefgh

例如,对于 SQL 仓库:

token:<36bit_Your_Token_PlaceHolder>@adb-1234567890123456.7.databricks.azure.cn:443/sql/1.0/endpoints/a1b234c5678901d2

注意

作为安全最佳做法,不应将此 DSN 连接字符串硬编码到 Go 代码中。 而应从安全位置检索此 DSN 连接字符串。 例如,本文前面部分中的代码示例使用了环境变量。

可选参数

  • 可以在 <param=value> 中指定受支持的可选连接参数。 一些较常用的方法包括:
    • catalog:设置会话中的初始目录名称。
    • schema:设置会话中的初始架构名称。
    • maxRows:设置每个请求提取的最大行数。 默认为 10000
    • timeout:为服务器查询执行添加超时(以秒为单位)。 默认设置为无超时。
    • userAgentEntry:用于标识合作伙伴。 有关详细信息,请参阅合作伙伴的文档。
  • 可以在 param=value 中指定受支持的可选会话参数。 一些较常用的方法包括:
    • ansi_mode:一个布尔字符串。 true 值表示会话语句符合 ANSI SQL 规范指定的规则。 系统默认值为 false。
    • timezone:一个字符串,例如 China/Beijing。 设置会话的时区。 系统默认值为 UTC。

例如,对于 SQL 仓库:

token:<36bit_Your_Token_PlaceHolder>@adb-1234567890123456.7.databricks.azure.cn:443/sql/1.0/endpoints/a1b234c5678901d2?catalog=hive_metastore&schema=example&maxRows=100&timeout=60&timezone=China/Beijing&ansi_mode=true

使用 NewConnector 函数进行连接

或者,使用 sql.OpenDB() 通过使用 dbsql.NewConnector() 创建的新连接器对象创建数据库句柄(使用新连接器对象连接到 Azure Databricks 群集和 SQL 仓库需要 v1.0.0 或更高版本的适用于 Go 的 Databricks SQL 驱动程序)。 例如:

package main

import (
  "database/sql"
  "os"
  dbsql "github.com/databricks/databricks-sql-go"
)

func main() {
  connector, err := dbsql.NewConnector(
    dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESS_TOKEN")),
    dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
    dbsql.WithPort(443),
    dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTP_PATH")),
  )
  if err != nil {
    panic(err)
  }

  db := sql.OpenDB(connector)
  defer db.Close()

  if err := db.Ping(); err != nil {
    panic(err)
  }
}

要指定正确的 NewConnector 设置集,请参阅身份验证中的示例。

注意

作为安全最佳做法,不应将 NewConnector 设置硬编码到 Go 代码中。 而应该从安全位置检索这些值。 例如,前面的代码使用环境变量。

一些最常用的函数选项包括:

  • WithAccessToken(<access-token>):要求中你的 Azure Databricks 个人访问令牌。 要求 string
  • WithServerHostname(<server-hostname>):要求中的“服务器主机名”值。 必需 string
  • WithPort(<port>):服务器的端口号,通常为 443。 必需 int
  • WithHTTPPath(<http-path>):要求中的“HTTP 路径”值。 要求 string
  • WithInitialNamespace(<catalog>, <schema>):会话中的目录和架构名称。 可选 string, string
  • WithMaxRows(<max-rows>):每个请求提取的最大行数。 默认为 10000. 可选 int
  • WithSessionParams(<params-map>):会话参数,包括“timezone”和“ansi_mode”。 可选 map[string]string
  • WithTimeout(<timeout>)。 服务器查询执行的超时(以 time.Duration 表示)。 默认设置为无超时。 可选。
  • WithUserAgentEntry(<isv-name-plus-product-name>)。 用于标识合作伙伴。 有关详细信息,请参阅合作伙伴的文档。 可选 string

例如:

connector, err := dbsql.NewConnector(
  dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESS_TOKEN")),
  dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
  dbsql.WithPort(443),
  dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTP_PATH")),
  dbsql.WithInitialNamespace("samples", "nyctaxi"),
  dbsql.WithMaxRows(100),
  dbsql.SessionParams(map[string]string{"timezone": "America/Sao_Paulo", "ansi_mode": "true"}),
  dbsql.WithTimeout(time.Minute),
  dbsql.WithUserAgentEntry("example-user"),
)

身份验证

Databricks SQL Driver for Go 支持以下 Azure Databricks 身份验证类型:

Databricks SQL Driver for Go 尚不支持以下 Azure Databricks 身份验证类型:

Databricks 个人访问令牌身份验证

要将 Databricks SQL Driver for Go 与 Azure Databricks 个人访问令牌身份验证配合使用,你必须先创建一个 Azure Databricks 个人访问令牌,如下所示:

  1. 在 Azure Databricks 工作区中,单击顶部栏中的 Azure Databricks 用户名,然后从下拉列表中选择“设置”
  2. 单击“开发人员”。
  3. 在“访问令牌”旁边,单击“管理”。
  4. 单击“生成新令牌”。
  5. (可选)输入有助于将来识别此令牌的注释,并将令牌的默认生存期更改为 90 天。 若要创建没有生存期的令牌(不建议),请将“生存期(天)”框留空(保留空白)。
  6. 单击“生成” 。
  7. 将显示的令牌复制到安全位置,然后单击“完成”。

注意

请务必将复制的令牌保存到安全的位置。 请勿与他人共享复制的令牌。 如果丢失了复制的令牌,你将无法重新生成完全相同的令牌, 而必须重复此过程来创建新令牌。 如果丢失了复制的令牌,或者认为令牌已泄露,Databricks 强烈建议通过单击“访问令牌”页上令牌旁边的垃圾桶(撤销)图标立即从工作区中删除该令牌。

如果你无法在工作区中创建或使用令牌,可能是因为工作区管理员已禁用令牌或未授予你创建或使用令牌的权限。 请与工作区管理员联系,或参阅以下主题:

要使用 DSN 连接字符串和使用 DSN 连接字符串进行连接中的代码示例对 Databricks SQL Driver for Go 进行身份验证,请使用以下 DSN 连接字符串语法,其中:

  • <personal-access-token> 是要求中的 Azure Databricks 个人访问令牌。
  • <server-hostname> 是要求中的“服务器主机名”值。
  • <port-number> 是要求中的端口值,通常是 443
  • <http-path> 是要求中的“HTTP 路径”值。

还可以追加本文前面列出的一个或多个可选参数

token:<personal-access-token>@<server-hostname>:<port-number>/<http-path>

要使用 NewConnector 函数对 Databricks SQL Driver for Go 进行身份验证,请使用以下代码片段和使用 NewConnector 函数进行连接中的代码示例,它假定你已设置以下环境变量:

  • DATABRICKS_SERVER_HOSTNAME,设置为你的群集或 SQL 仓库的服务器主机名值。
  • DATABRICKS_HTTP_PATH,设置为你的群集或 SQL 仓库的 HTTP 路径值。
  • DATABRICKS_TOKEN,设置为 Azure Databricks 个人访问令牌。

若要设置环境变量,请参阅操作系统的文档。

connector, err := dbsql.NewConnector(
  dbsql.WithServerHostname(os.Getenv("DATABRICKS_SERVER_HOSTNAME")),
  dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTP_PATH")),
  dbsql.WithPort(443),
  dbsql.WithAccessToken(os.Getenv("DATABRICKS_TOKEN")),
)

OAuth 用户到计算机 (U2M) 身份验证

Databricks SQL Driver for Go 版本 1.5.0 及更高版本支持 OAuth 用户到计算机 (U2M) 身份验证

要通过 DSN 连接字符串和使用 DSN 连接字符串进行连接中的代码示例来使用 Databricks SQL Driver for Go,请使用以下 DSN 连接字符串语法,其中:

  • <server-hostname> 是要求中的“服务器主机名”值。
  • <port-number> 是要求中的端口值,通常是 443
  • <http-path> 是要求中的“HTTP 路径”值。

还可以追加本文前面列出的一个或多个可选参数

<server-hostname>:<port-number>/<http-path>?authType=OauthU2M

要使用 NewConnector 函数对 Databricks SQL Driver for Go 进行身份验证,必须先将以下内容添加到 import 声明中:

"github.com/databricks/databricks-sql-go/auth/oauth/u2m"

然后使用以下代码片段和使用 NewConnector 函数进行连接中的代码示例,它假定你已设置以下环境变量:

  • DATABRICKS_SERVER_HOSTNAME,设置为你的群集或 SQL 仓库的服务器主机名值。
  • DATABRICKS_HTTP_PATH,设置为你的群集或 SQL 仓库的 HTTP 路径值。

若要设置环境变量,请参阅操作系统对应的文档。

authenticator, err := u2m.NewAuthenticator(os.Getenv("DATABRICKS_SERVER_HOSTNAME"), 1*time.Minute)
if err != nil {
  panic(err)
}

connector, err := dbsql.NewConnector(
  dbsql.WithServerHostname(os.Getenv("DATABRICKS_SERVER_HOSTNAME")),
  dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTP_PATH")),
  dbsql.WithPort(443),
  dbsql.WithAuthenticator(authenticator),
)

OAuth 计算机到计算机 (M2M) 身份验证

Databricks SQL Driver for Go 版本 1.5.2 及更高版本支持 OAuth 计算机到计算机 (M2M) 身份验证

要将 Databricks SQL Driver for Go 与 OAuth M2M 身份验证配合使用,必须执行以下操作:

  1. 在 Azure Databricks 工作区中创建 Azure Databricks 服务主体,并为该服务主体创建 OAuth 机密。

    若要创建服务主体及其 OAuth 机密,请参阅使用 OAuth (OAuth M2M) 通过服务主体对 Azure Databricks 的访问进行身份验证。 记下服务主体的 UUID应用程序 ID 值,以及服务主体的OAuth 机密的机密值。

  2. 授予该服务主体对群集或仓库的访问权限。

    若要向服务主体授予访问群集或仓库的权限,请参阅计算权限管理 SQL 仓库

要使用 DSN 连接字符串和使用 DSN 连接字符串进行连接中的代码示例对 Databricks SQL Driver for Go 进行身份验证,请使用以下 DSN 连接字符串语法,其中:

  • <server-hostname> 是要求中的“服务器主机名”值。
  • <port-number> 是要求中的端口值,通常是 443
  • <http-path> 是要求中的“HTTP 路径”值。
  • <client-id> 为服务主体的 UUID应用程序 ID 值。
  • <client-secret> 是服务主体的 OAuth 机密的机密值。

还可以追加本文前面列出的一个或多个可选参数

<server-hostname>:<port-number>/<http-path>?authType=OAuthM2M&clientID=<client-id>&clientSecret=<client-secret>

要使用 NewConnector 函数对 Databricks SQL Driver for Go 进行身份验证,必须先将以下内容添加到 import 声明中:

"github.com/databricks/databricks-sql-go/auth/oauth/m2m"

然后使用以下代码片段和使用 NewConnector 函数进行连接中的代码示例,它假定你已设置以下环境变量:

  • DATABRICKS_SERVER_HOSTNAME,设置为你的群集或 SQL 仓库的服务器主机名值。
  • DATABRICKS_HTTP_PATH,设置为你的群集或 SQL 仓库的 HTTP 路径值。
  • DATABRICKS_CLIENT_ID,设置为服务主体的 UUID应用程序 ID 值。
  • DATABRICKS_CLIENT_SECRET,设置为服务主体的 OAuth 机密的机密值。

若要设置环境变量,请参阅操作系统对应的文档。

authenticator := m2m.NewAuthenticator(
  os.Getenv("DATABRICKS_CLIENT_ID"),
  os.Getenv("DATABRICKS_CLIENT_SECRET"),
  os.Getenv("DATABRICKS_SERVER_HOSTNAME"),
)

connector, err := dbsql.NewConnector(
  dbsql.WithServerHostname(os.Getenv("DATABRICKS_SERVER_HOSTNAME")),
  dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTP_PATH")),
  dbsql.WithPort(443),
  dbsql.WithAuthenticator(authenticator),
)

查询数据

下面的代码示例演示如何调用 Databricks SQL Driver for Go 以在 Azure Databricks 计算资源上运行基本 SQL 查询。 此命令返回 samples 目录的 nyctaxi 架构中的trips 表中的前两行。

此代码示例从名为 DATABRICKS_DSN 的环境变量中检索 DSN 连接字符串

package main

import (
  "database/sql"
  "fmt"
  "os"
  "time"

  _ "github.com/databricks/databricks-sql-go"
)

func main() {
  dsn := os.Getenv("DATABRICKS_DSN")

  if dsn == "" {
    panic("No connection string found." +
          "Set the DATABRICKS_DSN environment variable, and try again.")
  }

  db, err := sql.Open("databricks", dsn)
  if err != nil {
    panic(err)
  }

  defer db.Close()

  var (
    tpep_pickup_datetime  time.Time
    tpep_dropoff_datetime time.Time
    trip_distance         float64
    fare_amount           float64
    pickup_zip            int
    dropoff_zip           int
  )

  rows, err := db.Query("SELECT * FROM samples.nyctaxi.trips LIMIT 2")
  if err != nil {
    panic(err)
  }

  defer rows.Close()

  fmt.Print("tpep_pickup_datetime,",
    "tpep_dropoff_datetime,",
    "trip_distance,",
    "fare_amount,",
    "pickup_zip,",
    "dropoff_zip\n")

  for rows.Next() {
    err := rows.Scan(&tpep_pickup_datetime,
      &tpep_dropoff_datetime,
      &trip_distance,
      &fare_amount,
      &pickup_zip,
      &dropoff_zip)
    if err != nil {
      panic(err)
    }

    fmt.Print(tpep_pickup_datetime, ",",
      tpep_dropoff_datetime, ",",
      trip_distance, ",",
      fare_amount, ",",
      pickup_zip, ",",
      dropoff_zip, "\n")
  }

  err = rows.Err()
  if err != nil {
    panic(err)
  }
}

管理 Unity Catalog 卷中的文件

Databricks SQL Driver 可让你将本地文件写入 Unity Catalog 、从卷下载文件以及从卷中删除文件,如下例所示:

package main

import (
  "context"
  "database/sql"
  "os"

  _ "github.com/databricks/databricks-sql-go"
  "github.com/databricks/databricks-sql-go/driverctx"
)

func main() {
  dsn := os.Getenv("DATABRICKS_DSN")

  if dsn == "" {
    panic("No connection string found." +
      "Set the DATABRICKS_DSN environment variable, and try again.")
  }

  db, err := sql.Open("databricks", dsn)
  if err != nil {
    panic(err)
  }
  defer db.Close()

  // For writing local files to volumes and downloading files from volumes,
  // you must first specify the path to the local folder that contains the
  // files to be written or downloaded.
  // For multiple folders, add their paths to the following string array.
  // For deleting files in volumes, this string array is ignored but must
  // still be provided, so in that case its value can be set for example
  // to an empty string.
  ctx := driverctx.NewContextWithStagingInfo(
    context.Background(),
    []string{"/tmp/"},
  )

  // Write a local file to the path in the specified volume.
  // Specify OVERWRITE to overwrite any existing file in that path.
  db.ExecContext(ctx, "PUT '/tmp/my-data.csv' INTO '/Volumes/main/default/my-volume/my-data.csv' OVERWRITE")

  // Download a file from the path in the specified volume.
  db.ExecContext(ctx, "GET '/Volumes/main/default/my-volume/my-data.csv' TO '/tmp/my-downloaded-data.csv'")

  // Delete a file from the path in the specified volume.
  db.ExecContext(ctx, "REMOVE '/Volumes/main/default/my-volume/my-data.csv'")

  db.Close()
}

日志记录

使用 github.com/databricks/databricks-sql-go/logger 记录 Databricks SQL Driver for Go 发出的消息。 以下代码示例使用 sql.Open() 通过 DSN 连接字符串创建数据库句柄。 此代码示例从名为 DATABRICKS_DSN 的环境变量中检索 DSN 连接字符串。 在 debug 级别及以下级别发出的所有日志消息都会写入 results.log 文件。

package main

import (
  "database/sql"
  "io"
  "log"
  "os"

  _ "github.com/databricks/databricks-sql-go"
  dbsqllog "github.com/databricks/databricks-sql-go/logger"
)

func main() {
  dsn := os.Getenv("DATABRICKS_DSN")

  // Use the specified file for logging messages to.
  file, err := os.Create("results.log")
  if err != nil {
    log.Fatal(err)
  }
  defer file.Close()

  writer := io.Writer(file)

  // Log messages at the debug level and below.
  if err := dbsqllog.SetLogLevel("debug"); err != nil {
    log.Fatal(err)
  }

  // Log messages to the file.
  dbsqllog.SetLogOutput(writer)

  if dsn == "" {
    panic("Error: Cannot connect. No connection string found. " +
      "Set the DATABRICKS_DSN environment variable, and try again.")
  }

  db, err := sql.Open("databricks", dsn)
  if err != nil {
    panic(err)
  }
  defer db.Close()

  if err := db.Ping(); err != nil {
    panic(err)
  }
}

测试

若要测试代码,请使用 Go 测试框架,例如测试标准库。 若要在不调用 Azure Databricks REST API 终结点或更改 Azure Databricks 帐户或工作区的状态的情况下在模拟条件下测试代码,可以使用 Go 模拟库(如 testfify)。

例如,给定以下名为 helpers.go 的文件,它包含返回 Azure Databricks 工作区连接的 GetDBWithDSNPAT 函数、从 samples 目录的 nyctaxi 架构中的 trips 表返回数据的 GetNYCTaxiTrips 函数,以及输出返回的数据的 PrintNYCTaxiTrips

package main

import (
  "database/sql"
  "fmt"
  "strconv"
  "time"
)

func GetDBWithDSNPAT(dsn string) (*sql.DB, error) {
  db, err := sql.Open("databricks", dsn)
  if err != nil {
    return nil, err
  }
  return db, nil
}

func GetNYCTaxiTrips(db *sql.DB, numRows int) (*sql.Rows, error) {
  rows, err := db.Query("SELECT * FROM samples.nyctaxi.trips LIMIT " + strconv.Itoa(numRows))
  if err != nil {
    return nil, err
  }
  return rows, nil
}

func PrintNYCTaxiTrips(rows *sql.Rows) {
  var (
    tpep_pickup_datetime  time.Time
    tpep_dropoff_datetime time.Time
    trip_distance         float64
    fare_amount           float64
    pickup_zip            int
    dropoff_zip           int
  )

  fmt.Print(
    "tpep_pickup_datetime,",
    "tpep_dropoff_datetime,",
    "trip_distance,",
    "fare_amount,",
    "pickup_zip,",
    "dropoff_zip\n",
  )

  for rows.Next() {
    err := rows.Scan(
      &tpep_pickup_datetime,
      &tpep_dropoff_datetime,
      &trip_distance,
      &fare_amount,
      &pickup_zip,
      &dropoff_zip,
    )
    if err != nil {
      panic(err)
    }

    fmt.Print(
      tpep_pickup_datetime, ",",
      tpep_dropoff_datetime, ",",
      trip_distance, ",",
      fare_amount, ",",
      pickup_zip, ",",
      dropoff_zip, "\n",
    )
  }

  err := rows.Err()
  if err != nil {
    panic(err)
  }
}

给定以下名为 main.go 的文件,用于调用以下函数:

package main

import (
  "os"
)

func main() {
  db, err := GetDBWithDSNPAT(os.Getenv("DATABRICKS_DSN"))
  if err != nil {
    panic(err)
  }

  rows, err := GetNYCTaxiTrips(db, 2)
  if err != nil {
    panic(err)
  }

  PrintNYCTaxiTrips(rows)
}

以下名为 helpers_test.go 的文件测试 GetNYCTaxiTrips 函数是否返回预期的响应。 此测试将模拟 sql.DB 对象,而不是创建与目标工作区的真实连接。 该测试还模拟一些符合真实数据中的架构和值的数据。 该测试通过模拟连接返回模拟数据,然后检查其中一个模拟数据行的值是否与预期值匹配。

package main

import (
  "database/sql"
  "testing"

  "github.com/stretchr/testify/assert"
  "github.com/stretchr/testify/mock"
)

// Define an interface that contains a method with the same signature
// as the real GetNYCTaxiTrips function that you want to test.
type MockGetNYCTaxiTrips interface {
  GetNYCTaxiTrips(db *sql.DB, numRows int) (*sql.Rows, error)
}

// Define a struct that represents the receiver of the interface's method
// that you want to test.
type MockGetNYCTaxiTripsObj struct {
  mock.Mock
}

// Define the behavior of the interface's method that you want to test.
func (m *MockGetNYCTaxiTripsObj) GetNYCTaxiTrips(db *sql.DB, numRows int) (*sql.Rows, error) {
  args := m.Called(db, numRows)
  return args.Get(0).(*sql.Rows), args.Error(1)
}

func TestGetNYCTaxiTrips(t *testing.T) {
  // Instantiate the receiver.
  mockGetNYCTaxiTripsObj := new(MockGetNYCTaxiTripsObj)

  // Define how the mock function should be called and what it should return.
  // We're not concerned with whether the actual database is connected to--just
  // what is returned.
  mockGetNYCTaxiTripsObj.On("GetNYCTaxiTrips", mock.Anything, mock.AnythingOfType("int")).Return(&sql.Rows{}, nil)

  // Call the mock function that you want to test.
  rows, err := mockGetNYCTaxiTripsObj.GetNYCTaxiTrips(nil, 2)

  // Assert that the mock function was called as expected.
  mockGetNYCTaxiTripsObj.AssertExpectations(t)

  // Assert that the mock function returned what you expected.
  assert.NotNil(t, rows)
  assert.Nil(t, err)
}

GetNYCTaxiTrips 函数包含 SELECT 语句,因此不会更改 trips 表的状态,在此示例中并不是一定需要模拟。 但是,模拟让你能够快速运行测试,而无需等待与工作区建立实际连接。 此外,通过模拟,可以多次针对可能更改表状态的函数运行模拟测试,例如 INSERT INTOUPDATEDELETE FROM

其他资源