From 04022d24858026a1e0c6b80baed2a31bb3a5fe4e Mon Sep 17 00:00:00 2001
From: Cem Aksoylar <caksoylar@users.noreply.github.com>
Date: Sun, 15 Oct 2023 23:27:45 -0700
Subject: [PATCH] feat(mouse): Split move/scroll & x/y acceleration

---
 app/include/zmk/events/mouse_tick.h |  6 +--
 app/include/zmk/mouse.h             |  7 +++
 app/src/mouse/key_listener.c        | 70 ++++++++++++++++++++++-------
 app/src/mouse/tick_listener.c       | 21 +++++----
 4 files changed, 77 insertions(+), 27 deletions(-)

diff --git a/app/include/zmk/events/mouse_tick.h b/app/include/zmk/events/mouse_tick.h
index 2f69b045..75a041e9 100644
--- a/app/include/zmk/events/mouse_tick.h
+++ b/app/include/zmk/events/mouse_tick.h
@@ -17,7 +17,7 @@ struct zmk_mouse_tick {
     struct vector2d max_scroll;
     struct mouse_config move_config;
     struct mouse_config scroll_config;
-    int64_t *start_time;
+    struct mouse_times start_times;
     int64_t timestamp;
 };
 
@@ -27,13 +27,13 @@ static inline struct zmk_mouse_tick_event *zmk_mouse_tick(struct vector2d max_mo
                                                           struct vector2d max_scroll,
                                                           struct mouse_config move_config,
                                                           struct mouse_config scroll_config,
-                                                          int64_t *movement_start) {
+                                                          struct mouse_times movement_start) {
     return new_zmk_mouse_tick((struct zmk_mouse_tick){
         .max_move = max_move,
         .max_scroll = max_scroll,
         .move_config = move_config,
         .scroll_config = scroll_config,
-        .start_time = movement_start,
+        .start_times = movement_start,
         .timestamp = k_uptime_get(),
     });
 }
diff --git a/app/include/zmk/mouse.h b/app/include/zmk/mouse.h
index 1944ea47..c2ae6252 100644
--- a/app/include/zmk/mouse.h
+++ b/app/include/zmk/mouse.h
@@ -25,5 +25,12 @@ struct vector2d {
     float y;
 };
 
+struct mouse_times {
+    uint64_t m_x;
+    uint64_t m_y;
+    uint64_t s_x;
+    uint64_t s_y;
+};
+
 struct k_work_q *zmk_mouse_work_q();
 int zmk_mouse_init();
diff --git a/app/src/mouse/key_listener.c b/app/src/mouse/key_listener.c
index 91c4c5ce..bd06efeb 100644
--- a/app/src/mouse/key_listener.c
+++ b/app/src/mouse/key_listener.c
@@ -23,7 +23,7 @@ static struct vector2d move_speed = {0};
 static struct vector2d scroll_speed = {0};
 static struct mouse_config move_config = (struct mouse_config){0};
 static struct mouse_config scroll_config = (struct mouse_config){0};
