Make LiveChannel thread safe

This commit is contained in:
Paul-Henri Froidmont 2025-08-29 01:39:00 +02:00
parent fca87a4263
commit dc3cc0ac07
Signed by: phfroidmont
GPG key ID: BE948AFD7E7873BE
2 changed files with 154 additions and 133 deletions

View file

@ -0,0 +1,102 @@
package scalive
import scalive.WebSocketMessage.LiveResponse
import scalive.WebSocketMessage.Payload
import scalive.WebSocketMessage.Payload.EventType
import zio.Chunk
import zio.json.*
import zio.json.ast.Json
final case class WebSocketMessage(
// Live session ID, auto increment defined by the client on join
joinRef: Option[Int],
// Message ID, global auto increment defined by the client on every message
messageRef: Int,
// LiveView instance id
topic: String,
eventType: String,
payload: WebSocketMessage.Payload)
object WebSocketMessage:
given JsonCodec[WebSocketMessage] = JsonCodec[Json].transformOrFail(
{
case Json.Arr(
Chunk(joinRef, Json.Str(messageRef), Json.Str(topic), Json.Str(eventType), payload)
) =>
val payloadParsed = eventType match
case "heartbeat" => Right(Payload.Heartbeat)
case "phx_join" => payload.as[Payload.Join]
case "event" => payload.as[Payload.Event]
case s => Left(s"Unknown event type : $s")
payloadParsed.map(
WebSocketMessage(
joinRef.asString.map(_.toInt),
messageRef.toInt,
topic,
eventType,
_
)
)
case v => Left(s"Could not parse socket message ${v.toJson}")
},
m =>
Json.Arr(
m.joinRef.map(ref => Json.Str(ref.toString)).getOrElse(Json.Null),
Json.Str(m.messageRef.toString),
Json.Str(m.topic),
Json.Str(m.eventType),
m.payload.match
case Payload.Heartbeat => Json.Obj.empty
case p: Payload.Join => p.toJsonAST.getOrElse(throw new IllegalArgumentException())
case p: Payload.Reply => p.toJsonAST.getOrElse(throw new IllegalArgumentException())
case p: Payload.Event => p.toJsonAST.getOrElse(throw new IllegalArgumentException())
)
)
enum Payload:
case Heartbeat
case Join(
url: String,
// params: Map[String, String],
session: String,
static: Option[String],
sticky: Boolean)
case Reply(status: String, response: LiveResponse)
case Event(`type`: Payload.EventType, event: String, value: Map[String, String])
object Payload:
given JsonCodec[Payload.Join] = JsonCodec.derived
given JsonEncoder[Payload.Reply] = JsonEncoder.derived
given JsonCodec[Payload.Event] = JsonCodec.derived
enum EventType:
case Click
object EventType:
given JsonCodec[EventType] = JsonCodec[String].transformOrFail(
{
case "click" => Right(Click)
case s => Left(s"Unsupported event type: $s")
},
{ case Click =>
"click"
}
)
enum LiveResponse:
case Empty
case InitDiff(rendered: scalive.Diff)
case Diff(diff: scalive.Diff)
object LiveResponse:
given JsonEncoder[LiveResponse] =
JsonEncoder[Json].contramap {
case Empty => Json.Obj.empty
case InitDiff(rendered) =>
Json.Obj(
"liveview_version" -> Json.Str("1.1.8"),
"rendered" -> rendered.toJsonAST.getOrElse(throw new IllegalArgumentException())
)
case Diff(diff) =>
Json.Obj(
"diff" -> diff.toJsonAST.getOrElse(throw new IllegalArgumentException())
)
}
end WebSocketMessage

View file

