diff --git a/core/src/scalive/WebSocketMessage.scala b/core/src/scalive/WebSocketMessage.scala new file mode 100644 index 0000000..cabe74a --- /dev/null +++ b/core/src/scalive/WebSocketMessage.scala @@ -0,0 +1,102 @@ +package scalive + +import scalive.WebSocketMessage.LiveResponse +import scalive.WebSocketMessage.Payload +import scalive.WebSocketMessage.Payload.EventType +import zio.Chunk +import zio.json.* +import zio.json.ast.Json + +final case class WebSocketMessage( + // Live session ID, auto increment defined by the client on join + joinRef: Option[Int], + // Message ID, global auto increment defined by the client on every message + messageRef: Int, + // LiveView instance id + topic: String, + eventType: String, + payload: WebSocketMessage.Payload) +object WebSocketMessage: + given JsonCodec[WebSocketMessage] = JsonCodec[Json].transformOrFail( + { + case Json.Arr( + Chunk(joinRef, Json.Str(messageRef), Json.Str(topic), Json.Str(eventType), payload) + ) => + val payloadParsed = eventType match + case "heartbeat" => Right(Payload.Heartbeat) + case "phx_join" => payload.as[Payload.Join] + case "event" => payload.as[Payload.Event] + case s => Left(s"Unknown event type : $s") + + payloadParsed.map( + WebSocketMessage( + joinRef.asString.map(_.toInt), + messageRef.toInt, + topic, + eventType, + _ + ) + ) + case v => Left(s"Could not parse socket message ${v.toJson}") + }, + m => + Json.Arr( + m.joinRef.map(ref => Json.Str(ref.toString)).getOrElse(Json.Null), + Json.Str(m.messageRef.toString), + Json.Str(m.topic), + Json.Str(m.eventType), + m.payload.match + case Payload.Heartbeat => Json.Obj.empty + case p: Payload.Join => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) + case p: Payload.Reply => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) + case p: Payload.Event => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) + ) + ) + + enum Payload: + case Heartbeat + case Join( + url: String, + // params: Map[String, String], + session: String, + static: Option[String], + sticky: Boolean) + case Reply(status: String, response: LiveResponse) + case Event(`type`: Payload.EventType, event: String, value: Map[String, String]) + object Payload: + given JsonCodec[Payload.Join] = JsonCodec.derived + given JsonEncoder[Payload.Reply] = JsonEncoder.derived + given JsonCodec[Payload.Event] = JsonCodec.derived + + enum EventType: + case Click + object EventType: + given JsonCodec[EventType] = JsonCodec[String].transformOrFail( + { + case "click" => Right(Click) + case s => Left(s"Unsupported event type: $s") + }, + { case Click => + "click" + } + ) + + enum LiveResponse: + case Empty + case InitDiff(rendered: scalive.Diff) + case Diff(diff: scalive.Diff) + object LiveResponse: + given JsonEncoder[LiveResponse] = + JsonEncoder[Json].contramap { + case Empty => Json.Obj.empty + case InitDiff(rendered) => + Json.Obj( + "liveview_version" -> Json.Str("1.1.8"), + "rendered" -> rendered.toJsonAST.getOrElse(throw new IllegalArgumentException()) + ) + case Diff(diff) => + Json.Obj( + "diff" -> diff.toJsonAST.getOrElse(throw new IllegalArgumentException()) + ) + } +end WebSocketMessage diff --git a/zio/src/scalive/LiveRouter.scala b/zio/src/scalive/LiveRouter.scala index fd660cb..d81b619 100644 --- a/zio/src/scalive/LiveRouter.scala +++ b/zio/src/scalive/LiveRouter.scala @@ -1,15 +1,13 @@ package scalive -import scalive.SocketMessage.LiveResponse -import scalive.SocketMessage.Payload -import scalive.SocketMessage.Payload.EventType +import scalive.WebSocketMessage.LiveResponse +import scalive.WebSocketMessage.Payload import zio.* import zio.http.* import zio.http.ChannelEvent.Read import zio.http.codec.PathCodec import zio.http.template.Html import zio.json.* -import zio.json.ast.Json import java.util.Base64 import scala.collection.mutable @@ -43,67 +41,82 @@ final case class LiveRoute[A, ClientEvt: JsonCodec, ServerEvt]( ) } -class LiveChannel(): - // TODO not thread safe +class LiveChannel(semaphore: Semaphore): private val sockets: mutable.Map[String, Socket[?, ?]] = mutable.Map.empty // TODO should check id isn't already present - def join[ClientEvt: JsonCodec](id: String, token: String, lv: LiveView[ClientEvt, ?]): Diff = - val socket = Socket(id, token, lv) - sockets.addOne(id, socket) - socket.diff + def join[ClientEvt: JsonCodec](id: String, token: String, lv: LiveView[ClientEvt, ?]): UIO[Diff] = + semaphore.withPermit { + ZIO.succeed { + val socket = Socket(id, token, lv) + sockets.addOne(id, socket) + socket.diff + } + } // TODO handle missing id - def event(id: String, value: String): Diff = - val s = sockets(id) - s.lv.handleClientEvent( - value - .fromJson(using s.clientEventCodec.decoder).getOrElse(throw new IllegalArgumentException()) - ) - s.diff + def event(id: String, value: String): UIO[Diff] = + semaphore.withPermit { + ZIO.succeed { + val s = sockets(id) + s.lv.handleClientEvent( + value + .fromJson(using s.clientEventCodec.decoder).getOrElse( + throw new IllegalArgumentException() + ) + ) + s.diff + } + } + +object LiveChannel: + def make(): UIO[LiveChannel] = + Semaphore.make(permits = 1).map(new LiveChannel(_)) class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRoute[?, ?, ?]]): private val socketApp: WebSocketApp[Any] = - val liveChannel = new LiveChannel() Handler.webSocket { channel => - channel - .receiveAll { - case Read(WebSocketFrame.Text(content)) => - for - message <- ZIO - .fromEither(content.fromJson[SocketMessage]) - .mapError(new IllegalArgumentException(_)) - reply <- handleMessage(message, liveChannel) - _ <- channel.send(Read(WebSocketFrame.text(reply.toJson))) - yield () - case _ => ZIO.unit - }.tapErrorCause(ZIO.logErrorCause(_)) + LiveChannel + .make().flatMap(liveChannel => + channel + .receiveAll { + case Read(WebSocketFrame.Text(content)) => + for + message <- ZIO + .fromEither(content.fromJson[WebSocketMessage]) + .mapError(new IllegalArgumentException(_)) + reply <- handleMessage(message, liveChannel) + _ <- channel.send(Read(WebSocketFrame.text(reply.toJson))) + yield () + case _ => ZIO.unit + }.tapErrorCause(ZIO.logErrorCause(_)) + ) } - private def handleMessage(message: SocketMessage, liveChannel: LiveChannel): Task[SocketMessage] = + private def handleMessage(message: WebSocketMessage, liveChannel: LiveChannel) + : Task[WebSocketMessage] = val reply = message.payload match case Payload.Heartbeat => ZIO.succeed(Payload.Reply("ok", LiveResponse.Empty)) case Payload.Join(url, session, static, sticky) => ZIO - .fromEither(URL.decode(url)).map(url => + .fromEither(URL.decode(url)).flatMap(url => val req = Request(url = url) liveRoutes .collectFirst { route => val pathParams = route.path.decode(req.path).getOrElse(???) val lv = route.liveviewBuilder(pathParams, req) - val diff = - liveChannel.join(message.topic, session, lv)(using route.clientEventCodec) - Payload.Reply("ok", LiveResponse.InitDiff(diff)) + liveChannel.join(message.topic, session, lv)(using route.clientEventCodec) }.getOrElse(???) - ) + ).map(diff => Payload.Reply("ok", LiveResponse.InitDiff(diff))) case Payload.Event(_, event, _) => - val diff = liveChannel.event(message.topic, event) - ZIO.succeed(Payload.Reply("ok", LiveResponse.Diff(diff))) + liveChannel + .event(message.topic, event) + .map(diff => Payload.Reply("ok", LiveResponse.Diff(diff))) case Payload.Reply(_, _) => ZIO.die(new IllegalArgumentException()) - reply.map(SocketMessage(message.joinRef, message.messageRef, message.topic, "phx_reply", _)) + reply.map(WebSocketMessage(message.joinRef, message.messageRef, message.topic, "phx_reply", _)) val routes: Routes[Any, Response] = Routes.fromIterable( @@ -114,97 +127,3 @@ class LiveRouter(rootLayout: HtmlElement => HtmlElement, liveRoutes: List[LiveRo ) ) end LiveRouter - -final case class SocketMessage( - // Live session ID, auto increment defined by the client on join - joinRef: Option[Int], - // Message ID, global auto increment defined by the client on every message - messageRef: Int, - // LiveView instance id - topic: String, - eventType: String, - payload: SocketMessage.Payload) -object SocketMessage: - given JsonCodec[SocketMessage] = JsonCodec[Json].transformOrFail( - { - case Json.Arr( - Chunk(joinRef, Json.Str(messageRef), Json.Str(topic), Json.Str(eventType), payload) - ) => - val payloadParsed = eventType match - case "heartbeat" => Right(Payload.Heartbeat) - case "phx_join" => payload.as[Payload.Join] - case "event" => payload.as[Payload.Event] - case s => Left(s"Unknown event type : $s") - - payloadParsed.map( - SocketMessage( - joinRef.asString.map(_.toInt), - messageRef.toInt, - topic, - eventType, - _ - ) - ) - case v => Left(s"Could not parse socket message ${v.toJson}") - }, - m => - Json.Arr( - m.joinRef.map(ref => Json.Str(ref.toString)).getOrElse(Json.Null), - Json.Str(m.messageRef.toString), - Json.Str(m.topic), - Json.Str(m.eventType), - m.payload.match - case Payload.Heartbeat => Json.Obj.empty - case p: Payload.Join => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) - case p: Payload.Reply => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) - case p: Payload.Event => p.toJsonAST.getOrElse(throw new IllegalArgumentException()) - ) - ) - - enum Payload: - case Heartbeat - case Join( - url: String, - // params: Map[String, String], - session: String, - static: Option[String], - sticky: Boolean) - case Reply(status: String, response: LiveResponse) - case Event(`type`: Payload.EventType, event: String, value: Map[String, String]) - object Payload: - given JsonCodec[Payload.Join] = JsonCodec.derived - given JsonEncoder[Payload.Reply] = JsonEncoder.derived - given JsonCodec[Payload.Event] = JsonCodec.derived - - enum EventType: - case Click - object EventType: - given JsonCodec[EventType] = JsonCodec[String].transformOrFail( - { - case "click" => Right(Click) - case s => Left(s"Unsupported event type: $s") - }, - { case Click => - "click" - } - ) - - enum LiveResponse: - case Empty - case InitDiff(rendered: scalive.Diff) - case Diff(diff: scalive.Diff) - object LiveResponse: - given JsonEncoder[LiveResponse] = - JsonEncoder[Json].contramap { - case Empty => Json.Obj.empty - case InitDiff(rendered) => - Json.Obj( - "liveview_version" -> Json.Str("1.1.8"), - "rendered" -> rendered.toJsonAST.getOrElse(throw new IllegalArgumentException()) - ) - case Diff(diff) => - Json.Obj( - "diff" -> diff.toJsonAST.getOrElse(throw new IllegalArgumentException()) - ) - } -end SocketMessage