1# Copyright 2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import struct
20
21import pytest
22
23from bumble import core
24from bumble import device
25from bumble import host
26from bumble import controller
27from bumble import link
28from bumble import avc
29from bumble import avrcp
30from bumble import avctp
31from bumble.transport import common
32
33
34# -----------------------------------------------------------------------------
35class TwoDevices:
36    def __init__(self):
37        self.connections = [None, None]
38
39        addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
40        self.link = link.LocalLink()
41        self.controllers = [
42            controller.Controller('C1', link=self.link, public_address=addresses[0]),
43            controller.Controller('C2', link=self.link, public_address=addresses[1]),
44        ]
45        self.devices = [
46            device.Device(
47                address=addresses[0],
48                host=host.Host(
49                    self.controllers[0], common.AsyncPipeSink(self.controllers[0])
50                ),
51            ),
52            device.Device(
53                address=addresses[1],
54                host=host.Host(
55                    self.controllers[1], common.AsyncPipeSink(self.controllers[1])
56                ),
57            ),
58        ]
59        self.devices[0].classic_enabled = True
60        self.devices[1].classic_enabled = True
61        self.connections = [None, None]
62        self.protocols = [None, None]
63
64    def on_connection(self, which, connection):
65        self.connections[which] = connection
66
67    async def setup_connections(self):
68        await self.devices[0].power_on()
69        await self.devices[1].power_on()
70
71        self.connections = await asyncio.gather(
72            self.devices[0].connect(
73                self.devices[1].public_address, core.BT_BR_EDR_TRANSPORT
74            ),
75            self.devices[1].accept(self.devices[0].public_address),
76        )
77
78        self.protocols = [avrcp.Protocol(), avrcp.Protocol()]
79        self.protocols[0].listen(self.devices[1])
80        await self.protocols[1].connect(self.connections[0])
81
82
83# -----------------------------------------------------------------------------
84def test_frame_parser():
85    with pytest.raises(ValueError) as error:
86        avc.Frame.from_bytes(bytes.fromhex("11480000"))
87
88    x = bytes.fromhex("014D0208")
89    frame = avc.Frame.from_bytes(x)
90    assert frame.subunit_type == avc.Frame.SubunitType.PANEL
91    assert frame.subunit_id == 7
92    assert frame.opcode == 8
93
94    x = bytes.fromhex("014DFF0108")
95    frame = avc.Frame.from_bytes(x)
96    assert frame.subunit_type == avc.Frame.SubunitType.PANEL
97    assert frame.subunit_id == 260
98    assert frame.opcode == 8
99
100    x = bytes.fromhex("0148000019581000000103")
101
102    frame = avc.Frame.from_bytes(x)
103
104    assert isinstance(frame, avc.CommandFrame)
105    assert frame.ctype == avc.CommandFrame.CommandType.STATUS
106    assert frame.subunit_type == avc.Frame.SubunitType.PANEL
107    assert frame.subunit_id == 0
108    assert frame.opcode == 0
109
110
111# -----------------------------------------------------------------------------
112def test_vendor_dependent_command():
113    x = bytes.fromhex("0148000019581000000103")
114    frame = avc.Frame.from_bytes(x)
115    assert isinstance(frame, avc.VendorDependentCommandFrame)
116    assert frame.company_id == 0x1958
117    assert frame.vendor_dependent_data == bytes.fromhex("1000000103")
118
119    frame = avc.VendorDependentCommandFrame(
120        avc.CommandFrame.CommandType.STATUS,
121        avc.Frame.SubunitType.PANEL,
122        0,
123        0x1958,
124        bytes.fromhex("1000000103"),
125    )
126    assert bytes(frame) == x
127
128
129# -----------------------------------------------------------------------------
130def test_avctp_message_assembler():
131    received_message = []
132
133    def on_message(transaction_label, is_response, ipid, pid, payload):
134        received_message.append((transaction_label, is_response, ipid, pid, payload))
135
136    assembler = avctp.MessageAssembler(on_message)
137
138    payload = bytes.fromhex("01")
139    assembler.on_pdu(bytes([1 << 4 | 0b00 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload)
140    assert received_message
141    assert received_message[0] == (1, False, False, 0x1122, payload)
142
143    received_message = []
144    payload = bytes.fromhex("010203")
145    assembler.on_pdu(bytes([1 << 4 | 0b01 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload)
146    assert len(received_message) == 0
147    assembler.on_pdu(bytes([1 << 4 | 0b00 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload)
148    assert received_message
149    assert received_message[0] == (1, False, False, 0x1122, payload)
150
151    received_message = []
152    payload = bytes.fromhex("010203")
153    assembler.on_pdu(
154        bytes([1 << 4 | 0b01 << 2 | 1 << 1 | 0, 3, 0x11, 0x22]) + payload[0:1]
155    )
156    assembler.on_pdu(
157        bytes([1 << 4 | 0b10 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload[1:2]
158    )
159    assembler.on_pdu(
160        bytes([1 << 4 | 0b11 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload[2:3]
161    )
162    assert received_message
163    assert received_message[0] == (1, False, False, 0x1122, payload)
164
165    # received_message = []
166    # parameter = bytes.fromhex("010203")
167    # assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, len(parameter)) + parameter)
168    # assert len(received_message) == 0
169
170
171# -----------------------------------------------------------------------------
172def test_avrcp_pdu_assembler():
173    received_pdus = []
174
175    def on_pdu(pdu_id, parameter):
176        received_pdus.append((pdu_id, parameter))
177
178    assembler = avrcp.PduAssembler(on_pdu)
179
180    parameter = bytes.fromhex("01")
181    assembler.on_pdu(struct.pack(">BBH", 0x10, 0b00, len(parameter)) + parameter)
182    assert received_pdus
183    assert received_pdus[0] == (0x10, parameter)
184
185    received_pdus = []
186    parameter = bytes.fromhex("010203")
187    assembler.on_pdu(struct.pack(">BBH", 0x10, 0b01, len(parameter)) + parameter)
188    assert len(received_pdus) == 0
189    assembler.on_pdu(struct.pack(">BBH", 0x10, 0b00, len(parameter)) + parameter)
190    assert received_pdus
191    assert received_pdus[0] == (0x10, parameter)
192
193    received_pdus = []
194    parameter = bytes.fromhex("010203")
195    assembler.on_pdu(struct.pack(">BBH", 0x10, 0b01, 1) + parameter[0:1])
196    assembler.on_pdu(struct.pack(">BBH", 0x10, 0b10, 1) + parameter[1:2])
197    assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, 1) + parameter[2:3])
198    assert received_pdus
199    assert received_pdus[0] == (0x10, parameter)
200
201    received_pdus = []
202    parameter = bytes.fromhex("010203")
203    assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, len(parameter)) + parameter)
204    assert len(received_pdus) == 0
205
206
207def test_passthrough_commands():
208    play_pressed = avc.PassThroughCommandFrame(
209        avc.CommandFrame.CommandType.CONTROL,
210        avc.CommandFrame.SubunitType.PANEL,
211        0,
212        avc.PassThroughCommandFrame.StateFlag.PRESSED,
213        avc.PassThroughCommandFrame.OperationId.PLAY,
214        b'',
215    )
216
217    play_pressed_bytes = bytes(play_pressed)
218    parsed = avc.Frame.from_bytes(play_pressed_bytes)
219    assert isinstance(parsed, avc.PassThroughCommandFrame)
220    assert parsed.operation_id == avc.PassThroughCommandFrame.OperationId.PLAY
221    assert bytes(parsed) == play_pressed_bytes
222
223
224# -----------------------------------------------------------------------------
225@pytest.mark.asyncio
226async def test_get_supported_events():
227    two_devices = TwoDevices()
228    await two_devices.setup_connections()
229
230    supported_events = await two_devices.protocols[0].get_supported_events()
231    assert supported_events == []
232
233    delegate1 = avrcp.Delegate([avrcp.EventId.VOLUME_CHANGED])
234    two_devices.protocols[0].delegate = delegate1
235    supported_events = await two_devices.protocols[1].get_supported_events()
236    assert supported_events == [avrcp.EventId.VOLUME_CHANGED]
237
238
239# -----------------------------------------------------------------------------
240if __name__ == '__main__':
241    test_frame_parser()
242    test_vendor_dependent_command()
243    test_avctp_message_assembler()
244    test_avrcp_pdu_assembler()
245    test_passthrough_commands()
246    test_get_supported_events()
247