Skip to content

Commit 0a09bb3

Browse files
feat: implement Flow control batcher (#15)
* feat: Implement FlowControlBatcher This handles aggregating flow control requests without allowing them to get above the max int64 value. * Use correct request for comparison.
1 parent b2d0d36 commit 0a09bb3

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import NamedTuple, List, Optional
2+
3+
from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage
4+
5+
_EXPEDITE_BATCH_REQUEST_RATIO = 0.5
6+
_MAX_INT64 = 0x7FFFFFFFFFFFFFFF
7+
8+
9+
class _AggregateRequest:
10+
request: FlowControlRequest
11+
12+
def __init__(self):
13+
self.request = FlowControlRequest()
14+
15+
def __add__(self, other: FlowControlRequest):
16+
self.request.allowed_bytes += other.allowed_bytes
17+
self.request.allowed_bytes = min(self.request.allowed_bytes, _MAX_INT64)
18+
self.request.allowed_messages += other.allowed_messages
19+
self.request.allowed_messages = min(self.request.allowed_messages, _MAX_INT64)
20+
return self
21+
22+
23+
def _exceeds_expedite_ratio(pending: int, client: int):
24+
if client <= 0:
25+
return False
26+
return (pending/client) >= _EXPEDITE_BATCH_REQUEST_RATIO
27+
28+
29+
def _to_optional(req: FlowControlRequest) -> Optional[FlowControlRequest]:
30+
if req.allowed_messages == 0 and req.allowed_bytes == 0:
31+
return None
32+
return req
33+
34+
35+
class FlowControlBatcher:
36+
_client_tokens: _AggregateRequest
37+
_pending_tokens: _AggregateRequest
38+
39+
def __init__(self):
40+
self._client_tokens = _AggregateRequest()
41+
self._pending_tokens = _AggregateRequest()
42+
43+
def add(self, request: FlowControlRequest):
44+
self._client_tokens += request
45+
self._pending_tokens += request
46+
47+
def on_messages(self, messages: List[SequencedMessage]):
48+
byte_size = sum(message.size_bytes for message in messages)
49+
self._client_tokens += FlowControlRequest(allowed_bytes=-byte_size, allowed_messages=-len(messages))
50+
51+
def request_for_restart(self) -> Optional[FlowControlRequest]:
52+
self._pending_tokens = _AggregateRequest()
53+
return _to_optional(self._client_tokens.request)
54+
55+
def release_pending_request(self) -> Optional[FlowControlRequest]:
56+
request = self._pending_tokens.request
57+
self._pending_tokens = _AggregateRequest()
58+
return _to_optional(request)
59+
60+
def should_expedite(self):
61+
pending_request = self._pending_tokens.request
62+
client_request = self._client_tokens.request
63+
if _exceeds_expedite_ratio(pending_request.allowed_bytes, client_request.allowed_bytes):
64+
return True
65+
if _exceeds_expedite_ratio(pending_request.allowed_messages, client_request.allowed_messages):
66+
return True
67+
return False
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from google.cloud.pubsublite.internal.wire.flow_control_batcher import FlowControlBatcher
2+
from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage
3+
4+
5+
def test_restart_clears_send():
6+
batcher = FlowControlBatcher()
7+
batcher.add(FlowControlRequest(allowed_bytes=10, allowed_messages=3))
8+
assert batcher.should_expedite()
9+
to_send = batcher.release_pending_request()
10+
assert to_send.allowed_bytes == 10
11+
assert to_send.allowed_messages == 3
12+
restart_1 = batcher.request_for_restart()
13+
assert restart_1.allowed_bytes == 10
14+
assert restart_1.allowed_messages == 3
15+
assert not batcher.should_expedite()
16+
assert batcher.release_pending_request() is None
17+
18+
19+
def test_add_remove():
20+
batcher = FlowControlBatcher()
21+
batcher.add(FlowControlRequest(allowed_bytes=10, allowed_messages=3))
22+
restart_1 = batcher.request_for_restart()
23+
assert restart_1.allowed_bytes == 10
24+
assert restart_1.allowed_messages == 3
25+
batcher.on_messages([SequencedMessage(size_bytes=2), SequencedMessage(size_bytes=3)])
26+
restart_2 = batcher.request_for_restart()
27+
assert restart_2.allowed_bytes == 5
28+
assert restart_2.allowed_messages == 1

0 commit comments

Comments
 (0)