package io.micronaut.http.netty.websocket;
import io.micronaut.context.annotation.Requires;
import io.micronaut.http.MediaType;
import io.micronaut.websocket.WebSocketBroadcaster;
import io.micronaut.websocket.WebSocketSession;
import io.micronaut.websocket.exceptions.WebSocketSessionException;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.Attribute;
import io.reactivex.BackpressureStrategy;
import io.reactivex.Flowable;
import javax.inject.Singleton;
import java.util.function.Predicate;
@Singleton
@Requires(beans = WebSocketSessionRepository.class)
public class NettyServerWebSocketBroadcaster implements WebSocketBroadcaster {
private final WebSocketMessageEncoder webSocketMessageEncoder;
private final WebSocketSessionRepository webSocketSessionRepository;
public NettyServerWebSocketBroadcaster(WebSocketMessageEncoder webSocketMessageEncoder,
WebSocketSessionRepository webSocketSessionRepository) {
this.webSocketMessageEncoder = webSocketMessageEncoder;
this.webSocketSessionRepository = webSocketSessionRepository;
}
@Override
public <T> void broadcastSync(T message, MediaType mediaType, Predicate<WebSocketSession> filter) {
WebSocketFrame frame = webSocketMessageEncoder.encodeMessage(message, mediaType);
try {
webSocketSessionRepository.getChannelGroup().writeAndFlush(frame, ch -> {
Attribute<NettyRxWebSocketSession> attr = ch.attr(NettyRxWebSocketSession.WEB_SOCKET_SESSION_KEY);
NettyRxWebSocketSession s = attr.get();
return s != null && s.isOpen() && filter.test(s);
}).sync();
} catch (InterruptedException e) {
throw new WebSocketSessionException("Broadcast Interrupted");
}
}
@Override
public <T> Flowable<T> broadcast(T message, MediaType mediaType, Predicate<WebSocketSession> filter) {
return Flowable.create(emitter -> {
try {
WebSocketFrame frame = webSocketMessageEncoder.encodeMessage(message, mediaType);
webSocketSessionRepository.getChannelGroup().writeAndFlush(frame, ch -> {
Attribute<NettyRxWebSocketSession> attr = ch.attr(NettyRxWebSocketSession.WEB_SOCKET_SESSION_KEY);
NettyRxWebSocketSession s = attr.get();
return s != null && s.isOpen() && filter.test(s);
}).addListener(future -> {
if (future.isSuccess()) {
emitter.onNext(message);
emitter.onComplete();
} else {
Throwable cause = future.cause();
emitter.onError(new WebSocketSessionException("Broadcast Failure: " + cause.getMessage(), cause));
}
});
} catch (Throwable e) {
emitter.onError(new WebSocketSessionException("Broadcast Failure: " + e.getMessage(), e));
}
}, BackpressureStrategy.BUFFER);
}
}