chore: use signal.signal and add test
Signed-off-by: Jesús Fernández <7312236+fernandezcuesta@users.noreply.github.com>
This commit is contained in:
parent
869f7d4439
commit
50b187cae8
|
@ -45,8 +45,8 @@ def update(r: fnv1.Resource, source: dict | structpb.Struct | pydantic.BaseModel
|
|||
# apiVersion is set to its default value 's3.aws.upbound.io/v1beta2'
|
||||
# (and not explicitly provided during initialization), it will be
|
||||
# excluded from the serialized output.
|
||||
data['apiVersion'] = source.apiVersion
|
||||
data['kind'] = source.kind
|
||||
data["apiVersion"] = source.apiVersion
|
||||
data["kind"] = source.kind
|
||||
r.resource.update(data)
|
||||
case structpb.Struct():
|
||||
# TODO(negz): Use struct_to_dict and update to match other semantics?
|
||||
|
|
|
@ -66,8 +66,8 @@ def load_credentials(tls_certs_dir: str) -> grpc.ServerCredentials:
|
|||
)
|
||||
|
||||
|
||||
async def _stop(server, timeout): # noqa: ASYNC109
|
||||
await server.stop(grace=timeout)
|
||||
async def _stop(server, grace=GRACE_PERIOD):
|
||||
await server.stop(grace=grace)
|
||||
|
||||
|
||||
def serve(
|
||||
|
@ -96,8 +96,9 @@ def serve(
|
|||
|
||||
server = grpc.aio.server()
|
||||
|
||||
loop.add_signal_handler(
|
||||
signal.SIGTERM, lambda: asyncio.create_task(_stop(server, timeout=GRACE_PERIOD))
|
||||
signal.signal(
|
||||
signal.SIGTERM,
|
||||
lambda _, __: asyncio.create_task(_stop(server)),
|
||||
)
|
||||
|
||||
grpcv1.add_FunctionRunnerServiceServicer_to_server(function, server)
|
||||
|
@ -126,7 +127,8 @@ def serve(
|
|||
try:
|
||||
loop.run_until_complete(start())
|
||||
finally:
|
||||
loop.run_until_complete(server.stop(grace=GRACE_PERIOD))
|
||||
if server._server.is_running():
|
||||
loop.run_until_complete(server.stop(grace=GRACE_PERIOD))
|
||||
loop.close()
|
||||
|
||||
|
||||
|
|
|
@ -12,7 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import os
|
||||
import signal
|
||||
import unittest
|
||||
|
||||
import grpc
|
||||
|
@ -52,6 +55,25 @@ class TestRuntime(unittest.IsolatedAsyncioTestCase):
|
|||
|
||||
self.assertEqual(rsp, case.want, "-want, +got")
|
||||
|
||||
async def test_sigterm_handling(self) -> None:
|
||||
async def mock_server():
|
||||
await server.start()
|
||||
await asyncio.sleep(1)
|
||||
self.assertTrue(server._server.is_running(), "Server should be running")
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
await server.wait_for_termination()
|
||||
self.assertFalse(
|
||||
server._server.is_running(),
|
||||
"Server should have been stopped on SIGTERM",
|
||||
)
|
||||
|
||||
server = grpc.aio.server()
|
||||
signal.signal(
|
||||
signal.SIGTERM,
|
||||
lambda _, __: asyncio.create_task(runtime._stop(server)),
|
||||
)
|
||||
await mock_server()
|
||||
|
||||
|
||||
class EchoRunner(grpcv1.FunctionRunnerService):
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue