diff --git a/build.mill b/build.mill index 0c4c86d..027e22e 100644 --- a/build.mill +++ b/build.mill @@ -53,6 +53,9 @@ object scalive extends Module: def mvnDeps = Seq(mvn"dev.zio::zio-http:3.4.0") def moduleDeps = Seq(core) + object test extends ScalaTests with scalalib.TestModule.ZioTest: + def zioTestVersion = "2.1.23" + object example extends ScalaCommon: def moduleDeps = Seq(scalive.zio) def mvnDeps = Seq(mvn"dev.optics::monocle-core:3.1.0", mvn"dev.zio::zio-logging:2.5.1") diff --git a/scalive/core/src/scalive/LiveView.scala b/scalive/core/src/scalive/LiveView.scala index f3c3bab..1ec868f 100644 --- a/scalive/core/src/scalive/LiveView.scala +++ b/scalive/core/src/scalive/LiveView.scala @@ -3,8 +3,12 @@ package scalive import zio.* import zio.stream.* +final case class LiveContext(staticChanged: Boolean) +object LiveContext: + def staticChanged: URIO[LiveContext, Boolean] = ZIO.serviceWith[LiveContext](_.staticChanged) + trait LiveView[Msg, Model]: - def init: Model | Task[Model] - def update(model: Model): Msg => Model | Task[Model] + def init: Model | RIO[LiveContext, Model] + def update(model: Model): Msg => Model | RIO[LiveContext, Model] def view(model: Dyn[Model]): HtmlElement - def subscriptions(model: Model): ZStream[Any, Nothing, Msg] + def subscriptions(model: Model): ZStream[LiveContext, Nothing, Msg] diff --git a/scalive/core/src/scalive/StaticTracking.scala b/scalive/core/src/scalive/StaticTracking.scala new file mode 100644 index 0000000..769bcb3 --- /dev/null +++ b/scalive/core/src/scalive/StaticTracking.scala @@ -0,0 +1,46 @@ +package scalive + +import scala.collection.mutable.ListBuffer + +import zio.json.* +import zio.json.ast.Json + +object StaticTracking: + private val attrName = phx.trackStatic.name + private val urlAttrNames = List(href.name, src.name) + + def collect(el: HtmlElement): List[String] = + val urls = ListBuffer.empty[String] + + def hasTrack(mods: Seq[Mod.Attr]): Boolean = + mods.exists { + case Mod.Attr.Static(`attrName`, _) => true + case Mod.Attr.StaticValueAsPresence(`attrName`, v) => v + case Mod.Attr.DynValueAsPresence(`attrName`, dyn) => dyn.currentValue + case _ => false + } + + def loop(node: HtmlElement): Unit = + val attrs = node.attrMods + if hasTrack(attrs) then + attrs.foreach { + case Mod.Attr.Static(name, value) if urlAttrNames.contains(name) => + urls += value + case Mod.Attr.Dyn(name, dyn, _) if urlAttrNames.contains(name) => + urls += dyn.currentValue + case _ => () + } + node.contentMods.foreach { + case Mod.Content.Tag(child) => loop(child) + case _ => () + } + + loop(el) + urls.toList + + def clientListFromParams(params: Option[Map[String, Json]]): Option[List[String]] = + params.flatMap(_.get("_track_static")).flatMap(_.as[List[String]].toOption) + + def staticChanged(client: Option[List[String]], server: List[String]): Boolean = + client.exists(_ != server) +end StaticTracking diff --git a/scalive/core/test/src/scalive/StaticTrackingSpec.scala b/scalive/core/test/src/scalive/StaticTrackingSpec.scala new file mode 100644 index 0000000..4855268 --- /dev/null +++ b/scalive/core/test/src/scalive/StaticTrackingSpec.scala @@ -0,0 +1,31 @@ +package scalive + +import utest.* +import zio.json.ast.Json + +object StaticTrackingSpec extends TestSuite: + val tests = Tests { + test("collects href and src from tracked tags") { + val el = div( + scriptTag(phx.trackStatic := true, src := "/static/app.js"), + linkTag(phx.trackStatic := true, href := "/static/app.css"), + div() + ) + + val urls = StaticTracking.collect(el) + assert(urls == List("/static/app.js", "/static/app.css")) + } + + test("extracts _track_static from params") { + val params = Map("_track_static" -> Json.Arr(Json.Str("/a.js"), Json.Str("/b.css"))) + assert(StaticTracking.clientListFromParams(Some(params)) == Some(List("/a.js", "/b.css"))) + assert(StaticTracking.clientListFromParams(None).isEmpty) + } + + test("detects static changes when lists differ") { + val server = List("/a.js", "/b.css") + assert(!StaticTracking.staticChanged(Some(server), server)) + assert(StaticTracking.staticChanged(Some(List("/a.js")), server)) + assert(!StaticTracking.staticChanged(None, server)) + } + } diff --git a/scalive/zio/src/scalive/LiveRouter.scala b/scalive/zio/src/scalive/LiveRouter.scala index c71f221..532c806 100644 --- a/scalive/zio/src/scalive/LiveRouter.scala +++ b/scalive/zio/src/scalive/LiveRouter.scala @@ -12,6 +12,7 @@ import zio.json.* import zio.stream.SubscriptionRef import zio.stream.ZStream +import scalive.* import scalive.WebSocketMessage.Meta import scalive.WebSocketMessage.Payload @@ -25,8 +26,9 @@ final case class LiveRoute[A, Msg, Model]( val id: String = s"phx-${Base64.getUrlEncoder().withoutPadding().encodeToString(Random().nextBytes(12))}" val token = Token.sign("secret", id, "") + val ctx = LiveContext(staticChanged = false) for - initModel <- normalize(lv.init) + initModel <- normalize(lv.init, ctx) el = lv.view(Var(initModel)) _ = el.syncAll() yield Response.html( @@ -57,6 +59,7 @@ class LiveChannel(private val sockets: SubscriptionRef[Map[String, Socket[?, ?]] id: String, token: String, lv: LiveView[Msg, Model], + ctx: LiveContext, meta: WebSocketMessage.Meta ): RIO[Scope, Unit] = sockets @@ -65,11 +68,11 @@ class LiveChannel(private val sockets: SubscriptionRef[Map[String, Socket[?, ?]] case Some(socket) => socket.shutdown *> Socket - .start(id, token, lv, meta) + .start(id, token, lv, ctx, meta) .map(m.updated(id, _)) case None => Socket - .start(id, token, lv, meta) + .start(id, token, lv, ctx, meta) .map(m.updated(id, _)) }.flatMap(_ => ZIO.logDebug(s"LiveView joined $id")) @@ -101,6 +104,8 @@ object LiveChannel: class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRoute[?, ?, ?]]): + private val trackedStatic = StaticTracking.collect(rootLayout(div())) + private val socketApp: WebSocketApp[Any] = Handler.webSocket { channel => ZIO @@ -154,7 +159,9 @@ class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRo : RIO[Scope, Option[WebSocketMessage]] = message.payload match case Payload.Heartbeat => ZIO.succeed(Some(message.okReply)) - case Payload.Join(url, redirect, session, static, sticky) => + case Payload.Join(url, redirect, session, static, params, sticky) => + val clientStatics = static.orElse(StaticTracking.clientListFromParams(params)) + val ctx = LiveContext(StaticTracking.staticChanged(clientStatics, trackedStatic)) ZIO .fromEither(URL.decode(url.orElse(redirect).getOrElse(???))).flatMap(url => val req = Request(url = url) @@ -164,9 +171,9 @@ class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRo .decode(req.path) .toOption .map(route.liveviewBuilder(_, req)) - .map( + .map(lv => ZIO.logDebug(s"Joining LiveView ${route.path.toString} ${message.topic}") *> - liveChannel.join(message.topic, session, _, message.meta) + liveChannel.join(message.topic, session, lv, ctx, message.meta) ) ) .collectFirst { case Some(join) => join.map(_ => None) } diff --git a/scalive/zio/src/scalive/Socket.scala b/scalive/zio/src/scalive/Socket.scala index 327aa44..55efcfc 100644 --- a/scalive/zio/src/scalive/Socket.scala +++ b/scalive/zio/src/scalive/Socket.scala @@ -20,6 +20,7 @@ object Socket: id: String, token: String, lv: LiveView[Msg, Model], + ctx: LiveContext, meta: WebSocketMessage.Meta ): RIO[Scope, Socket[Msg, Model]] = ZIO.logAnnotate("lv", id) { @@ -27,14 +28,15 @@ object Socket: inbox <- Queue.bounded[(Payload.Event, WebSocketMessage.Meta)](4) outHub <- Hub.unbounded[(Payload, WebSocketMessage.Meta)] - initModel <- normalize(lv.init) + initModel <- normalize(lv.init, ctx) modelVar = Var(initModel) el = lv.view(modelVar) ref <- Ref.make((modelVar, el)) initDiff = el.diff(trackUpdates = false) - lvStreamRef <- SubscriptionRef.make(lv.subscriptions(initModel)) + lvStreamRef <- + SubscriptionRef.make(lv.subscriptions(initModel).provideLayer(ZLayer.succeed(ctx))) clientMsgStream = ZStream.fromQueue(inbox) serverMsgStream = (ZStream.fromZIO(lvStreamRef.get) ++ lvStreamRef.changes) @@ -53,9 +55,11 @@ object Socket: ) ) updatedModel <- - normalize(lv.update(modelVar.currentValue)(f(event.params))) + normalize(lv.update(modelVar.currentValue)(f(event.params)), ctx) _ = modelVar.set(updatedModel) - _ <- lvStreamRef.set(lv.subscriptions(updatedModel)) + _ <- lvStreamRef.set( + lv.subscriptions(updatedModel).provideLayer(ZLayer.succeed(ctx)) + ) diff = el.diff() payload = Payload.okReply(LiveResponse.Diff(diff)) _ <- outHub.publish(payload -> meta) @@ -64,7 +68,7 @@ object Socket: serverFiber <- serverMsgStream.runForeach { (msg, meta) => for (modelVar, el) <- ref.get - updatedModel <- normalize(lv.update(modelVar.currentValue)(msg)) + updatedModel <- normalize(lv.update(modelVar.currentValue)(msg), ctx) _ = modelVar.set(updatedModel) diff = el.diff() payload = Payload.Diff(diff) diff --git a/scalive/zio/src/scalive/WebSocketMessage.scala b/scalive/zio/src/scalive/WebSocketMessage.scala index ca88db1..936a390 100644 --- a/scalive/zio/src/scalive/WebSocketMessage.scala +++ b/scalive/zio/src/scalive/WebSocketMessage.scala @@ -82,7 +82,8 @@ object WebSocketMessage: redirect: Option[String], // params: Map[String, String], session: String, - static: Option[String], + static: Option[List[String]], + params: Option[Map[String, Json]], sticky: Boolean) case Leave case Close diff --git a/scalive/zio/src/scalive/ZIOHelpers.scala b/scalive/zio/src/scalive/ZIOHelpers.scala index c791944..8f4ee69 100644 --- a/scalive/zio/src/scalive/ZIOHelpers.scala +++ b/scalive/zio/src/scalive/ZIOHelpers.scala @@ -2,7 +2,7 @@ package scalive import zio.* -private def normalize[A](value: A | Task[A]): Task[A] = +private def normalize[A](value: A | RIO[LiveContext, A], ctx: LiveContext): Task[A] = value match - case t: Task[?] @unchecked => t.asInstanceOf[Task[A]] - case v => ZIO.succeed(v.asInstanceOf[A]) + case t: ZIO[LiveContext, Throwable, A] @unchecked => t.provide(ZLayer.succeed(ctx)) + case v => ZIO.succeed(v.asInstanceOf[A]) diff --git a/scalive/zio/test/src/scalive/SocketSpec.scala b/scalive/zio/test/src/scalive/SocketSpec.scala new file mode 100644 index 0000000..8c15845 --- /dev/null +++ b/scalive/zio/test/src/scalive/SocketSpec.scala @@ -0,0 +1,76 @@ +package scalive + +import zio.* +import zio.stream.ZStream +import zio.test.* + +import scalive.WebSocketMessage.LiveResponse +import scalive.WebSocketMessage.Payload + +object SocketSpec extends ZIOSpecDefault: + + enum Msg: + case FromClient + case FromServer + + final case class Model(counter: Int = 0, staticFlag: Option[Boolean] = None) + + private val meta = WebSocketMessage.Meta(None, None, topic = "t", eventType = "event") + + private def makeLiveView(serverStream: ZStream[LiveContext, Nothing, Msg]) = + new LiveView[Msg, Model]: + def init: Model | RIO[LiveContext, Model] = + LiveContext.staticChanged.map(flag => Model(staticFlag = Some(flag))) + + def update(model: Model): Msg => Model | RIO[LiveContext, Model] = { + case Msg.FromClient => ZIO.succeed(model.copy(counter = model.counter + 1)) + case Msg.FromServer => ZIO.succeed(model.copy(counter = model.counter + 10)) + } + + def view(model: Dyn[Model]): HtmlElement = + div( + idAttr := "root", + phx.onClick(Msg.FromClient), + model(_.counter.toString) + ) + + def subscriptions(model: Model): ZStream[LiveContext, Nothing, Msg] = serverStream + + private def makeSocket(ctx: LiveContext, lv: LiveView[Msg, Model]) = + Socket.start("id", "token", lv, ctx, meta) + + override def spec = suite("SocketSpec")( + test("emits init diff and uses LiveContext") { + val ctx = LiveContext(staticChanged = true) + val lv = makeLiveView(ZStream.empty) + for + socket <- makeSocket(ctx, lv) + msgs <- socket.outbox.take(1).runHead + yield assertTrue( + msgs.size == 1, + msgs.head._1 match + case Payload.Reply("ok", LiveResponse.InitDiff(_)) => true + case _ => false + , + msgs.head._2 == meta + ) + }, + test("server stream emits diff") { + val ctx = LiveContext(staticChanged = false) + val lv = makeLiveView(ZStream.succeed(Msg.FromServer)) + for + socket <- makeSocket(ctx, lv) + diff <- socket.outbox.drop(1).runHead.some + yield assertTrue(diff._1.isInstanceOf[Payload.Diff]) + }, + test("shutdown stops outbox") { + val ctx = LiveContext(staticChanged = false) + val lv = makeLiveView(ZStream.empty) + for + socket <- makeSocket(ctx, lv) + _ <- socket.shutdown + res <- socket.outbox.runCollect + yield assertTrue(res.nonEmpty) + } + ) +end SocketSpec