-static int64_t start_time = 0;
+static struct mouse_times start_times = (struct mouse_times){0};
 
 bool equals(const struct mouse_config *one, const struct mouse_config *other) {
     return one->delay_ms == other->delay_ms &&
@@ -34,7 +34,7 @@ bool equals(const struct mouse_config *one, const struct mouse_config *other) {
 static void clear_mouse_state(struct k_work *work) {
     move_speed = (struct vector2d){0};
     scroll_speed = (struct vector2d){0};
-    start_time = 0;
+    start_times = (struct mouse_times){0};
     zmk_hid_mouse_movement_set(0, 0);
     zmk_hid_mouse_scroll_set(0, 0);
     LOG_DBG("Clearing state");
@@ -51,7 +51,7 @@ static void mouse_tick_timer_handler(struct k_work *work) {
     zmk_hid_mouse_scroll_set(0, 0);
     LOG_DBG("Raising mouse tick event");
     ZMK_EVENT_RAISE(
-        zmk_mouse_tick(move_speed, scroll_speed, move_config, scroll_config, &start_time));
+        zmk_mouse_tick(move_speed, scroll_speed, move_config, scroll_config, start_times));
     zmk_endpoints_send_mouse_report();
 }
 
@@ -64,21 +64,59 @@ void mouse_timer_cb(struct k_timer *dummy) {
 
 K_TIMER_DEFINE(mouse_timer, mouse_timer_cb, mouse_clear_cb);
 
-static int mouse_timer_ref_count = 0;
+static struct {
+    int m_x;
+    int m_y;
+    int s_x;
+    int s_y;
+} mouse_timer_ref_counts = {0, 0, 0, 0};
 
-void mouse_timer_ref() {
-    if (mouse_timer_ref_count == 0) {
-        start_time = k_uptime_get();
+void mouse_timer_ref(bool m_x, bool m_y, bool s_x, bool s_y) {
+    if (m_x && mouse_timer_ref_counts.m_x == 0) {
+        start_times.m_x = k_uptime_get();
+    }
+    if (m_y && mouse_timer_ref_counts.m_y == 0) {
+        start_times.m_y = k_uptime_get();
+    }
+    if (s_x && mouse_timer_ref_counts.s_x == 0) {
+        start_times.s_x = k_uptime_get();
+    }
+    if (s_y && mouse_timer_ref_counts.s_y == 0) {
+        start_times.s_y = k_uptime_get();
+    }
+    if (mouse_timer_ref_counts.m_x == 0 && mouse_timer_ref_counts.m_y == 0 &&
+        mouse_timer_ref_counts.s_x == 0 && mouse_timer_ref_counts.s_y == 0) {
         k_timer_start(&mouse_timer, K_NO_WAIT, K_MSEC(CONFIG_ZMK_MOUSE_TICK_DURATION));
     }
-    mouse_timer_ref_count += 1;
+    if (m_x) {
+        mouse_timer_ref_counts.m_x++;
+    }
+    if (m_y) {
+        mouse_timer_ref_counts.m_y++;
+    }
+    if (s_x) {
+        mouse_timer_ref_counts.s_x++;
+    }
+    if (s_y) {
+        mouse_timer_ref_counts.s_y++;
+    }
 }
 
-void mouse_timer_unref() {
-    if (mouse_timer_ref_count > 0) {
-        mouse_timer_ref_count--;
+void mouse_timer_unref(bool m_x, bool m_y, bool s_x, bool s_y) {
+    if (m_x && mouse_timer_ref_counts.m_x > 0) {
+        mouse_timer_ref_counts.m_x--;
     }
-    if (mouse_timer_ref_count == 0) {
+    if (m_y && mouse_timer_ref_counts.m_y > 0) {
+        mouse_timer_ref_counts.m_y--;
+    }
+    if (s_x && mouse_timer_ref_counts.s_x > 0) {
+        mouse_timer_ref_counts.s_x--;
+    }
+    if (s_y && mouse_timer_ref_counts.s_y > 0) {
+        mouse_timer_ref_counts.s_y--;
+    }
+    if (mouse_timer_ref_counts.m_x == 0 && mouse_timer_ref_counts.m_y == 0 &&
+        mouse_timer_ref_counts.s_x == 0 && mouse_timer_ref_counts.s_y == 0) {
         k_timer_stop(&mouse_timer);
     }
 }
@@ -86,25 +124,25 @@ void mouse_timer_unref() {
 static void listener_mouse_move_pressed(const struct zmk_mouse_move_state_changed *ev) {
     move_speed.x += ev->max_speed.x;
     move_speed.y += ev->max_speed.y;
-    mouse_timer_ref();
+    mouse_timer_ref(ev->max_speed.x != 0, ev->max_speed.y != 0, false, false);
 }
 
 static void listener_mouse_move_released(const struct zmk_mouse_move_state_changed *ev) {
     move_speed.x -= ev->max_speed.x;
     move_speed.y -= ev->max_speed.y;
-    mouse_timer_unref();
+    mouse_timer_unref(ev->max_speed.x != 0, ev->max_speed.y != 0, false, false);
 }
 
 static void listener_mouse_scroll_pressed(const struct zmk_mouse_scroll_state_changed *ev) {
     scroll_speed.x += ev->max_speed.x;
     scroll_speed.y += ev->max_speed.y;
-    mouse_timer_ref();
+    mouse_timer_ref(false, false, ev->max_speed.x != 0, ev->max_speed.y != 0);
 }
 
 static void listener_mouse_scroll_released(const struct zmk_mouse_scroll_state_changed *ev) {
     scroll_speed.x -= ev->max_speed.x;
     scroll_speed.y -= ev->max_speed.y;
-    mouse_timer_unref();
+    mouse_timer_unref(false, false, ev->max_speed.x != 0, ev->max_speed.y != 0);
 }
 
 static void listener_mouse_button_pressed(const struct zmk_mouse_button_state_changed *ev) {
diff --git a/app/src/mouse/tick_listener.c b/app/src/mouse/tick_listener.c
index f84b29f7..90aaf111 100644
--- a/app/src/mouse/tick_listener.c
+++ b/app/src/mouse/tick_listener.c
@@ -61,17 +61,18 @@ static void track_remainder(float *move, float *remainder) {
 
 static struct vector2d update_movement(struct vector2d *remainder,
                                        const struct mouse_config *config, struct vector2d max_speed,
-                                       int64_t now, int64_t *start_time) {
+                                       int64_t now, int64_t start_time_x, int64_t start_time_y) {
     struct vector2d move = {0};
     if (max_speed.x == 0 && max_speed.y == 0) {
         *remainder = (struct vector2d){0};
         return move;
     }
 
-    int64_t move_duration = ms_since_start(*start_time, now, config->delay_ms);
+    int64_t move_duration_x = ms_since_start(start_time_x, now, config->delay_ms);
+    int64_t move_duration_y = ms_since_start(start_time_y, now, config->delay_ms);
     move = (struct vector2d){
-        .x = speed(config, max_speed.x, move_duration) * CONFIG_ZMK_MOUSE_TICK_DURATION / 1000,
-        .y = speed(config, max_speed.y, move_duration) * CONFIG_ZMK_MOUSE_TICK_DURATION / 1000,
+        .x = speed(config, max_speed.x, move_duration_x) * CONFIG_ZMK_MOUSE_TICK_DURATION / 1000,
+        .y = speed(config, max_speed.y, move_duration_y) * CONFIG_ZMK_MOUSE_TICK_DURATION / 1000,
     };
 
     track_remainder(&(move.x), &(remainder->x));
@@ -81,12 +82,16 @@ static struct vector2d update_movement(struct vector2d *remainder,
 }
 
 static void mouse_tick_handler(const struct zmk_mouse_tick *tick) {
-    struct vector2d move = update_movement(&move_remainder, &(tick->move_config), tick->max_move,
-                                           tick->timestamp, tick->start_time);
+    LOG_INF("tick start times: %lld %lld %lld %lld", tick->start_times.m_x, tick->start_times.m_y,
+            tick->start_times.s_x, tick->start_times.s_y);
+    struct vector2d move =
+        update_movement(&move_remainder, &(tick->move_config), tick->max_move, tick->timestamp,
+                        tick->start_times.m_x, tick->start_times.m_y);
     zmk_hid_mouse_movement_update((int16_t)CLAMP(move.x, INT16_MIN, INT16_MAX),
                                   (int16_t)CLAMP(move.y, INT16_MIN, INT16_MAX));
-    struct vector2d scroll = update_movement(&scroll_remainder, &(tick->scroll_config),
-                                             tick->max_scroll, tick->timestamp, tick->start_time);
+    struct vector2d scroll =
+        update_movement(&scroll_remainder, &(tick->scroll_config), tick->max_scroll,
+                        tick->timestamp, tick->start_times.s_x, tick->start_times.s_y);
     zmk_hid_mouse_scroll_update((int8_t)CLAMP(scroll.x, INT8_MIN, INT8_MAX),
                                 (int8_t)CLAMP(scroll.y, INT8_MIN, INT8_MAX));
 }