Implement subscriptions

This commit is contained in:
Paul-Henri Froidmont 2025-09-11 17:30:26 +02:00
parent 08036ab5aa
commit 5da0b64c3e
Signed by: phfroidmont
GPG key ID: BE948AFD7E7873BE
7 changed files with 103 additions and 45 deletions

View file

@ -3,6 +3,7 @@ package playground
import scalive.* import scalive.*
import zio.* import zio.*
import zio.stream.ZStream
import TestView.* import TestView.*
class TestView extends LiveView[Msg, Model]: class TestView extends LiveView[Msg, Model]:
@ -29,6 +30,8 @@ class TestView extends LiveView[Msg, Model]:
) )
) )
def subscriptions(model: Model) = ZStream.empty
object TestView: object TestView:
enum Msg: enum Msg:

View file

@ -16,6 +16,14 @@ enum Diff:
case Dynamic(key: String, diff: Diff) case Dynamic(key: String, diff: Diff)
case Deleted case Deleted
extension (diff: Diff)
def isEmpty: Boolean = diff match
case Diff.Tag(static, dynamic) => static.isEmpty && dynamic.isEmpty
case _: Diff.Comprehension => false
case _: Diff.Value => false
case _: Diff.Dynamic => false
case Diff.Deleted => false
object Diff: object Diff:
given JsonEncoder[Diff] = JsonEncoder[Json].contramap(toJson(_)) given JsonEncoder[Diff] = JsonEncoder[Json].contramap(toJson(_))

View file

@ -1,9 +1,10 @@
package scalive package scalive
import zio.* import zio.*
import zio.stream.*
trait LiveView[Msg, Model]: trait LiveView[Msg, Model]:
def init: Task[Model] def init: Task[Model]
def update(model: Model): Msg => Task[Model] def update(model: Model): Msg => Task[Model]
def view(model: Dyn[Model]): HtmlElement def view(model: Dyn[Model]): HtmlElement
// def subscriptions(model: Model): ZStream[Any, Nothing, Msg] def subscriptions(model: Model): ZStream[Any, Nothing, Msg]

View file

@ -11,18 +11,19 @@ final case class WebSocketMessage(
// Live session ID, auto increment defined by the client on join // Live session ID, auto increment defined by the client on join
joinRef: Option[Int], joinRef: Option[Int],
// Message ID, global auto increment defined by the client on every message // Message ID, global auto increment defined by the client on every message
messageRef: Int, messageRef: Option[Int],
// LiveView instance id // LiveView instance id
topic: String, topic: String,
eventType: String, eventType: String,
payload: WebSocketMessage.Payload): payload: WebSocketMessage.Payload):
val meta = WebSocketMessage.Meta(joinRef, messageRef, topic) val meta = WebSocketMessage.Meta(joinRef, messageRef, topic, eventType)
object WebSocketMessage: object WebSocketMessage:
final case class Meta( final case class Meta(
joinRef: Option[Int], joinRef: Option[Int],
messageRef: Int, messageRef: Option[Int],
topic: String) topic: String,
eventType: String)
given JsonCodec[WebSocketMessage] = JsonCodec[Json].transformOrFail( given JsonCodec[WebSocketMessage] = JsonCodec[Json].transformOrFail(
{ {
@ -38,7 +39,7 @@ object WebSocketMessage:
payloadParsed.map( payloadParsed.map(
WebSocketMessage( WebSocketMessage(
joinRef.asString.map(_.toInt), joinRef.asString.map(_.toInt),
messageRef.toInt, Some(messageRef.toInt),
topic, topic,
eventType, eventType,
_ _
@ -49,14 +50,15 @@ object WebSocketMessage:
m => m =>
Json.Arr( Json.Arr(
m.joinRef.map(ref => Json.Str(ref.toString)).getOrElse(Json.Null), m.joinRef.map(ref => Json.Str(ref.toString)).getOrElse(Json.Null),
Json.Str(m.messageRef.toString), m.messageRef.map(ref => Json.Str(ref.toString)).getOrElse(Json.Null),
Json.Str(m.topic), Json.Str(m.topic),
Json.Str(m.eventType), Json.Str(m.eventType),
m.payload.match m.payload match
case Payload.Heartbeat => Json.Obj.empty case Payload.Heartbeat => Json.Obj.empty
case p: Payload.Join => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) case p: Payload.Join => p.toJsonAST.getOrElse(throw new IllegalArgumentException())
case p: Payload.Reply => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) case p: Payload.Reply => p.toJsonAST.getOrElse(throw new IllegalArgumentException())
case p: Payload.Event => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) case p: Payload.Event => p.toJsonAST.getOrElse(throw new IllegalArgumentException())
case p: Payload.Diff => p.toJsonAST.getOrElse(throw new IllegalArgumentException())
) )
) )
@ -69,11 +71,16 @@ object WebSocketMessage:
static: Option[String], static: Option[String],
sticky: Boolean) sticky: Boolean)
case Reply(status: String, response: LiveResponse) case Reply(status: String, response: LiveResponse)
case Diff(diff: scalive.Diff)
case Event(`type`: Payload.EventType, event: String, value: Map[String, String]) case Event(`type`: Payload.EventType, event: String, value: Map[String, String])
object Payload: object Payload:
given JsonCodec[Payload.Join] = JsonCodec.derived given JsonCodec[Payload.Join] = JsonCodec.derived
given JsonEncoder[Payload.Reply] = JsonEncoder.derived given JsonEncoder[Payload.Reply] = JsonEncoder.derived
given JsonCodec[Payload.Event] = JsonCodec.derived given JsonCodec[Payload.Event] = JsonCodec.derived
given JsonEncoder[Payload.Diff] = JsonEncoder[scalive.Diff].contramap(_.diff)
def okReply(response: LiveResponse) =
Payload.Reply("ok", response)
enum EventType: enum EventType:
case Click case Click

View file

@ -1,9 +1,10 @@
import ExampleLiveView.*
import monocle.syntax.all.* import monocle.syntax.all.*
import scalive.* import scalive.*
import zio.* import zio.*
import zio.json.* import zio.json.*
import zio.stream.ZStream
import ExampleLiveView.*
class ExampleLiveView(someParam: String) extends LiveView[Msg, Model]: class ExampleLiveView(someParam: String) extends LiveView[Msg, Model]:
def init = ZIO.succeed( def init = ZIO.succeed(
@ -65,6 +66,10 @@ class ExampleLiveView(someParam: String) extends LiveView[Msg, Model]:
) )
) )
) )
def subscriptions(model: Model) =
ZStream.tick(1.second).map(_ => Msg.IncCounter).drop(1)
end ExampleLiveView end ExampleLiveView
object ExampleLiveView: object ExampleLiveView:

