/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package tests import ( "context" "errors" "testing" "github.com/apache/thrift/lib/go/test/gopath/src/clientmiddlewareexceptiontest" "github.com/apache/thrift/lib/go/thrift" ) type fakeClientMiddlewareExceptionTestHandler func(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) func (f fakeClientMiddlewareExceptionTestHandler) Foo(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) { return f(ctx) } type clientMiddlewareErrorChecker func(err error) error var clientMiddlewareExceptionCases = []struct { label string handler fakeClientMiddlewareExceptionTestHandler checker clientMiddlewareErrorChecker }{ { label: "no-error", handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) { return new(clientmiddlewareexceptiontest.FooResponse), nil }, checker: func(err error) error { if err != nil { return errors.New("expected err to be nil") } return nil }, }, { label: "exception-1", handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) { return nil, new(clientmiddlewareexceptiontest.Exception1) }, checker: func(err error) error { if !errors.As(err, new(*clientmiddlewareexceptiontest.Exception1)) { return errors.New("expected err to be of type *clientmiddlewareexceptiontest.Exception1") } return nil }, }, { label: "no-error", handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) { return nil, new(clientmiddlewareexceptiontest.Exception2) }, checker: func(err error) error { if !errors.As(err, new(*clientmiddlewareexceptiontest.Exception2)) { return errors.New("expected err to be of type *clientmiddlewareexceptiontest.Exception2") } return nil }, }, } func TestClientMiddlewareException(t *testing.T) { for _, c := range clientMiddlewareExceptionCases { t.Run(c.label, func(t *testing.T) { serverSocket, err := thrift.NewTServerSocket(":0") if err != nil { t.Fatalf("failed to create server socket: %v", err) } processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(c.handler) server := thrift.NewTSimpleServer2(processor, serverSocket) if err := server.Listen(); err != nil { t.Fatalf("failed to listen server: %v", err) } addr := serverSocket.Addr().String() go server.Serve() t.Cleanup(func() { server.Stop() }) var cfg *thrift.TConfiguration socket := thrift.NewTSocketConf(addr, cfg) if err := socket.Open(); err != nil { t.Fatalf("failed to create client connection: %v", err) } t.Cleanup(func() { socket.Close() }) inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg) outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg) middleware := func(next thrift.TClient) thrift.TClient { return thrift.WrappedTClient{ Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) { defer func() { if checkErr := c.checker(err); checkErr != nil { t.Errorf("middleware result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err) } }() return next.Call(ctx, method, args, result) }, } } client := thrift.WrapClient( thrift.NewTStandardClient(inProtocol, outProtocol), middleware, thrift.ExtractIDLExceptionClientMiddleware, ) result, err := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(context.Background()) if checkErr := c.checker(err); checkErr != nil { t.Errorf("final result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err) } }) } } func TestExtractExceptionFromResult(t *testing.T) { for _, c := range clientMiddlewareExceptionCases { t.Run(c.label, func(t *testing.T) { serverSocket, err := thrift.NewTServerSocket(":0") if err != nil { t.Fatalf("failed to create server socket: %v", err) } processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(c.handler) server := thrift.NewTSimpleServer2(processor, serverSocket) if err := server.Listen(); err != nil { t.Fatalf("failed to listen server: %v", err) } addr := serverSocket.Addr().String() go server.Serve() t.Cleanup(func() { server.Stop() }) var cfg *thrift.TConfiguration socket := thrift.NewTSocketConf(addr, cfg) if err := socket.Open(); err != nil { t.Fatalf("failed to create client connection: %v", err) } t.Cleanup(func() { socket.Close() }) inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg) outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg) middleware := func(next thrift.TClient) thrift.TClient { return thrift.WrappedTClient{ Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) { defer func() { if err == nil { err = thrift.ExtractExceptionFromResult(result) } if checkErr := c.checker(err); checkErr != nil { t.Errorf("middleware result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err) } }() return next.Call(ctx, method, args, result) }, } } client := thrift.WrapClient( thrift.NewTStandardClient(inProtocol, outProtocol), middleware, ) result, err := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(context.Background()) if checkErr := c.checker(err); checkErr != nil { t.Errorf("final result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err) } }) } }