@ -1,15 +1,13 @@
package scalive package scalive
import scalive.SocketMessage.LiveResponse import scalive.WebSocketMessage.LiveResponse
import scalive.SocketMessage.Payload import scalive.WebSocketMessage.Payload
import scalive.SocketMessage.Payload.EventType
import zio.* import zio.*
import zio.http.* import zio.http.*
import zio.http.ChannelEvent.Read import zio.http.ChannelEvent.Read
import zio.http.codec.PathCodec import zio.http.codec.PathCodec
import zio.http.template.Html import zio.http.template.Html
import zio.json.* import zio.json.*
import zio.json.ast.Json
import java.util.Base64 import java.util.Base64
import scala.collection.mutable import scala.collection.mutable
@ -43,67 +41,82 @@ final case class LiveRoute[A, ClientEvt: JsonCodec, ServerEvt](
) )
} }
class LiveChannel(): class LiveChannel(semaphore: Semaphore):
// TODO not thread safe
private val sockets: mutable.Map[String, Socket[?, ?]] = mutable.Map.empty private val sockets: mutable.Map[String, Socket[?, ?]] = mutable.Map.empty
// TODO should check id isn't already present // TODO should check id isn't already present
def join[ClientEvt: JsonCodec](id: String, token: String, lv: LiveView[ClientEvt, ?]): Diff = def join[ClientEvt: JsonCodec](id: String, token: String, lv: LiveView[ClientEvt, ?]): UIO[Diff] =
val socket = Socket(id, token, lv) semaphore.withPermit {
sockets.addOne(id, socket) ZIO.succeed {
socket.diff val socket = Socket(id, token, lv)
sockets.addOne(id, socket)
socket.diff
}
}
// TODO handle missing id // TODO handle missing id
def event(id: String, value: String): Diff = def event(id: String, value: String): UIO[Diff] =
val s = sockets(id) semaphore.withPermit {
s.lv.handleClientEvent( ZIO.succeed {
value val s = sockets(id)
.fromJson(using s.clientEventCodec.decoder).getOrElse(throw new IllegalArgumentException()) s.lv.handleClientEvent(
) value
s.diff .fromJson(using s.clientEventCodec.decoder).getOrElse(
throw new IllegalArgumentException()
)
)
s.diff
}
}
object LiveChannel:
def make(): UIO[LiveChannel] =
Semaphore.make(permits = 1).map(new LiveChannel(_))
class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRoute[?, ?, ?]]): class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRoute[?, ?, ?]]):
private val socketApp: WebSocketApp[Any] = private val socketApp: WebSocketApp[Any] =
val liveChannel = new LiveChannel()
Handler.webSocket { channel => Handler.webSocket { channel =>
channel LiveChannel
.receiveAll { .make().flatMap(liveChannel =>
case Read(WebSocketFrame.Text(content)) => channel
for .receiveAll {
message <- ZIO case Read(WebSocketFrame.Text(content)) =>
.fromEither(content.fromJson[SocketMessage]) for
.mapError(new IllegalArgumentException(_)) message <- ZIO
reply <- handleMessage(message, liveChannel) .fromEither(content.fromJson[WebSocketMessage])
_ <- channel.send(Read(WebSocketFrame.text(reply.toJson))) .mapError(new IllegalArgumentException(_))
yield () reply <- handleMessage(message, liveChannel)
case _ => ZIO.unit _ <- channel.send(Read(WebSocketFrame.text(reply.toJson)))
}.tapErrorCause(ZIO.logErrorCause(_)) yield ()
case _ => ZIO.unit
}.tapErrorCause(ZIO.logErrorCause(_))
)
} }
private def handleMessage(message: SocketMessage, liveChannel: LiveChannel): Task[SocketMessage] = private def handleMessage(message: WebSocketMessage, liveChannel: LiveChannel)
: Task[WebSocketMessage] =
val reply = message.payload match val reply = message.payload match
case Payload.Heartbeat => ZIO.succeed(Payload.Reply("ok", LiveResponse.Empty)) case Payload.Heartbeat => ZIO.succeed(Payload.Reply("ok", LiveResponse.Empty))
case Payload.Join(url, session, static, sticky) => case Payload.Join(url, session, static, sticky) =>
ZIO ZIO
.fromEither(URL.decode(url)).map(url => .fromEither(URL.decode(url)).flatMap(url =>
val req = Request(url = url) val req = Request(url = url)
liveRoutes liveRoutes
.collectFirst { route => .collectFirst { route =>
val pathParams = route.path.decode(req.path).getOrElse(???) val pathParams = route.path.decode(req.path).getOrElse(???)
val lv = route.liveviewBuilder(pathParams, req) val lv = route.liveviewBuilder(pathParams, req)
val diff = liveChannel.join(message.topic, session, lv)(using route.clientEventCodec)
liveChannel.join(message.topic, session, lv)(using route.clientEventCodec)
Payload.Reply("ok", LiveResponse.InitDiff(diff))
}.getOrElse(???) }.getOrElse(???)
) ).map(diff => Payload.Reply("ok", LiveResponse.InitDiff(diff)))
case Payload.Event(_, event, _) => case Payload.Event(_, event, _) =>
val diff = liveChannel.event(message.topic, event) liveChannel
ZIO.succeed(Payload.Reply("ok", LiveResponse.Diff(diff))) .event(message.topic, event)
.map(diff => Payload.Reply("ok", LiveResponse.Diff(diff)))
case Payload.Reply(_, _) => ZIO.die(new IllegalArgumentException()) case Payload.Reply(_, _) => ZIO.die(new IllegalArgumentException())
reply.map(SocketMessage(message.joinRef, message.messageRef, message.topic, "phx_reply", _)) reply.map(WebSocketMessage(message.joinRef, message.messageRef, message.topic, "phx_reply", _))
val routes: Routes[Any, Response] = val routes: Routes[Any, Response] =
Routes.fromIterable( Routes.fromIterable(
@ -114,97 +127,3 @@ class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRo
) )
) )
end LiveRouter end LiveRouter
final case class SocketMessage(
// Live session ID, auto increment defined by the client on join
joinRef: Option[Int],
// Message ID, global auto increment defined by the client on every message
messageRef: Int,
// LiveView instance id
topic: String,
eventType: String,
payload: SocketMessage.Payload)
object SocketMessage:
given JsonCodec[SocketMessage] = JsonCodec[Json].transformOrFail(
{
case Json.Arr(
Chunk(joinRef, Json.Str(messageRef), Json.Str(topic), Json.Str(eventType), payload)
) =>
val payloadParsed = eventType match
case "heartbeat" => Right(Payload.Heartbeat)
case "phx_join" => payload.as[Payload.Join]
case "event" => payload.as[Payload.Event]
case s => Left(s"Unknown event type : $s")
payloadParsed.map(
SocketMessage(
joinRef.asString.map(_.toInt),
messageRef.toInt,
topic,
eventType,
_
)
)
case v => Left(s"Could not parse socket message ${v.toJson}")
},
m =>
Json.Arr(
m.joinRef.map(ref => Json.Str(ref.toString)).getOrElse(Json.Null),
Json.Str(m.messageRef.toString),
Json.Str(m.topic),
Json.Str(m.eventType),
m.payload.match
case Payload.Heartbeat => Json.Obj.empty
case p: Payload.Join => p.toJsonAST.getOrElse(throw new IllegalArgumentException())
case p: Payload.Reply => p.toJsonAST.getOrElse(throw new IllegalArgumentException())
case p: Payload.Event => p.toJsonAST.getOrElse(throw new IllegalArgumentException())
)
)
enum Payload:
case Heartbeat
case Join(
url: String,
// params: Map[String, String],
session: String,
static: Option[String],
sticky: Boolean)
case Reply(status: String, response: LiveResponse)
case Event(`type`: Payload.EventType, event: String, value: Map[String, String])
object Payload:
given JsonCodec[Payload.Join] = JsonCodec.derived
given JsonEncoder[Payload.Reply] = JsonEncoder.derived
given JsonCodec[Payload.Event] = JsonCodec.derived
enum EventType:
case Click
object EventType:
given JsonCodec[EventType] = JsonCodec[String].transformOrFail(
{
case "click" => Right(Click)
case s => Left(s"Unsupported event type: $s")
},
{ case Click =>
"click"
}
)
enum LiveResponse:
case Empty
case InitDiff(rendered: scalive.Diff)
case Diff(diff: scalive.Diff)
object LiveResponse:
given JsonEncoder[LiveResponse] =
JsonEncoder[Json].contramap {
case Empty => Json.Obj.empty
case InitDiff(rendered) =>
Json.Obj(
"liveview_version" -> Json.Str("1.1.8"),
"rendered" -> rendered.toJsonAST.getOrElse(throw new IllegalArgumentException())
)
case Diff(diff) =>
Json.Obj(
"diff" -> diff.toJsonAST.getOrElse(throw new IllegalArgumentException())
)
}
end SocketMessage