View file

@ -48,12 +48,12 @@ final case class LiveRoute[A, Msg: JsonCodec, Model](
end LiveRoute end LiveRoute
class LiveChannel(private val sockets: SubscriptionRef[Map[String, Socket[?, ?]]]): class LiveChannel(private val sockets: SubscriptionRef[Map[String, Socket[?, ?]]]):
def diffsStream: ZStream[Any, Nothing, (LiveResponse, Meta)] = def diffsStream: ZStream[Any, Nothing, (Payload, Meta)] =
sockets.changes sockets.changes
.map(m => .map(m =>
ZStream ZStream
.mergeAllUnbounded()(m.values.map(_.outbox).toList*) .mergeAllUnbounded()(m.values.map(_.outbox).toList*)
).flatMapParSwitch(1, 1)(identity) ).flatMapParSwitch(1)(identity)
def join[Msg: JsonCodec, Model]( def join[Msg: JsonCodec, Model](
id: String, id: String,
@ -102,21 +102,27 @@ class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRo
ZIO.scoped(for ZIO.scoped(for
liveChannel <- LiveChannel.make() liveChannel <- LiveChannel.make()
_ <- liveChannel.diffsStream _ <- liveChannel.diffsStream
.foreach((diff, meta) => .runForeach((payload, meta) =>
channel.send( channel
Read( .send(
WebSocketFrame.text( Read(
WebSocketMessage( WebSocketFrame.text(
meta.joinRef, WebSocketMessage(
meta.messageRef, meta.joinRef,
meta.topic, meta.messageRef,
"phx_reply", meta.topic,
Payload.Reply("ok", diff) payload match
).toJson case Payload.Diff(_) => "diff"
case _ => "phx_reply",
payload
).toJson
)
) )
) )
) )
).fork .tapErrorCause(c => ZIO.logErrorCause("diffsStream pipeline failed", c))
.ensuring(ZIO.logWarning("WS out fiber terminated"))
.fork
_ <- channel _ <- channel
.receiveAll { .receiveAll {
case Read(WebSocketFrame.Text(content)) => case Read(WebSocketFrame.Text(content)) =>
@ -169,6 +175,7 @@ class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRo
.event(message.topic, event, message.meta) .event(message.topic, event, message.meta)
.map(_ => None) .map(_ => None)
case Payload.Reply(_, _) => ZIO.die(new IllegalArgumentException()) case Payload.Reply(_, _) => ZIO.die(new IllegalArgumentException())
case Payload.Diff(_) => ZIO.die(new IllegalArgumentException())
end match end match
end handleMessage end handleMessage

View file

@ -5,13 +5,14 @@ import zio.*
import zio.Queue import zio.Queue
import zio.json.* import zio.json.*
import zio.stream.ZStream import zio.stream.ZStream
import zio.stream.SubscriptionRef
import scalive.WebSocketMessage.Payload
final case class Socket[Msg: JsonCodec, Model] private ( final case class Socket[Msg: JsonCodec, Model] private (
id: String, id: String,
token: String, token: String,
inbox: Queue[(Msg, WebSocketMessage.Meta)], inbox: Queue[(Msg, WebSocketMessage.Meta)],
outbox: ZStream[Any, Nothing, (LiveResponse, WebSocketMessage.Meta)], outbox: ZStream[Any, Nothing, (Payload, WebSocketMessage.Meta)],
fiber: Fiber.Runtime[Throwable, Unit],
shutdown: UIO[Unit]): shutdown: UIO[Unit]):
val messageCodec = JsonCodec[Msg] val messageCodec = JsonCodec[Msg]
@ -23,28 +24,54 @@ object Socket:
meta: WebSocketMessage.Meta meta: WebSocketMessage.Meta
): RIO[Scope, Socket[Msg, Model]] = ): RIO[Scope, Socket[Msg, Model]] =
for for
inbox <- Queue.bounded[(Msg, WebSocketMessage.Meta)](4) inbox <- Queue.bounded[(Msg, WebSocketMessage.Meta)](4)
outHub <- Hub.bounded[(LiveResponse, WebSocketMessage.Meta)](4) outHub <- Hub.unbounded[(Payload, WebSocketMessage.Meta)]
initModel <- lv.init initModel <- lv.init
modelVar = Var(initModel) modelVar = Var(initModel)
el = lv.view(modelVar) el = lv.view(modelVar)
ref <- Ref.make((modelVar, el)) ref <- Ref.make((modelVar, el))
initDiff = el.diff(trackUpdates = false) initDiff = el.diff(trackUpdates = false)
fiber <- ZStream
.fromQueue(inbox) lvStreamRef <- SubscriptionRef.make(lv.subscriptions(initModel))
.mapZIO { (msg, meta) =>
for clientMsgStream = ZStream.fromQueue(inbox)
(modelVar, el) <- ref.get serverMsgStream = (ZStream.fromZIO(lvStreamRef.get) ++ lvStreamRef.changes)
updatedModel <- lv.update(modelVar.currentValue)(msg) .flatMapParSwitch(1, 1)(identity)
_ = modelVar.set(updatedModel) .map(_ -> meta.copy(messageRef = None, eventType = "diff"))
diff = el.diff()
_ <- outHub.publish(LiveResponse.Diff(diff) -> meta) clientFiber <- clientMsgStream.runForeach { (msg, meta) =>
yield () for
} (modelVar, el) <- ref.get
.runDrain updatedModel <- lv.update(modelVar.currentValue)(msg)
.forkScoped _ = modelVar.set(updatedModel)
stop = inbox.shutdown *> outHub.shutdown *> fiber.interrupt.unit _ <- lvStreamRef.set(lv.subscriptions(updatedModel))
diffStream <- ZStream.fromHubScoped(outHub) diff = el.diff()
outbox = ZStream.succeed(LiveResponse.InitDiff(initDiff) -> meta) ++ diffStream payload = Payload.okReply(LiveResponse.Diff(diff))
yield Socket[Msg, Model](id, token, inbox, outbox, fiber, stop) _ <- outHub.publish(payload -> meta)
yield ()
}.fork
serverFiber <- serverMsgStream.runForeach { (msg, meta) =>
for
(modelVar, el) <- ref.get
updatedModel <- lv.update(modelVar.currentValue)(msg)
_ = modelVar.set(updatedModel)
diff = el.diff()
payload = Payload.Diff(diff)
_ <- outHub.publish(payload -> meta)
yield ()
}.fork
stop =
inbox.shutdown *> outHub.shutdown *> clientFiber.interrupt.unit *> serverFiber.interrupt.unit
outbox =
ZStream.succeed(
Payload.okReply(LiveResponse.InitDiff(initDiff)) -> meta
) ++ ZStream.unwrapScoped(ZStream.fromHubScoped(outHub)).filterNot {
case (Payload.Diff(diff), _) => diff.isEmpty
case _ => false
}
yield Socket[Msg, Model](id, token, inbox, outbox, stop)
end for
end start
end Socket end Socket