xref: /aosp_15_r20/frameworks/base/packages/SystemUI/compose/core/src/com/android/compose/gesture/NestedDraggable.kt (revision d57664e9bc4670b3ecf6748a746a57c557b6bc9e)
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.compose.gesture
18 
19 import androidx.compose.foundation.OverscrollEffect
20 import androidx.compose.foundation.gestures.Orientation
21 import androidx.compose.foundation.gestures.awaitEachGesture
22 import androidx.compose.foundation.gestures.awaitFirstDown
23 import androidx.compose.foundation.gestures.awaitHorizontalTouchSlopOrCancellation
24 import androidx.compose.foundation.gestures.awaitVerticalTouchSlopOrCancellation
25 import androidx.compose.foundation.gestures.horizontalDrag
26 import androidx.compose.foundation.gestures.verticalDrag
27 import androidx.compose.foundation.overscroll
28 import androidx.compose.ui.Modifier
29 import androidx.compose.ui.geometry.Offset
30 import androidx.compose.ui.input.nestedscroll.NestedScrollConnection
31 import androidx.compose.ui.input.nestedscroll.NestedScrollDispatcher
32 import androidx.compose.ui.input.nestedscroll.NestedScrollSource
33 import androidx.compose.ui.input.nestedscroll.nestedScrollModifierNode
34 import androidx.compose.ui.input.pointer.AwaitPointerEventScope
35 import androidx.compose.ui.input.pointer.PointerEvent
36 import androidx.compose.ui.input.pointer.PointerEventPass
37 import androidx.compose.ui.input.pointer.PointerId
38 import androidx.compose.ui.input.pointer.PointerInputChange
39 import androidx.compose.ui.input.pointer.PointerInputScope
40 import androidx.compose.ui.input.pointer.SuspendingPointerInputModifierNode
41 import androidx.compose.ui.input.pointer.changedToDownIgnoreConsumed
42 import androidx.compose.ui.input.pointer.changedToUpIgnoreConsumed
43 import androidx.compose.ui.input.pointer.positionChange
44 import androidx.compose.ui.input.pointer.util.VelocityTracker
45 import androidx.compose.ui.input.pointer.util.addPointerInputChange
46 import androidx.compose.ui.node.CompositionLocalConsumerModifierNode
47 import androidx.compose.ui.node.DelegatingNode
48 import androidx.compose.ui.node.ModifierNodeElement
49 import androidx.compose.ui.node.PointerInputModifierNode
50 import androidx.compose.ui.node.currentValueOf
51 import androidx.compose.ui.platform.LocalViewConfiguration
52 import androidx.compose.ui.unit.IntSize
53 import androidx.compose.ui.unit.Velocity
54 import androidx.compose.ui.util.fastAny
55 import androidx.compose.ui.util.fastSumBy
56 import com.android.compose.modifiers.thenIf
57 import kotlin.math.sign
58 import kotlinx.coroutines.CoroutineScope
59 import kotlinx.coroutines.CoroutineStart
60 import kotlinx.coroutines.launch
61 
62 /**
63  * A draggable that plays nicely with the nested scroll mechanism.
64  *
65  * This can be used whenever you need a draggable inside a scrollable or a draggable that contains a
66  * scrollable.
67  */
68 interface NestedDraggable {
69     /**
70      * Called when a drag is started in the given [position] (*before* dragging the touch slop) and
71      * in the direction given by [sign], with the given number of [pointersDown] when the touch slop
72      * was detected.
73      */
74     fun onDragStarted(position: Offset, sign: Float, pointersDown: Int): Controller
75 
76     /**
77      * Whether this draggable should consume any scroll amount with the given [sign] coming from a
78      * nested scrollable.
79      *
80      * This is called whenever a nested scrollable does not consume some scroll amount. If this
81      * returns `true`, then [onDragStarted] will be called and this draggable will have priority and
82      * consume all future events during preScroll until the nested scroll is finished.
83      */
84     fun shouldConsumeNestedScroll(sign: Float): Boolean
85 
86     interface Controller {
87         /**
88          * Drag by [delta] pixels.
89          *
90          * @return the consumed [delta]. Any non-consumed delta will be dispatched to the next
91          *   nested scroll connection to be consumed by any composable above in the hierarchy. If
92          *   the drag was performed on this draggable directly (instead of on a nested scrollable),
93          *   any remaining delta will be used to overscroll this draggable.
94          */
95         fun onDrag(delta: Float): Float
96 
97         /**
98          * Stop the current drag with the given [velocity].
99          *
100          * @return the consumed [velocity]. Any non-consumed velocity will be dispatched to the next
101          *   nested scroll connection to be consumed by any composable above in the hierarchy. If
102          *   the drag was performed on this draggable directly (instead of on a nested scrollable),
103          *   any remaining velocity will be used to animate the overscroll of this draggable.
104          */
105         suspend fun onDragStopped(velocity: Float): Float
106     }
107 }
108 
109 /**
110  * A draggable that supports nested scrolling and overscroll effects.
111  *
112  * @see NestedDraggable
113  */
Modifiernull114 fun Modifier.nestedDraggable(
115     draggable: NestedDraggable,
116     orientation: Orientation,
117     overscrollEffect: OverscrollEffect? = null,
118     enabled: Boolean = true,
119 ): Modifier {
120     return this.thenIf(overscrollEffect != null) { Modifier.overscroll(overscrollEffect) }
121         .then(NestedDraggableElement(draggable, orientation, overscrollEffect, enabled))
122 }
123 
124 private data class NestedDraggableElement(
125     private val draggable: NestedDraggable,
126     private val orientation: Orientation,
127     private val overscrollEffect: OverscrollEffect?,
128     private val enabled: Boolean,
129 ) : ModifierNodeElement<NestedDraggableNode>() {
createnull130     override fun create(): NestedDraggableNode {
131         return NestedDraggableNode(draggable, orientation, overscrollEffect, enabled)
132     }
133 
updatenull134     override fun update(node: NestedDraggableNode) {
135         node.update(draggable, orientation, overscrollEffect, enabled)
136     }
137 }
138 
139 private class NestedDraggableNode(
140     private var draggable: NestedDraggable,
141     override var orientation: Orientation,
142     private var overscrollEffect: OverscrollEffect?,
143     private var enabled: Boolean,
144 ) :
145     DelegatingNode(),
146     PointerInputModifierNode,
147     NestedScrollConnection,
148     CompositionLocalConsumerModifierNode,
149     OrientationAware {
150     private val nestedScrollDispatcher = NestedScrollDispatcher()
151     private var trackDownPositionDelegate: SuspendingPointerInputModifierNode? = null
152         set(value) {
<lambda>null153             field?.let { undelegate(it) }
<lambda>null154             field = value?.also { delegate(it) }
155         }
156 
157     private var detectDragsDelegate: SuspendingPointerInputModifierNode? = null
158         set(value) {
<lambda>null159             field?.let { undelegate(it) }
<lambda>null160             field = value?.also { delegate(it) }
161         }
162 
163     /** The controller created by the nested scroll logic (and *not* the drag logic). */
164     private var nestedScrollController: WrappedController? = null
165         set(value) {
166             field?.ensureOnDragStoppedIsCalled()
167             field = value
168         }
169 
170     /**
171      * The last pointer which was the first down since the last time all pointers were up.
172      *
173      * This is use to track the started position of a drag started on a nested scrollable.
174      */
175     private var lastFirstDown: Offset? = null
176 
177     /** The number of pointers down. */
178     private var pointersDownCount = 0
179 
<lambda>null180     init {
181         delegate(nestedScrollModifierNode(this, nestedScrollDispatcher))
182     }
183 
onDetachnull184     override fun onDetach() {
185         nestedScrollController?.ensureOnDragStoppedIsCalled()
186     }
187 
updatenull188     fun update(
189         draggable: NestedDraggable,
190         orientation: Orientation,
191         overscrollEffect: OverscrollEffect?,
192         enabled: Boolean,
193     ) {
194         this.draggable = draggable
195         this.orientation = orientation
196         this.overscrollEffect = overscrollEffect
197         this.enabled = enabled
198 
199         trackDownPositionDelegate?.resetPointerInputHandler()
200         detectDragsDelegate?.resetPointerInputHandler()
201         nestedScrollController?.ensureOnDragStoppedIsCalled()
202 
203         if (!enabled && trackDownPositionDelegate != null) {
204             check(detectDragsDelegate != null)
205             trackDownPositionDelegate = null
206             detectDragsDelegate = null
207         }
208     }
209 
onPointerEventnull210     override fun onPointerEvent(
211         pointerEvent: PointerEvent,
212         pass: PointerEventPass,
213         bounds: IntSize,
214     ) {
215         if (!enabled) return
216 
217         if (trackDownPositionDelegate == null) {
218             check(detectDragsDelegate == null)
219             trackDownPositionDelegate = SuspendingPointerInputModifierNode { trackDownPosition() }
220             detectDragsDelegate = SuspendingPointerInputModifierNode { detectDrags() }
221         }
222 
223         checkNotNull(trackDownPositionDelegate).onPointerEvent(pointerEvent, pass, bounds)
224         checkNotNull(detectDragsDelegate).onPointerEvent(pointerEvent, pass, bounds)
225     }
226 
onCancelPointerInputnull227     override fun onCancelPointerInput() {
228         trackDownPositionDelegate?.onCancelPointerInput()
229         detectDragsDelegate?.onCancelPointerInput()
230     }
231 
232     /*
233      * ======================================
234      * ===== Pointer input (drag) logic =====
235      * ======================================
236      */
237 
detectDragsnull238     private suspend fun PointerInputScope.detectDrags() {
239         // Lazily create the velocity tracker when the pointer input restarts.
240         val velocityTracker = VelocityTracker()
241 
242         awaitEachGesture {
243             val down = awaitFirstDown(requireUnconsumed = false)
244             check(down.position == lastFirstDown) {
245                 "Position from detectDrags() is not the same as position in trackDownPosition()"
246             }
247             check(pointersDownCount == 1) { "pointersDownCount is equal to $pointersDownCount" }
248 
249             var overSlop = 0f
250             val onTouchSlopReached = { change: PointerInputChange, over: Float ->
251                 change.consume()
252                 overSlop = over
253             }
254 
255             suspend fun AwaitPointerEventScope.awaitTouchSlopOrCancellation(
256                 pointerId: PointerId
257             ): PointerInputChange? {
258                 return when (orientation) {
259                     Orientation.Horizontal ->
260                         awaitHorizontalTouchSlopOrCancellation(pointerId, onTouchSlopReached)
261                     Orientation.Vertical ->
262                         awaitVerticalTouchSlopOrCancellation(pointerId, onTouchSlopReached)
263                 }
264             }
265 
266             var drag = awaitTouchSlopOrCancellation(down.id)
267 
268             // We try to pick-up the drag gesture in case the touch slop swipe was consumed by a
269             // nested scrollable child that disappeared.
270             // This was copied from http://shortn/_10L8U02IoL.
271             // TODO(b/380838584): Reuse detect(Horizontal|Vertical)DragGestures() instead.
272             while (drag == null && currentEvent.changes.fastAny { it.pressed }) {
273                 var event: PointerEvent
274                 do {
275                     event = awaitPointerEvent()
276                 } while (
277                     event.changes.fastAny { it.isConsumed } && event.changes.fastAny { it.pressed }
278                 )
279 
280                 // An event was not consumed and there's still a pointer in the screen.
281                 if (event.changes.fastAny { it.pressed }) {
282                     // Await touch slop again, using the initial down as starting point.
283                     // For most cases this should return immediately since we probably moved
284                     // far enough from the initial down event.
285                     drag = awaitTouchSlopOrCancellation(down.id)
286                 }
287             }
288 
289             if (drag != null) {
290                 velocityTracker.resetTracking()
291                 val sign = (drag.position - down.position).toFloat().sign
292                 check(pointersDownCount > 0) { "pointersDownCount is equal to $pointersDownCount" }
293                 val wrappedController =
294                     WrappedController(
295                         coroutineScope,
296                         draggable.onDragStarted(down.position, sign, pointersDownCount),
297                     )
298                 if (overSlop != 0f) {
299                     onDrag(wrappedController, drag, overSlop, velocityTracker)
300                 }
301 
302                 // If a drag was started, we cancel any other drag started by a nested scrollable.
303                 //
304                 // Note: we cancel the nested drag here *after* starting the new drag so that in the
305                 // STL case, the cancelled drag will not change the current scene of the STL.
306                 nestedScrollController?.ensureOnDragStoppedIsCalled()
307 
308                 val isSuccessful =
309                     try {
310                         val onDrag = { change: PointerInputChange ->
311                             onDrag(
312                                 wrappedController,
313                                 change,
314                                 change.positionChange().toFloat(),
315                                 velocityTracker,
316                             )
317                             change.consume()
318                         }
319 
320                         when (orientation) {
321                             Orientation.Horizontal -> horizontalDrag(drag.id, onDrag)
322                             Orientation.Vertical -> verticalDrag(drag.id, onDrag)
323                         }
324                     } catch (t: Throwable) {
325                         wrappedController.ensureOnDragStoppedIsCalled()
326                         throw t
327                     }
328 
329                 if (isSuccessful) {
330                     val maxVelocity = currentValueOf(LocalViewConfiguration).maximumFlingVelocity
331                     val velocity =
332                         velocityTracker
333                             .calculateVelocity(Velocity(maxVelocity, maxVelocity))
334                             .toFloat()
335                     onDragStopped(wrappedController, velocity)
336                 } else {
337                     onDragStopped(wrappedController, velocity = 0f)
338                 }
339             }
340         }
341     }
342 
onDragnull343     private fun onDrag(
344         controller: NestedDraggable.Controller,
345         change: PointerInputChange,
346         delta: Float,
347         velocityTracker: VelocityTracker,
348     ) {
349         velocityTracker.addPointerInputChange(change)
350 
351         scrollWithOverscroll(delta) { deltaFromOverscroll ->
352             scrollWithNestedScroll(deltaFromOverscroll) { deltaFromNestedScroll ->
353                 controller.onDrag(deltaFromNestedScroll)
354             }
355         }
356     }
357 
onDragStoppednull358     private fun onDragStopped(controller: WrappedController, velocity: Float) {
359         coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) {
360             try {
361                 flingWithOverscroll(velocity) { velocityFromOverscroll ->
362                     flingWithNestedScroll(velocityFromOverscroll) { velocityFromNestedScroll ->
363                         controller.onDragStopped(velocityFromNestedScroll)
364                     }
365                 }
366             } finally {
367                 controller.ensureOnDragStoppedIsCalled()
368             }
369         }
370     }
371 
scrollWithOverscrollnull372     private fun scrollWithOverscroll(delta: Float, performScroll: (Float) -> Float): Float {
373         val effect = overscrollEffect
374         return if (effect != null) {
375             effect
376                 .applyToScroll(delta.toOffset(), source = NestedScrollSource.UserInput) {
377                     performScroll(it.toFloat()).toOffset()
378                 }
379                 .toFloat()
380         } else {
381             performScroll(delta)
382         }
383     }
384 
scrollWithNestedScrollnull385     private fun scrollWithNestedScroll(delta: Float, performScroll: (Float) -> Float): Float {
386         val preConsumed =
387             nestedScrollDispatcher
388                 .dispatchPreScroll(
389                     available = delta.toOffset(),
390                     source = NestedScrollSource.UserInput,
391                 )
392                 .toFloat()
393         val available = delta - preConsumed
394         val consumed = performScroll(available)
395         val left = available - consumed
396         val postConsumed =
397             nestedScrollDispatcher
398                 .dispatchPostScroll(
399                     consumed = (preConsumed + consumed).toOffset(),
400                     available = left.toOffset(),
401                     source = NestedScrollSource.UserInput,
402                 )
403                 .toFloat()
404         return consumed + preConsumed + postConsumed
405     }
406 
flingWithOverscrollnull407     private suspend fun flingWithOverscroll(
408         velocity: Float,
409         performFling: suspend (Float) -> Float,
410     ) {
411         val effect = overscrollEffect
412         if (effect != null) {
413             effect.applyToFling(velocity.toVelocity()) { performFling(it.toFloat()).toVelocity() }
414         } else {
415             performFling(velocity)
416         }
417     }
418 
flingWithNestedScrollnull419     private suspend fun flingWithNestedScroll(
420         velocity: Float,
421         performFling: suspend (Float) -> Float,
422     ): Float {
423         val preConsumed = nestedScrollDispatcher.dispatchPreFling(available = velocity.toVelocity())
424         val available = velocity - preConsumed.toFloat()
425         val consumed = performFling(available)
426         val left = available - consumed
427         return nestedScrollDispatcher
428             .dispatchPostFling(
429                 consumed = consumed.toVelocity() + preConsumed,
430                 available = left.toVelocity(),
431             )
432             .toFloat()
433     }
434 
435     /*
436      * ===============================
437      * ===== Nested scroll logic =====
438      * ===============================
439      */
440 
trackDownPositionnull441     private suspend fun PointerInputScope.trackDownPosition() {
442         awaitEachGesture {
443             val down = awaitFirstDown(requireUnconsumed = false)
444             lastFirstDown = down.position
445             pointersDownCount = 1
446 
447             do {
448                 pointersDownCount +=
449                     awaitPointerEvent().changes.fastSumBy { change ->
450                         when {
451                             change.changedToDownIgnoreConsumed() -> 1
452                             change.changedToUpIgnoreConsumed() -> -1
453                             else -> 0
454                         }
455                     }
456             } while (pointersDownCount > 0)
457         }
458     }
459 
onPreScrollnull460     override fun onPreScroll(available: Offset, source: NestedScrollSource): Offset {
461         val controller = nestedScrollController ?: return Offset.Zero
462         val consumed = controller.onDrag(available.toFloat())
463         return consumed.toOffset()
464     }
465 
onPostScrollnull466     override fun onPostScroll(
467         consumed: Offset,
468         available: Offset,
469         source: NestedScrollSource,
470     ): Offset {
471         if (source == NestedScrollSource.SideEffect) {
472             check(nestedScrollController == null)
473             return Offset.Zero
474         }
475 
476         val offset = available.toFloat()
477         if (offset == 0f) {
478             return Offset.Zero
479         }
480 
481         val sign = offset.sign
482         if (nestedScrollController == null && draggable.shouldConsumeNestedScroll(sign)) {
483             val startedPosition = checkNotNull(lastFirstDown) { "lastFirstDown is not set" }
484 
485             // TODO(b/382665591): Replace this by check(pointersDownCount > 0).
486             val pointersDown = pointersDownCount.coerceAtLeast(1)
487             nestedScrollController =
488                 WrappedController(
489                     coroutineScope,
490                     draggable.onDragStarted(startedPosition, sign, pointersDown),
491                 )
492         }
493 
494         val controller = nestedScrollController ?: return Offset.Zero
495         return controller.onDrag(offset).toOffset()
496     }
497 
onPreFlingnull498     override suspend fun onPreFling(available: Velocity): Velocity {
499         val controller = nestedScrollController ?: return Velocity.Zero
500         nestedScrollController = null
501 
502         val consumed = controller.onDragStopped(available.toFloat())
503         return consumed.toVelocity()
504     }
505 }
506 
507 /**
508  * A controller that wraps [delegate] and can be used to ensure that [onDragStopped] is called, but
509  * not more than once.
510  */
511 private class WrappedController(
512     private val coroutineScope: CoroutineScope,
513     private val delegate: NestedDraggable.Controller,
<lambda>null514 ) : NestedDraggable.Controller by delegate {
515     private var onDragStoppedCalled = false
516 
517     override fun onDrag(delta: Float): Float {
518         if (onDragStoppedCalled) return 0f
519         return delegate.onDrag(delta)
520     }
521 
522     override suspend fun onDragStopped(velocity: Float): Float {
523         if (onDragStoppedCalled) return 0f
524         onDragStoppedCalled = true
525         return delegate.onDragStopped(velocity)
526     }
527 
528     fun ensureOnDragStoppedIsCalled() {
529         // Start with UNDISPATCHED so that onDragStopped() is always run until its first suspension
530         // point, even if coroutineScope is cancelled.
531         coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) { onDragStopped(velocity = 0f) }
532     }
533 }
534