From 5da0b64c3eccd5a25d84e17328abacf746f36650 Mon Sep 17 00:00:00 2001 From: Paul-Henri Froidmont Date: Thu, 11 Sep 2025 17:30:26 +0200 Subject: [PATCH] Implement subscriptions --- core/src/TestLiveView.scala | 3 ++ core/src/scalive/Diff.scala | 8 +++ core/src/scalive/LiveView.scala | 3 +- core/src/scalive/WebSocketMessage.scala | 21 +++++--- example/src/ExampleLiveView.scala | 7 ++- zio/src/scalive/LiveRouter.scala | 37 +++++++------ zio/src/scalive/Socket.scala | 69 +++++++++++++++++-------- 7 files changed, 103 insertions(+), 45 deletions(-) diff --git a/core/src/TestLiveView.scala b/core/src/TestLiveView.scala index 6ada1bb..9e3b24a 100644 --- a/core/src/TestLiveView.scala +++ b/core/src/TestLiveView.scala @@ -3,6 +3,7 @@ package playground import scalive.* import zio.* +import zio.stream.ZStream import TestView.* class TestView extends LiveView[Msg, Model]: @@ -29,6 +30,8 @@ class TestView extends LiveView[Msg, Model]: ) ) + def subscriptions(model: Model) = ZStream.empty + object TestView: enum Msg: diff --git a/core/src/scalive/Diff.scala b/core/src/scalive/Diff.scala index 86071e3..95ee3fd 100644 --- a/core/src/scalive/Diff.scala +++ b/core/src/scalive/Diff.scala @@ -16,6 +16,14 @@ enum Diff: case Dynamic(key: String, diff: Diff) case Deleted +extension (diff: Diff) + def isEmpty: Boolean = diff match + case Diff.Tag(static, dynamic) => static.isEmpty && dynamic.isEmpty + case _: Diff.Comprehension => false + case _: Diff.Value => false + case _: Diff.Dynamic => false + case Diff.Deleted => false + object Diff: given JsonEncoder[Diff] = JsonEncoder[Json].contramap(toJson(_)) diff --git a/core/src/scalive/LiveView.scala b/core/src/scalive/LiveView.scala index 360e4fe..0073564 100644 --- a/core/src/scalive/LiveView.scala +++ b/core/src/scalive/LiveView.scala @@ -1,9 +1,10 @@ package scalive import zio.* +import zio.stream.* trait LiveView[Msg, Model]: def init: Task[Model] def update(model: Model): Msg => Task[Model] def view(model: Dyn[Model]): HtmlElement - // def subscriptions(model: Model): ZStream[Any, Nothing, Msg] + def subscriptions(model: Model): ZStream[Any, Nothing, Msg] diff --git a/core/src/scalive/WebSocketMessage.scala b/core/src/scalive/WebSocketMessage.scala index 3d215ef..75c9533 100644 --- a/core/src/scalive/WebSocketMessage.scala +++ b/core/src/scalive/WebSocketMessage.scala @@ -11,18 +11,19 @@ final case class WebSocketMessage( // Live session ID, auto increment defined by the client on join joinRef: Option[Int], // Message ID, global auto increment defined by the client on every message - messageRef: Int, + messageRef: Option[Int], // LiveView instance id topic: String, eventType: String, payload: WebSocketMessage.Payload): - val meta = WebSocketMessage.Meta(joinRef, messageRef, topic) + val meta = WebSocketMessage.Meta(joinRef, messageRef, topic, eventType) object WebSocketMessage: final case class Meta( joinRef: Option[Int], - messageRef: Int, - topic: String) + messageRef: Option[Int], + topic: String, + eventType: String) given JsonCodec[WebSocketMessage] = JsonCodec[Json].transformOrFail( { @@ -38,7 +39,7 @@ object WebSocketMessage: payloadParsed.map( WebSocketMessage( joinRef.asString.map(_.toInt), - messageRef.toInt, + Some(messageRef.toInt), topic, eventType, _ @@ -49,14 +50,15 @@ object WebSocketMessage: m => Json.Arr( m.joinRef.map(ref => Json.Str(ref.toString)).getOrElse(Json.Null), - Json.Str(m.messageRef.toString), + m.messageRef.map(ref => Json.Str(ref.toString)).getOrElse(Json.Null), Json.Str(m.topic), Json.Str(m.eventType), - m.payload.match + m.payload match case Payload.Heartbeat => Json.Obj.empty case p: Payload.Join => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) case p: Payload.Reply => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) case p: Payload.Event => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) + case p: Payload.Diff => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) ) ) @@ -69,11 +71,16 @@ object WebSocketMessage: static: Option[String], sticky: Boolean) case Reply(status: String, response: LiveResponse) + case Diff(diff: scalive.Diff) case Event(`type`: Payload.EventType, event: String, value: Map[String, String]) object Payload: given JsonCodec[Payload.Join] = JsonCodec.derived given JsonEncoder[Payload.Reply] = JsonEncoder.derived given JsonCodec[Payload.Event] = JsonCodec.derived + given JsonEncoder[Payload.Diff] = JsonEncoder[scalive.Diff].contramap(_.diff) + + def okReply(response: LiveResponse) = + Payload.Reply("ok", response) enum EventType: case Click diff --git a/example/src/ExampleLiveView.scala b/example/src/ExampleLiveView.scala index dadf79e..4f2a3e6 100644 --- a/example/src/ExampleLiveView.scala +++ b/example/src/ExampleLiveView.scala @@ -1,9 +1,10 @@ +import ExampleLiveView.* import monocle.syntax.all.* import scalive.* import zio.* import zio.json.* +import zio.stream.ZStream -import ExampleLiveView.* class ExampleLiveView(someParam: String) extends LiveView[Msg, Model]: def init = ZIO.succeed( @@ -65,6 +66,10 @@ class ExampleLiveView(someParam: String) extends LiveView[Msg, Model]: ) ) ) + + def subscriptions(model: Model) = + ZStream.tick(1.second).map(_ => Msg.IncCounter).drop(1) + end ExampleLiveView object ExampleLiveView: diff --git a/zio/src/scalive/LiveRouter.scala b/zio/src/scalive/LiveRouter.scala index 7a528b4..7898394 100644 --- a/zio/src/scalive/LiveRouter.scala +++ b/zio/src/scalive/LiveRouter.scala @@ -48,12 +48,12 @@ final case class LiveRoute[A, Msg: JsonCodec, Model]( end LiveRoute class LiveChannel(private val sockets: SubscriptionRef[Map[String, Socket[?, ?]]]): - def diffsStream: ZStream[Any, Nothing, (LiveResponse, Meta)] = + def diffsStream: ZStream[Any, Nothing, (Payload, Meta)] = sockets.changes .map(m => ZStream .mergeAllUnbounded()(m.values.map(_.outbox).toList*) - ).flatMapParSwitch(1, 1)(identity) + ).flatMapParSwitch(1)(identity) def join[Msg: JsonCodec, Model]( id: String, @@ -102,21 +102,27 @@ class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRo ZIO.scoped(for liveChannel <- LiveChannel.make() _ <- liveChannel.diffsStream - .foreach((diff, meta) => - channel.send( - Read( - WebSocketFrame.text( - WebSocketMessage( - meta.joinRef, - meta.messageRef, - meta.topic, - "phx_reply", - Payload.Reply("ok", diff) - ).toJson + .runForeach((payload, meta) => + channel + .send( + Read( + WebSocketFrame.text( + WebSocketMessage( + meta.joinRef, + meta.messageRef, + meta.topic, + payload match + case Payload.Diff(_) => "diff" + case _ => "phx_reply", + payload + ).toJson + ) ) ) - ) - ).fork + ) + .tapErrorCause(c => ZIO.logErrorCause("diffsStream pipeline failed", c)) + .ensuring(ZIO.logWarning("WS out fiber terminated")) + .fork _ <- channel .receiveAll { case Read(WebSocketFrame.Text(content)) => @@ -169,6 +175,7 @@ class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRo .event(message.topic, event, message.meta) .map(_ => None) case Payload.Reply(_, _) => ZIO.die(new IllegalArgumentException()) + case Payload.Diff(_) => ZIO.die(new IllegalArgumentException()) end match end handleMessage diff --git a/zio/src/scalive/Socket.scala b/zio/src/scalive/Socket.scala index 1d0ca3d..b722628 100644 --- a/zio/src/scalive/Socket.scala +++ b/zio/src/scalive/Socket.scala @@ -5,13 +5,14 @@ import zio.* import zio.Queue import zio.json.* import zio.stream.ZStream +import zio.stream.SubscriptionRef +import scalive.WebSocketMessage.Payload final case class Socket[Msg: JsonCodec, Model] private ( id: String, token: String, inbox: Queue[(Msg, WebSocketMessage.Meta)], - outbox: ZStream[Any, Nothing, (LiveResponse, WebSocketMessage.Meta)], - fiber: Fiber.Runtime[Throwable, Unit], + outbox: ZStream[Any, Nothing, (Payload, WebSocketMessage.Meta)], shutdown: UIO[Unit]): val messageCodec = JsonCodec[Msg] @@ -23,28 +24,54 @@ object Socket: meta: WebSocketMessage.Meta ): RIO[Scope, Socket[Msg, Model]] = for - inbox <- Queue.bounded[(Msg, WebSocketMessage.Meta)](4) - outHub <- Hub.bounded[(LiveResponse, WebSocketMessage.Meta)](4) + inbox <- Queue.bounded[(Msg, WebSocketMessage.Meta)](4) + outHub <- Hub.unbounded[(Payload, WebSocketMessage.Meta)] + initModel <- lv.init modelVar = Var(initModel) el = lv.view(modelVar) ref <- Ref.make((modelVar, el)) + initDiff = el.diff(trackUpdates = false) - fiber <- ZStream - .fromQueue(inbox) - .mapZIO { (msg, meta) => - for - (modelVar, el) <- ref.get - updatedModel <- lv.update(modelVar.currentValue)(msg) - _ = modelVar.set(updatedModel) - diff = el.diff() - _ <- outHub.publish(LiveResponse.Diff(diff) -> meta) - yield () - } - .runDrain - .forkScoped - stop = inbox.shutdown *> outHub.shutdown *> fiber.interrupt.unit - diffStream <- ZStream.fromHubScoped(outHub) - outbox = ZStream.succeed(LiveResponse.InitDiff(initDiff) -> meta) ++ diffStream - yield Socket[Msg, Model](id, token, inbox, outbox, fiber, stop) + + lvStreamRef <- SubscriptionRef.make(lv.subscriptions(initModel)) + + clientMsgStream = ZStream.fromQueue(inbox) + serverMsgStream = (ZStream.fromZIO(lvStreamRef.get) ++ lvStreamRef.changes) + .flatMapParSwitch(1, 1)(identity) + .map(_ -> meta.copy(messageRef = None, eventType = "diff")) + + clientFiber <- clientMsgStream.runForeach { (msg, meta) => + for + (modelVar, el) <- ref.get + updatedModel <- lv.update(modelVar.currentValue)(msg) + _ = modelVar.set(updatedModel) + _ <- lvStreamRef.set(lv.subscriptions(updatedModel)) + diff = el.diff() + payload = Payload.okReply(LiveResponse.Diff(diff)) + _ <- outHub.publish(payload -> meta) + yield () + }.fork + serverFiber <- serverMsgStream.runForeach { (msg, meta) => + for + (modelVar, el) <- ref.get + updatedModel <- lv.update(modelVar.currentValue)(msg) + _ = modelVar.set(updatedModel) + diff = el.diff() + payload = Payload.Diff(diff) + _ <- outHub.publish(payload -> meta) + yield () + }.fork + stop = + inbox.shutdown *> outHub.shutdown *> clientFiber.interrupt.unit *> serverFiber.interrupt.unit + outbox = + ZStream.succeed( + Payload.okReply(LiveResponse.InitDiff(initDiff)) -> meta + ) ++ ZStream.unwrapScoped(ZStream.fromHubScoped(outHub)).filterNot { + case (Payload.Diff(diff), _) => diff.isEmpty + case _ => false + } + yield Socket[Msg, Model](id, token, inbox, outbox, stop) + end for + end start end Socket