diff --git a/examples/sdk/sqa/with_interactive_sql_executor.go b/examples/sdk/sqa/with_interactive_sql_executor.go new file mode 100644 index 0000000..78db031 --- /dev/null +++ b/examples/sdk/sqa/with_interactive_sql_executor.go @@ -0,0 +1,60 @@ +package main + +import ( + "fmt" + "log" + "os" + + "github.com/aliyun/aliyun-odps-go-sdk/odps" + "github.com/aliyun/aliyun-odps-go-sdk/odps/account" + "github.com/aliyun/aliyun-odps-go-sdk/odps/sqa" +) + +func main() { + conf, err := odps.NewConfigFromIni(os.Args[1]) + if err != nil { + log.Fatalf("%+v", err) + } + + aliAccount := account.NewAliyunAccount(conf.AccessId, conf.AccessKey) + odpsIns := odps.NewOdps(aliAccount, conf.Endpoint) + odpsIns.SetDefaultProjectName(conf.ProjectName) + sql := `select * from all_types_demo_no_parition;` + // + paramInfo := sqa.SQLExecutorQueryParam{ + OdpsIns: odpsIns, + TaskName: "test_mcqa", + ServiceName: sqa.DEFAULT_SERVICE, + RunningCluster: "", + } + ie := sqa.NewInteractiveSQLExecutor(¶mInfo) + err = ie.Run(sql, nil) + if err != nil { + log.Fatalf("%+v", err) + } + // + records, err := ie.GetResult(0, 10, 10000, true) + if err != nil { + log.Fatalf("%+v", err) + } + // + for _, record := range records { + for i, d := range record { + if d == nil { + fmt.Printf("null") + } else { + fmt.Printf("%s", d.Sql()) + } + + if i < record.Len()-1 { + fmt.Printf(", ") + } else { + fmt.Println() + } + } + } + err = ie.Close() + if err != nil { + log.Fatalf("%+v", err) + } +} diff --git a/odps/instance.go b/odps/instance.go index ea21f69..f13bf8b 100644 --- a/odps/instance.go +++ b/odps/instance.go @@ -19,13 +19,15 @@ package odps import ( "encoding/json" "encoding/xml" - "github.com/aliyun/aliyun-odps-go-sdk/odps/common" - "github.com/pkg/errors" + "io" "io/ioutil" "net/http" "net/url" "strings" "time" + + "github.com/aliyun/aliyun-odps-go-sdk/odps/common" + "github.com/pkg/errors" ) type InstanceStatus int @@ -270,6 +272,10 @@ func (instance *Instance) Id() string { return instance.id } +func (instance *Instance) ResourceUrl() string { + return instance.resourceUrl +} + func (instance *Instance) Owner() string { return instance.owner } @@ -358,6 +364,73 @@ func (instance *Instance) GetResult() ([]TaskResult, error) { return resModel.Tasks, nil } +type UpdateInfoResult struct { + Result string `json:"result"` + Status string `json:"status"` +} + +// UpdateInfo set information to running instance +func (instance *Instance) UpdateInfo(taskName, infoKey, infoValue string) (UpdateInfoResult, error) { + // instance set information + queryArgs := make(url.Values, 2) + queryArgs.Set("info", "") + queryArgs.Set("taskname", taskName) + // + instanceTaskInfoModel := struct { + XMLName xml.Name `xml:"Instance"` + Key string `xml:"Key"` + Value string `xml:"Value"` + }{ + Key: infoKey, + Value: infoValue, + } + // + var res UpdateInfoResult + client := instance.odpsIns.RestClient() + err := client.DoXmlWithParseRes( + common.HttpMethod.PutMethod, + instance.resourceUrl, + queryArgs, + instanceTaskInfoModel, + func(httpRes *http.Response) error { + err := json.NewDecoder(httpRes.Body).Decode(&res) + if err != nil { + return errors.Wrapf(err, "Parse http response failed, body: %+v", httpRes.Body) + } + return nil + }, + ) + // + return res, err +} + +// GetTaskInfo get the specific info of a task in the instance +func (instance *Instance) GetTaskInfo(taskName, infoKey string) (string, error) { + queryArgs := make(url.Values, 3) + queryArgs.Set("info", "") + queryArgs.Set("taskname", taskName) + queryArgs.Set("key", infoKey) + + client := instance.odpsIns.RestClient() + var bodyStr string + err := client.DoXmlWithParseRes( + common.HttpMethod.GetMethod, + instance.resourceUrl, + queryArgs, + nil, + func(httpRes *http.Response) error { + body, err := io.ReadAll(httpRes.Body) + if err != nil { + return errors.Wrapf(err, "Parse http response failed, body: %+v", httpRes.Body) + } + bodyStr = string(body) + return nil + }, + ) + + return bodyStr, err +} + func InstancesStatusFromStr(s string) InstanceStatus { switch strings.ToLower(s) { case "running": diff --git a/odps/sqa/sql_executor.go b/odps/sqa/sql_executor.go new file mode 100644 index 0000000..f99b860 --- /dev/null +++ b/odps/sqa/sql_executor.go @@ -0,0 +1,293 @@ +package sqa + +import ( + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + "time" + + "github.com/aliyun/aliyun-odps-go-sdk/odps" + "github.com/aliyun/aliyun-odps-go-sdk/odps/data" + tunnel2 "github.com/aliyun/aliyun-odps-go-sdk/odps/tunnel" + "github.com/google/uuid" + "github.com/pkg/errors" +) + +const ( + DEFAULT_TASK_NAME = "sqlrt_task" + DEFAULT_SERVICE = "public.default" + + SESSION_TIMEOUT = 60 + + OBJECT_STATUS_RUNNING = 2 + OBJECT_STATUS_FAILED = 4 + OBJECT_STATUS_TERMINATED = 5 + OBJECT_STATUS_CANCELLED = 6 +) + +type SQLExecutor interface { + Close() + Cancel() + Run(sql string, hints map[string]string) + GetResult() ([]data.Record, error) +} + +type InteractiveSQLExecutor struct { + odpsIns *odps.Odps + taskName string + serviceName string + hints map[string]string // for create instance + queryHints map[string]string // for query + id string + sql string + runningCluster string + instance *odps.Instance + subQueryInfo *SubQueryInfo +} + +type SQLExecutorQueryParam struct { + OdpsIns *odps.Odps + TaskName string + ServiceName string + RunningCluster string + Hints map[string]string +} + +func NewInteractiveSQLExecutor(params *SQLExecutorQueryParam) *InteractiveSQLExecutor { + id := uuid.New().String() + + hints := params.Hints + if hints == nil { + hints = make(map[string]string) + } + + return &InteractiveSQLExecutor{ + odpsIns: params.OdpsIns, + id: id, + taskName: params.TaskName, + serviceName: params.ServiceName, + runningCluster: params.RunningCluster, + hints: hints, + } +} + +// Run submit a query to instance +func (ie *InteractiveSQLExecutor) Run(sql string, queryHints map[string]string) error { + // + if queryHints == nil { + queryHints = make(map[string]string) + } + // init InteractiveSQLExecutor + ie.sql = sql + ie.queryHints = queryHints + if ie.queryHints == nil { + ie.queryHints = make(map[string]string) + } + // + var err error + ie.instance, err = ie.createInstance() + if err != nil { + return errors.Wrapf(err, "Get error when creating instance") + } + // wait for attach success + err = ie.waitAttachSuccess(SESSION_TIMEOUT) + if err != nil { + return err + } + // + err = ie.runQueryInternal() + if err != nil { + return errors.Wrapf(err, "Get error when creating running query: %v", ie.sql) + } + return nil +} + +func (ie *InteractiveSQLExecutor) createInstance() (*odps.Instance, error) { + if ie.serviceName != "" { + ie.hints["odps.sql.session.share.id"] = ie.serviceName + ie.hints["odps.sql.session.name"] = strings.TrimSpace(ie.serviceName) + } else { + return nil, errors.New("service name cannot be empty.") + } + + if ie.taskName == "" { + ie.taskName = DEFAULT_TASK_NAME + } + // + projectName := ie.odpsIns.DefaultProjectName() + // change "odps.sql.submit.mode" flag + userSubmitMode, ok := ie.hints["odps.sql.submit.mode"] + ie.hints["odps.sql.submit.mode"] = "script" + // + task := odps.NewSqlRTTask(ie.taskName, "", ie.hints) + + if ok { + ie.hints["odps.sql.submit.mode"] = userSubmitMode + } + // + instances := odps.NewInstances(ie.odpsIns, projectName) + instance, err := instances.CreateTask(projectName, &task) + if err != nil { + return nil, err + } + return instance, err +} + +type SubQueryResponse struct { + Status int + Result string + warnings string + SubQueryId int +} + +func (ie *InteractiveSQLExecutor) waitAttachSuccess(timeout int64) error { + if timeout < 1 { + timeout = SESSION_TIMEOUT + } + // + start := time.Now() + end := start.Add(time.Second * time.Duration(timeout)) + // + for time.Now().Before(end) { + infoStr, err := ie.instance.GetTaskInfo(ie.taskName, "wait_attach_success") + // + var subQueryResp SubQueryResponse + _ = json.Unmarshal([]byte(infoStr), &subQueryResp) + if err != nil || subQueryResp.Status == 0 { + // check task status + tasks, err := ie.instance.GetTasks() + if err != nil { + return err + } + // + var status odps.TaskStatus + for _, task := range tasks { + if task.Name == ie.taskName { + status = task.Status + break + } + } + if status != odps.TaskRunning { + return errors.New(fmt.Sprintf("instance id: %v, task name: %s, status: %v", + ie.instance.Id(), ie.taskName, status.String())) + } + } else if subQueryResp.Status == OBJECT_STATUS_FAILED || subQueryResp.Status == OBJECT_STATUS_TERMINATED { + return errors.New(fmt.Sprintf("attach instance [id: %v] failed, %s", ie.instance.Id(), subQueryResp.Result)) + } + // running + return nil + } + // + _ = ie.instance.Terminate() + return errors.New(fmt.Sprintf("attach instance [id: %v] timeout.", ie.instance.Id())) +} + +type SubQueryInfo struct { + QueryId int `json:"queryId"` + Status string `json:"status"` + Result string `json:"result"` +} + +func (ie *InteractiveSQLExecutor) runQueryInternal() error { + request := make(map[string]interface{}) + // + request["query"] = ie.sql + if ie.hints == nil { + ie.hints = make(map[string]string) + } + request["settings"] = ie.queryHints + requestJson, _ := json.Marshal(request) + // instance set information + res, err := ie.instance.UpdateInfo(ie.taskName, "query", string(requestJson)) + if err != nil { + return err + } + // + var subQueryInfo SubQueryInfo + if res.Status != "ok" { + subQueryInfo.Status = res.Status + subQueryInfo.Result = res.Result + } else if res.Result != "" { + err = json.Unmarshal([]byte(res.Result), &subQueryInfo) + if err != nil { + return errors.Wrapf(err, "%+v", res.Result) + } + } else { + return errors.Errorf("Invalid result: %+v", res) + } + ie.subQueryInfo = &subQueryInfo + // + return nil +} + +// GetResult get query result by instance tunnel +func (ie *InteractiveSQLExecutor) GetResult(offset, countLimit, sizeLimit int, limitEnabled bool) ([]data.Record, error) { + // + ds, err := ie.GetDownloadSession(limitEnabled) + if err != nil { + return nil, err + } + // + reader, err := ds.OpenRecordReader(offset, countLimit, sizeLimit, []string{}) + if err != nil { + return nil, err + } + // + results := make([]data.Record, 0, countLimit) + for { + record, err := reader.Read() + if err != nil { + isEOF := errors.Is(err, io.EOF) + if isEOF { + break + } + return nil, err + } + results = append(results, record) + } + // + return results, nil +} + +func (ie *InteractiveSQLExecutor) GetDownloadSession(limitEnabled bool) (*tunnel2.InstanceResultDownloadSession, error) { + if ie.instance == nil { + return nil, errors.New("InteractiveSQLExecutor.instance is nil, please create instance first") + } + // + projects := ie.odpsIns.Projects() + project := projects.Get(ie.instance.ProjectName()) + tunnelEndpoint, err := project.GetTunnelEndpoint() + if err != nil { + return nil, err + } + tunnel := tunnel2.NewTunnel(ie.odpsIns, tunnelEndpoint) + // + opts := make([]tunnel2.InstanceOption, 0) + opts = append(opts, tunnel2.InstanceSessionCfg.WithTaskName(ie.taskName)) + opts = append(opts, tunnel2.InstanceSessionCfg.WithQueryId(ie.subQueryInfo.QueryId)) + if limitEnabled { + opts = append(opts, tunnel2.InstanceSessionCfg.EnableLimit()) + } + + // + return tunnel.CreateInstanceResultDownloadSession(project.Name(), ie.instance.Id(), opts...) +} + +func (ie *InteractiveSQLExecutor) Close() error { + return ie.instance.Terminate() +} + +func (ie *InteractiveSQLExecutor) Cancel() error { + updateInfoResult, err := ie.instance.UpdateInfo(ie.taskName, "cancel", strconv.Itoa(ie.subQueryInfo.QueryId)) + if err != nil { + return err + } + // + if updateInfoResult.Status != "ok" { + return errors.New(fmt.Sprintf("cancel failed, message: %s", updateInfoResult.Result)) + } + + return nil +} diff --git a/odps/sql_rt_task.go b/odps/sql_rt_task.go new file mode 100644 index 0000000..7c2dc23 --- /dev/null +++ b/odps/sql_rt_task.go @@ -0,0 +1,18 @@ +package odps + +import "encoding/xml" + +type SQLRTTask struct { + XMLName xml.Name `xml:"SQLRT"` + BaseTask +} + +func (t *SQLRTTask) TaskType() string { + return "SQLRT" +} + +func NewSqlRTTask(name string, comment string, hints map[string]string) SQLRTTask { + return SQLRTTask{ + BaseTask: newBaseTask(name, comment, hints), + } +} diff --git a/odps/sql_task.go b/odps/sql_task.go index 4930c1c..9b86109 100644 --- a/odps/sql_task.go +++ b/odps/sql_task.go @@ -20,9 +20,10 @@ import ( "encoding/csv" "encoding/json" "encoding/xml" + "strings" + "github.com/aliyun/aliyun-odps-go-sdk/odps/common" "github.com/pkg/errors" - "strings" ) type SQLTask struct { diff --git a/odps/task.go b/odps/task.go index 50c6a81..7baa2cc 100644 --- a/odps/task.go +++ b/odps/task.go @@ -19,6 +19,7 @@ package odps import ( "encoding/json" "encoding/xml" + "github.com/aliyun/aliyun-odps-go-sdk/odps/common" ) @@ -37,13 +38,33 @@ func (n TaskName) GetName() string { // TaskConfig 作为embedding filed使用时,使用者自动实现Task接口的AddProperty方法 type TaskConfig struct { - Config []common.Property `xml:"Config>Property"` + Config []common.Property `xml:"Config>Property,omitempty"` } func (t *TaskConfig) AddProperty(key, value string) { t.Config = append(t.Config, common.Property{Name: key, Value: value}) } +type BaseTask struct { + TaskName `xml:"Name"` + TaskConfig + Comment string `xml:"Comment,omitempty"` +} + +func newBaseTask(name string, comment string, hints map[string]string) BaseTask { + baseTask := BaseTask{ + TaskName: TaskName(name), + Comment: comment, + } + + if hints != nil { + hintsJson, _ := json.Marshal(hints) + baseTask.Config = append(baseTask.Config, common.Property{Name: "settings", Value: string(hintsJson)}) + } + + return baseTask +} + type SQLCostTask struct { XMLName xml.Name `xml:"SQLCost"` SQLTask @@ -78,15 +99,6 @@ func (t *SQLPlanTask) TaskType() string { return "SQLPlan" } -type SQLRTTask struct { - XMLName xml.Name `xml:"SQLRT"` - SQLTask -} - -func (t *SQLRTTask) TaskType() string { - return "SQLRT" -} - type MergeTask struct { XMLName xml.Name `xml:"Merge"` TaskName `xml:"Name"` diff --git a/odps/tunnel/instance_result_download_session.go b/odps/tunnel/instance_result_download_session.go index a98bebd..0ebdd24 100644 --- a/odps/tunnel/instance_result_download_session.go +++ b/odps/tunnel/instance_result_download_session.go @@ -65,18 +65,19 @@ func CreateInstanceResultDownloadSession( Compressor: cfg.Compressor, } + // long pooling session if cfg.QueryId != -1 { session.IsLongPolling = true - } - - req, err := session.newInitiationRequest() - if err != nil { - return nil, errors.WithStack(err) - } + } else { + req, err := session.newInitiationRequest() + if err != nil { + return nil, errors.WithStack(err) + } - err = session.loadInformation(req) - if err != nil { - return nil, errors.WithStack(err) + err = session.loadInformation(req) + if err != nil { + return nil, errors.WithStack(err) + } } return &session, nil @@ -145,6 +146,11 @@ func (is *InstanceResultDownloadSession) OpenRecordReader( start, count, sizeLimit int, columnNames []string, ) (*RecordProtocReader, error) { + res, err := is.newDownloadConnection(start, count, sizeLimit, columnNames) + if err != nil { + return nil, errors.WithStack(err) + } + if len(columnNames) == 0 { columnNames = make([]string, len(is.schema.Columns)) for i, c := range is.schema.Columns { @@ -162,11 +168,6 @@ func (is *InstanceResultDownloadSession) OpenRecordReader( columns[i] = c } - res, err := is.newDownloadConnection(start, count, sizeLimit, columnNames) - if err != nil { - return nil, errors.WithStack(err) - } - reader := newRecordProtocReader(res, columns, is.shouldTransformDate) return &reader, nil } @@ -182,7 +183,7 @@ func (is *InstanceResultDownloadSession) newInitiationRequest() (*http.Request, if is.TaskName != "" { queryArgs.Set("cached", "") - queryArgs.Set("taskname", "") + queryArgs.Set("taskname", is.TaskName) if is.QueryId != -1 { queryArgs.Set("queryid", strconv.Itoa(is.QueryId)) @@ -267,10 +268,15 @@ func (is *InstanceResultDownloadSession) newDownloadConnection( queryArgs.Set("instance_tunnel_limit_enabled", "") } + if is.QuotaName != "" { + queryArgs.Set("quotaName", is.QuotaName) + } + queryArgs.Set("data", "") if is.IsLongPolling { + queryArgs.Set("schema_in_stream", "") queryArgs.Set("cached", "") - queryArgs.Set("taskname", "") + queryArgs.Set("taskname", is.TaskName) if is.QueryId != -1 { queryArgs.Set("queryid", strconv.Itoa(is.QueryId)) @@ -322,5 +328,44 @@ func (is *InstanceResultDownloadSession) newDownloadConnection( res.Body = WrapByCompressor(res.Body, contentEncoding) } + if is.IsLongPolling { + err = is.procLongPollingResp(res) + if err != nil { + return nil, err + } + } + return res, nil } + +func (is *InstanceResultDownloadSession) procLongPollingResp(res *http.Response) error { + // + if res.Header.Get("odps-tunnel-record-count") != "" { + recordCount, err := strconv.ParseInt(res.Header.Get("odps-tunnel-record-count"), 10, 32) + if err != nil { + return errors.WithStack(err) + } + is.recordCount = int(recordCount) + } + // + var tableSchema tableschema.TableSchema + if res.Header.Get("odps-tunnel-schema") != "" { + schemaStr := res.Header.Get("odps-tunnel-schema") + + err := json.Unmarshal([]byte(schemaStr), &tableSchema) + if err != nil { + return errors.WithStack(err) + } + } else { + reader := newRecordProtocReader(res, nil, false) + tableSchemaPtr, err := reader.readTableSchema() + if err != nil { + return err + } + tableSchema = *tableSchemaPtr + } + // + is.schema = tableSchema + + return nil +} diff --git a/odps/tunnel/protoc_common.go b/odps/tunnel/protoc_common.go index 9497ed0..abbbd38 100644 --- a/odps/tunnel/protoc_common.go +++ b/odps/tunnel/protoc_common.go @@ -17,8 +17,9 @@ package tunnel import ( - "google.golang.org/protobuf/encoding/protowire" "time" + + "google.golang.org/protobuf/encoding/protowire" ) var epochDay time.Time @@ -31,4 +32,5 @@ const ( MetaCount = protowire.Number(33554430) // magic num 2^25-2 MetaChecksum = protowire.Number(33554431) // magic num 2^25-1 EndRecord = protowire.Number(33553408) // magic num 2^25-1024 + SchemaEndTag = protowire.Number(33553920) // maigc num 2^25-512 ) diff --git a/odps/tunnel/protoc_reader_test.go b/odps/tunnel/protoc_reader_test.go index 765ea45..2d14a8a 100644 --- a/odps/tunnel/protoc_reader_test.go +++ b/odps/tunnel/protoc_reader_test.go @@ -18,11 +18,12 @@ package tunnel import ( "bytes" + "strings" + "testing" + "github.com/aliyun/aliyun-odps-go-sdk/odps/data" "github.com/aliyun/aliyun-odps-go-sdk/odps/datatype" "github.com/aliyun/aliyun-odps-go-sdk/odps/tableschema" - "strings" - "testing" ) var structTypeProtocData = []byte{ diff --git a/odps/tunnel/record_protoc_reader.go b/odps/tunnel/record_protoc_reader.go index e18dfa9..0d24329 100644 --- a/odps/tunnel/record_protoc_reader.go +++ b/odps/tunnel/record_protoc_reader.go @@ -17,6 +17,7 @@ package tunnel import ( + "encoding/json" "io" "net/http" "time" @@ -403,3 +404,42 @@ func (r *RecordProtocReader) readStruct(t datatype.StructType) (*data.Struct, er return sd, nil } + +func (r *RecordProtocReader) readTableSchema() (*tableschema.TableSchema, error) { + var schemaJson string + for { + tag, _, err := r.protocReader.ReadTag() + if err != nil { + return nil, err + } + // todo: validate crc failed, checking + if tag == SchemaEndTag { + crc := r.recordCrc.Value() + readUInt32, _ := r.protocReader.ReadUInt32() + if readUInt32 != crc { + return nil, errors.New("crc value is error") + } + r.recordCrc.Reset() + break + } + // + if tag > 1 { + return nil, errors.New("invalid tag") + } + // + r.recordCrc.Update(tag) + + v, err := r.protocReader.ReadBytes() + if err != nil { + return nil, errors.WithStack(err) + } + + r.recordCrc.Update(v) + schemaJson = string(v) + } + + var tableSchema tableschema.TableSchema + err := json.Unmarshal([]byte(schemaJson), &tableSchema) + + return &tableSchema, err +}