From fcc5f1799e47d2b04493517c4c7d7d5144afa31e Mon Sep 17 00:00:00 2001 From: Paul-Henri Froidmont Date: Wed, 3 Sep 2025 04:14:50 +0200 Subject: [PATCH] Stream events and responses --- core/src/TestLiveView.scala | 10 +- core/src/main.scala | 46 ++++--- core/src/scalive/LiveView.scala | 11 +- core/src/scalive/Socket.scala | 39 ------ core/src/scalive/WebSocketMessage.scala | 9 +- core/test/src/scalive/LiveViewSpec.scala | 49 +++---- example/src/ExampleLiveView.scala | 17 ++- zio/src/scalive/LiveRouter.scala | 160 +++++++++++++++-------- zio/src/scalive/Socket.scala | 45 +++++++ 9 files changed, 233 insertions(+), 153 deletions(-) delete mode 100644 core/src/scalive/Socket.scala create mode 100644 zio/src/scalive/Socket.scala diff --git a/core/src/TestLiveView.scala b/core/src/TestLiveView.scala index 0cad63c..2781ffc 100644 --- a/core/src/TestLiveView.scala +++ b/core/src/TestLiveView.scala @@ -1,3 +1,6 @@ +package scalive +package playground + import scalive.* final case class MyModel( @@ -6,14 +9,13 @@ final case class MyModel( elems: List[Elem] = List.empty) final case class Elem(name: String, age: Int) -class TestView(initialModel: MyModel) extends LiveView[String, TestView.Event]: +class TestView(initialModel: MyModel) extends LiveView[TestView.Event]: import TestView.Event.* private val modelVar = Var[MyModel](initialModel) - override def handleServerEvent(e: TestView.Event): Unit = - e match - case UpdateModel(f) => modelVar.update(f) + def handleEvent = + case UpdateModel(f) => modelVar.update(f) val el: HtmlElement = div( diff --git a/core/src/main.scala b/core/src/main.scala index f2f9e32..83a4e92 100644 --- a/core/src/main.scala +++ b/core/src/main.scala @@ -1,5 +1,12 @@ +package scalive +package playground + import scalive.* -import zio.json.JsonCodec +import zio.json.* + +extension (lv: LiveView[?]) + def renderHtml: String = + HtmlBuilder.build(lv.el) @main def main = @@ -10,20 +17,19 @@ def main = Elem("c", 30) ) ) - val s = Socket("", "", TestView(initModel)) + val lv = TestView(initModel) println("Init") - println(s.renderHtml()) - s.syncClient - s.syncClient + println(lv.renderHtml) + println(lv.diff().toJsonPretty) println("Edit class attribue") - s.lv.handleServerEvent( + lv.handleEvent( TestView.Event.UpdateModel(_.copy(cls = "text-lg")) ) - s.syncClient + println(lv.diff().toJsonPretty) println("Edit first and last") - s.lv.handleServerEvent( + lv.handleEvent( TestView.Event.UpdateModel( _.copy(elems = List( @@ -34,11 +40,11 @@ def main = ) ) ) - s.syncClient - println(s.renderHtml()) + println(lv.diff().toJsonPretty) + println(lv.diff().toJsonPretty) println("Add one") - s.lv.handleServerEvent( + lv.handleEvent( TestView.Event.UpdateModel( _.copy(elems = List( @@ -50,11 +56,11 @@ def main = ) ) ) - s.syncClient - println(s.renderHtml()) + println(lv.diff().toJsonPretty) + println(lv.renderHtml) println("Remove first") - s.lv.handleServerEvent( + lv.handleEvent( TestView.Event.UpdateModel( _.copy(elems = List( @@ -65,11 +71,11 @@ def main = ) ) ) - s.syncClient - println(s.renderHtml()) + println(lv.diff().toJsonPretty) + println(lv.renderHtml) println("Remove all") - s.lv.handleServerEvent( + lv.handleEvent( TestView.Event.UpdateModel( _.copy( cls = "text-lg", @@ -78,7 +84,7 @@ def main = ) ) ) - s.syncClient - s.syncClient - println(s.renderHtml()) + println(lv.diff().toJsonPretty) + println(lv.diff().toJsonPretty) + println(lv.renderHtml) end main diff --git a/core/src/scalive/LiveView.scala b/core/src/scalive/LiveView.scala index d646d6a..12e7dd6 100644 --- a/core/src/scalive/LiveView.scala +++ b/core/src/scalive/LiveView.scala @@ -1,6 +1,11 @@ package scalive -trait LiveView[ClientEvt, ServerEvent]: - def handleClientEvent(evt: ClientEvt): Unit = () - def handleServerEvent(evt: ServerEvent): Unit = () +trait LiveView[Event]: + def handleEvent: Event => Unit val el: HtmlElement + + private[scalive] def diff(trackUpdates: Boolean = true): Diff = + el.syncAll() + val diff = DiffBuilder.build(el, trackUpdates = trackUpdates) + el.setAllUnchanged() + diff diff --git a/core/src/scalive/Socket.scala b/core/src/scalive/Socket.scala deleted file mode 100644 index 6a29e4f..0000000 --- a/core/src/scalive/Socket.scala +++ /dev/null @@ -1,39 +0,0 @@ -package scalive - -import zio.json.* - -final case class Socket[CliEvt: JsonCodec, SrvEvt]( - id: String, - token: String, - lv: LiveView[CliEvt, SrvEvt]): - val clientEventCodec = JsonCodec[CliEvt] - - private var clientInitialized = false - - lv.el.syncAll() - - def renderHtml(rootLayout: HtmlElement => HtmlElement = identity): String = - lv.el.syncAll() - HtmlBuilder.build( - rootLayout( - div( - idAttr := id, - phx.session := token, - lv.el - ) - ) - ) - - def syncClient: Unit = - lv.el.syncAll() - println(DiffBuilder.build(lv.el, trackUpdates = clientInitialized).toJsonPretty) - clientInitialized = true - lv.el.setAllUnchanged() - - def diff: Diff = - lv.el.syncAll() - val diff = DiffBuilder.build(lv.el, trackUpdates = clientInitialized) - clientInitialized = true - lv.el.setAllUnchanged() - diff -end Socket diff --git a/core/src/scalive/WebSocketMessage.scala b/core/src/scalive/WebSocketMessage.scala index cabe74a..3d215ef 100644 --- a/core/src/scalive/WebSocketMessage.scala +++ b/core/src/scalive/WebSocketMessage.scala @@ -15,8 +15,15 @@ final case class WebSocketMessage( // LiveView instance id topic: String, eventType: String, - payload: WebSocketMessage.Payload) + payload: WebSocketMessage.Payload): + val meta = WebSocketMessage.Meta(joinRef, messageRef, topic) object WebSocketMessage: + + final case class Meta( + joinRef: Option[Int], + messageRef: Int, + topic: String) + given JsonCodec[WebSocketMessage] = JsonCodec[Json].transformOrFail( { case Json.Arr( diff --git a/core/test/src/scalive/LiveViewSpec.scala b/core/test/src/scalive/LiveViewSpec.scala index 2036b18..ac516e9 100644 --- a/core/test/src/scalive/LiveViewSpec.scala +++ b/core/test/src/scalive/LiveViewSpec.scala @@ -27,8 +27,9 @@ object LiveViewSpec extends TestSuite: test("Static only") { val lv = - new LiveView[String, Unit]: - val el = div("Static string") + new LiveView[Nothing]: + val el = div("Static string") + def handleEvent = _ => () lv.el.syncAll() test("init") { @@ -47,14 +48,14 @@ object LiveViewSpec extends TestSuite: test("Dynamic string") { val lv = - new LiveView[UpdateEvent, Nothing]: + new LiveView[UpdateEvent]: val model = Var(TestModel()) val el = div( h1(model(_.title)), p(model(_.otherString)) ) - override def handleClientEvent(evt: UpdateEvent): Unit = model.update(evt.f) + def handleEvent = evt => model.update(evt.f) lv.el.syncAll() lv.el.setAllUnchanged() @@ -75,19 +76,19 @@ object LiveViewSpec extends TestSuite: assertEqualsDiff(lv.el, emptyDiff) } test("diff with update") { - lv.handleClientEvent(UpdateEvent(_.copy(title = "title updated"))) + lv.handleEvent(UpdateEvent(_.copy(title = "title updated"))) assertEqualsDiff( lv.el, Json.Obj("0" -> Json.Str("title updated")) ) } test("diff with update and no change") { - lv.handleClientEvent(UpdateEvent(_.copy(title = "title value"))) + lv.handleEvent(UpdateEvent(_.copy(title = "title value"))) assertEqualsDiff(lv.el, emptyDiff) } test("diff with update in multiple commands") { - lv.handleClientEvent(UpdateEvent(_.copy(title = "title updated"))) - lv.handleClientEvent(UpdateEvent(_.copy(otherString = "other string updated"))) + lv.handleEvent(UpdateEvent(_.copy(title = "title updated"))) + lv.handleEvent(UpdateEvent(_.copy(otherString = "other string updated"))) assertEqualsDiff( lv.el, Json @@ -101,11 +102,11 @@ object LiveViewSpec extends TestSuite: test("Dynamic attribute") { val lv = - new LiveView[UpdateEvent, Nothing]: + new LiveView[UpdateEvent]: val model = Var(TestModel()) val el = div(cls := model(_.cls)) - override def handleClientEvent(evt: UpdateEvent): Unit = model.update(evt.f) + def handleEvent = evt => model.update(evt.f) lv.el.syncAll() lv.el.setAllUnchanged() @@ -126,7 +127,7 @@ object LiveViewSpec extends TestSuite: assertEqualsDiff(lv.el, emptyDiff) } test("diff with update") { - lv.handleClientEvent(UpdateEvent(_.copy(cls = "text-md"))) + lv.handleEvent(UpdateEvent(_.copy(cls = "text-md"))) assertEqualsDiff( lv.el, Json.Obj("0" -> Json.Str("text-md")) @@ -136,7 +137,7 @@ object LiveViewSpec extends TestSuite: test("when mod") { val lv = - new LiveView[UpdateEvent, Nothing]: + new LiveView[UpdateEvent]: val model = Var(TestModel()) val el = div( @@ -144,7 +145,7 @@ object LiveViewSpec extends TestSuite: div("static string", model(_.nestedTitle)) ) ) - override def handleClientEvent(evt: UpdateEvent): Unit = model.update(evt.f) + def handleEvent = evt => model.update(evt.f) lv.el.syncAll() lv.el.setAllUnchanged() @@ -164,11 +165,11 @@ object LiveViewSpec extends TestSuite: assertEqualsDiff(lv.el, emptyDiff) } test("diff with unrelated update") { - lv.handleClientEvent(UpdateEvent(_.copy(title = "title updated"))) + lv.handleEvent(UpdateEvent(_.copy(title = "title updated"))) assertEqualsDiff(lv.el, emptyDiff) } test("diff when true and nested update") { - lv.handleClientEvent(UpdateEvent(_.copy(bool = true))) + lv.handleEvent(UpdateEvent(_.copy(bool = true))) assertEqualsDiff( lv.el, Json.Obj( @@ -183,10 +184,10 @@ object LiveViewSpec extends TestSuite: ) } test("diff when nested change") { - lv.handleClientEvent(UpdateEvent(_.copy(bool = true))) + lv.handleEvent(UpdateEvent(_.copy(bool = true))) lv.el.syncAll() lv.el.setAllUnchanged() - lv.handleClientEvent(UpdateEvent(_.copy(bool = true, nestedTitle = "nested title updated"))) + lv.handleEvent(UpdateEvent(_.copy(bool = true, nestedTitle = "nested title updated"))) assertEqualsDiff( lv.el, Json.Obj( @@ -209,7 +210,7 @@ object LiveViewSpec extends TestSuite: ) ) val lv = - new LiveView[UpdateEvent, Nothing]: + new LiveView[UpdateEvent]: val model = Var(initModel) val el = div( @@ -224,7 +225,7 @@ object LiveViewSpec extends TestSuite: ) ) ) - override def handleClientEvent(evt: UpdateEvent): Unit = model.update(evt.f) + def handleEvent = evt => model.update(evt.f) lv.el.syncAll() lv.el.setAllUnchanged() @@ -265,11 +266,11 @@ object LiveViewSpec extends TestSuite: assertEqualsDiff(lv.el, emptyDiff) } test("diff with unrelated update") { - lv.handleClientEvent(UpdateEvent(_.copy(title = "title updated"))) + lv.handleEvent(UpdateEvent(_.copy(title = "title updated"))) assertEqualsDiff(lv.el, emptyDiff) } test("diff with item changed") { - lv.handleClientEvent( + lv.handleEvent( UpdateEvent(_.copy(items = initModel.items.updated(2, NestedModel("c", 99)))) ) assertEqualsDiff( @@ -289,7 +290,7 @@ object LiveViewSpec extends TestSuite: ) } test("diff with item added") { - lv.handleClientEvent( + lv.handleEvent( UpdateEvent( _.copy(items = initModel.items.appended(NestedModel("d", 35))) ) @@ -312,7 +313,7 @@ object LiveViewSpec extends TestSuite: ) } test("diff with first item removed") { - lv.handleClientEvent( + lv.handleEvent( UpdateEvent( _.copy(items = initModel.items.tail) ) @@ -339,7 +340,7 @@ object LiveViewSpec extends TestSuite: ) } test("diff all removed") { - lv.handleClientEvent(UpdateEvent(_.copy(items = List.empty))) + lv.handleEvent(UpdateEvent(_.copy(items = List.empty))) assertEqualsDiff( lv.el, Json.Obj( diff --git a/example/src/ExampleLiveView.scala b/example/src/ExampleLiveView.scala index 4c84290..60b77d6 100644 --- a/example/src/ExampleLiveView.scala +++ b/example/src/ExampleLiveView.scala @@ -1,4 +1,4 @@ -import ExampleLiveView.Evt +import ExampleLiveView.Event import monocle.syntax.all.* import scalive.* import zio.json.* @@ -6,7 +6,7 @@ import zio.json.* final case class ExampleModel(elems: List[NestedModel], cls: String = "text-xs") final case class NestedModel(name: String, age: Int) -class ExampleLiveView(someParam: String) extends LiveView[Evt, String]: +class ExampleLiveView(someParam: String) extends LiveView[Event]: val model = Var( ExampleModel( @@ -18,10 +18,9 @@ class ExampleLiveView(someParam: String) extends LiveView[Evt, String]: ) ) - override def handleClientEvent(evt: Evt): Unit = - evt match - case Evt.IncAge(value) => - model.update(_.focus(_.elems.index(2).age).modify(_ + value)) + def handleEvent = + case Event.Event(value) => + model.update(_.focus(_.elems.index(2).age).modify(_ + value)) val el = div( @@ -40,12 +39,12 @@ class ExampleLiveView(someParam: String) extends LiveView[Evt, String]: ) ), button( - phx.click := Evt.IncAge(1), + phx.click := Event.Event(1), "Inc age" ) ) end ExampleLiveView object ExampleLiveView: - enum Evt derives JsonCodec: - case IncAge(value: Int) + enum Event derives JsonCodec: + case Event(value: Int) diff --git a/zio/src/scalive/LiveRouter.scala b/zio/src/scalive/LiveRouter.scala index d81b619..b0b6f06 100644 --- a/zio/src/scalive/LiveRouter.scala +++ b/zio/src/scalive/LiveRouter.scala @@ -1,6 +1,7 @@ package scalive import scalive.WebSocketMessage.LiveResponse +import scalive.WebSocketMessage.Meta import scalive.WebSocketMessage.Payload import zio.* import zio.http.* @@ -8,15 +9,16 @@ import zio.http.ChannelEvent.Read import zio.http.codec.PathCodec import zio.http.template.Html import zio.json.* +import zio.stream.SubscriptionRef +import zio.stream.ZStream import java.util.Base64 -import scala.collection.mutable import scala.util.Random -final case class LiveRoute[A, ClientEvt: JsonCodec, ServerEvt]( +final case class LiveRoute[A, Event: JsonCodec]( path: PathCodec[A], - liveviewBuilder: (A, Request) => LiveView[ClientEvt, ServerEvt]): - val clientEventCodec = JsonCodec[ClientEvt] + liveviewBuilder: (A, Request) => LiveView[Event]): + val eventCodec = JsonCodec[Event] def toZioRoute(rootLayout: HtmlElement => HtmlElement): Route[Any, Nothing] = Method.GET / path -> handler { (params: A, req: Request) => @@ -41,63 +43,112 @@ final case class LiveRoute[A, ClientEvt: JsonCodec, ServerEvt]( ) } -class LiveChannel(semaphore: Semaphore): - private val sockets: mutable.Map[String, Socket[?, ?]] = mutable.Map.empty +class LiveChannel(private val sockets: SubscriptionRef[Map[String, Socket[?]]]): + def diffsStream: ZStream[Any, Nothing, (Diff, Meta)] = + sockets.changes + .map(m => + ZStream + .mergeAllUnbounded()( + m.values + .map(_.outbox).map(ZStream.fromHub(_)).toList* + ) + ).flatMapParSwitch(1, 1)(identity) + + def join[Event: JsonCodec]( + id: String, + token: String, + lv: LiveView[Event], + meta: WebSocketMessage.Meta + ): URIO[Scope, Unit] = + sockets.updateZIO { m => + m.get(id) match + case Some(socket) => + socket.shutdown *> + Socket + .start(id, token, lv, meta) + .map(m.updated(id, _)) + case None => + Socket + .start(id, token, lv, meta) + .map(m.updated(id, _)) - // TODO should check id isn't already present - def join[ClientEvt: JsonCodec](id: String, token: String, lv: LiveView[ClientEvt, ?]): UIO[Diff] = - semaphore.withPermit { - ZIO.succeed { - val socket = Socket(id, token, lv) - sockets.addOne(id, socket) - socket.diff - } } - // TODO handle missing id - def event(id: String, value: String): UIO[Diff] = - semaphore.withPermit { - ZIO.succeed { - val s = sockets(id) - s.lv.handleClientEvent( - value - .fromJson(using s.clientEventCodec.decoder).getOrElse( - throw new IllegalArgumentException() + def event(id: String, value: String, meta: WebSocketMessage.Meta): UIO[Unit] = + sockets.get.map { m => + m.get(id) match + case Some(socket) => + socket.inbox + .offer( + value + .fromJson(using socket.clientEventCodec.decoder) + .getOrElse(throw new IllegalArgumentException()) + -> meta ) - ) - s.diff - } - } + case None => ZIO.unit + }.unit + +end LiveChannel object LiveChannel: def make(): UIO[LiveChannel] = - Semaphore.make(permits = 1).map(new LiveChannel(_)) + SubscriptionRef.make(Map.empty).map(new LiveChannel(_)) -class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRoute[?, ?, ?]]): +class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRoute[?, ?]]): private val socketApp: WebSocketApp[Any] = Handler.webSocket { channel => - LiveChannel - .make().flatMap(liveChannel => - channel - .receiveAll { - case Read(WebSocketFrame.Text(content)) => - for - message <- ZIO - .fromEither(content.fromJson[WebSocketMessage]) - .mapError(new IllegalArgumentException(_)) - reply <- handleMessage(message, liveChannel) - _ <- channel.send(Read(WebSocketFrame.text(reply.toJson))) - yield () - case _ => ZIO.unit - }.tapErrorCause(ZIO.logErrorCause(_)) - ) + ZIO.scoped(for + liveChannel <- LiveChannel.make() + _ <- liveChannel.diffsStream + .foreach((diff, meta) => + channel.send( + Read( + WebSocketFrame.text( + WebSocketMessage( + meta.joinRef, + meta.messageRef, + meta.topic, + "phx_reply", + Payload.Reply("OK", LiveResponse.Diff(diff)) + ).toJson + ) + ) + ) + ).fork + _ <- channel + .receiveAll { + case Read(WebSocketFrame.Text(content)) => + for + message <- ZIO + .fromEither(content.fromJson[WebSocketMessage]) + .mapError(new IllegalArgumentException(_)) + reply <- handleMessage(message, liveChannel) + _ <- reply match + case Some(r) => channel.send(Read(WebSocketFrame.text(r.toJson))) + case None => ZIO.unit + yield () + case _ => ZIO.unit + }.tapErrorCause(ZIO.logErrorCause(_)) + yield ()) + } private def handleMessage(message: WebSocketMessage, liveChannel: LiveChannel) - : Task[WebSocketMessage] = - val reply = message.payload match - case Payload.Heartbeat => ZIO.succeed(Payload.Reply("ok", LiveResponse.Empty)) + : RIO[Scope, Option[WebSocketMessage]] = + message.payload match + case Payload.Heartbeat => + ZIO.succeed( + Some( + WebSocketMessage( + message.joinRef, + message.messageRef, + message.topic, + "phx_reply", + Payload.Reply("ok", LiveResponse.Empty) + ) + ) + ) case Payload.Join(url, session, static, sticky) => ZIO .fromEither(URL.decode(url)).flatMap(url => @@ -106,17 +157,20 @@ class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRo .collectFirst { route => val pathParams = route.path.decode(req.path).getOrElse(???) val lv = route.liveviewBuilder(pathParams, req) - liveChannel.join(message.topic, session, lv)(using route.clientEventCodec) + liveChannel + .join(message.topic, session, lv, message.meta)(using route.eventCodec) + .map(_ => None) - }.getOrElse(???) - ).map(diff => Payload.Reply("ok", LiveResponse.InitDiff(diff))) + }.getOrElse(ZIO.succeed(None)) + ) case Payload.Event(_, event, _) => liveChannel - .event(message.topic, event) - .map(diff => Payload.Reply("ok", LiveResponse.Diff(diff))) + .event(message.topic, event, message.meta) + .map(_ => None) case Payload.Reply(_, _) => ZIO.die(new IllegalArgumentException()) + end match - reply.map(WebSocketMessage(message.joinRef, message.messageRef, message.topic, "phx_reply", _)) + end handleMessage val routes: Routes[Any, Response] = Routes.fromIterable( diff --git a/zio/src/scalive/Socket.scala b/zio/src/scalive/Socket.scala new file mode 100644 index 0000000..d3aa3ac --- /dev/null +++ b/zio/src/scalive/Socket.scala @@ -0,0 +1,45 @@ +package scalive + +import zio.* +import zio.json.* +import zio.Queue +import zio.stream.ZStream + +final case class Socket[Event: JsonCodec] private ( + id: String, + token: String, + // lv: LiveView[CliEvt, SrvEvt], + inbox: Queue[(Event, WebSocketMessage.Meta)], + outbox: Hub[(Diff, WebSocketMessage.Meta)], + fiber: Fiber.Runtime[Nothing, Unit], + shutdown: UIO[Unit]): + val clientEventCodec = JsonCodec[Event] + +object Socket: + def start[Event: JsonCodec]( + id: String, + token: String, + lv: LiveView[Event], + meta: WebSocketMessage.Meta + ): URIO[Scope, Socket[Event]] = + for + inbox <- Queue.bounded[(Event, WebSocketMessage.Meta)](4) + outbox <- Hub.bounded[(Diff, WebSocketMessage.Meta)](4) + initDiff = lv.diff(trackUpdates = false) + _ <- outbox.publish(initDiff -> meta).unit + _ <- outbox.size.flatMap(s => ZIO.log(s.toString)) // FIXME + lvRef <- Ref.make(lv) + fiber <- ZStream + .fromQueue(inbox) + .mapZIO { (msg, meta) => + for + lv <- lvRef.get + _ = lv.handleEvent(msg) + diff = lv.diff() + _ <- outbox.publish(diff -> meta) + yield () + } + .runDrain + .forkScoped + stop = inbox.shutdown *> outbox.shutdown *> fiber.interrupt.unit + yield Socket[Event](id, token, inbox, outbox, fiber, stop)