/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.websockets.core.protocol;

import io.undertow.util.Headers;
import io.undertow.websockets.WebSocketExtension;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSocketVersion;
import io.undertow.websockets.extensions.ExtensionFunction;
import io.undertow.websockets.extensions.ExtensionHandshake;
import io.undertow.websockets.spi.WebSocketHttpExchange;
import org.xnio.IoFuture;
import io.undertow.connector.ByteBufferPool;
import org.xnio.StreamConnection;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;

Abstract base class for doing a WebSocket Handshake.
Author:Mike Brock
/** * Abstract base class for doing a WebSocket Handshake. * * @author Mike Brock */
public abstract class Handshake { private final WebSocketVersion version; private final String hashAlgorithm; private final String magicNumber; protected final Set<String> subprotocols; private static final byte[] EMPTY = new byte[0]; private static final Pattern PATTERN = Pattern.compile("\\s*,\\s*"); protected Set<ExtensionHandshake> availableExtensions = new HashSet<>(); protected boolean allowExtensions; protected Handshake(WebSocketVersion version, String hashAlgorithm, String magicNumber, final Set<String> subprotocols) { this.version = version; this.hashAlgorithm = hashAlgorithm; this.magicNumber = magicNumber; this.subprotocols = subprotocols; }
Return the version for which the Handshake can be used.
/** * Return the version for which the {@link Handshake} can be used. */
public WebSocketVersion getVersion() { return version; }
Return the algorithm that is used to hash during the handshake
/** * Return the algorithm that is used to hash during the handshake */
public String getHashAlgorithm() { return hashAlgorithm; }
Return the magic number which will be mixed in
/** * Return the magic number which will be mixed in */
public String getMagicNumber() { return magicNumber; }
Return the full url of the websocket location of the given WebSocketHttpExchange
/** * Return the full url of the websocket location of the given {@link WebSocketHttpExchange} */
protected static String getWebSocketLocation(WebSocketHttpExchange exchange) { String scheme; if ("https".equals(exchange.getRequestScheme())) { scheme = "wss"; } else { scheme = "ws"; } return scheme + "://" + exchange.getRequestHeader(Headers.HOST_STRING) + exchange.getRequestURI(); }
Issue the WebSocket upgrade
Params:
/** * Issue the WebSocket upgrade * * @param exchange The {@link WebSocketHttpExchange} for which the handshake and upgrade should occur. */
public final void handshake(final WebSocketHttpExchange exchange) { exchange.putAttachment(WebSocketVersion.ATTACHMENT_KEY, version); handshakeInternal(exchange); } protected abstract void handshakeInternal(final WebSocketHttpExchange exchange);
Return true if this implementation can be used to issue a handshake.
/** * Return {@code true} if this implementation can be used to issue a handshake. */
public abstract boolean matches(WebSocketHttpExchange exchange); /** * Create the {@link WebSocketChannel} from the {@link WebSocketHttpExchange} */ public abstract WebSocketChannel createChannel(WebSocketHttpExchange exchange, final StreamConnection channel, final ByteBufferPool pool);
convenience method to perform the upgrade
/** * convenience method to perform the upgrade */
protected final void performUpgrade(final WebSocketHttpExchange exchange, final byte[] data) { exchange.setResponseHeader(Headers.CONTENT_LENGTH_STRING, String.valueOf(data.length)); exchange.setResponseHeader(Headers.UPGRADE_STRING, "WebSocket"); exchange.setResponseHeader(Headers.CONNECTION_STRING, "Upgrade"); upgradeChannel(exchange, data); } protected void upgradeChannel(final WebSocketHttpExchange exchange, final byte[] data) { if (data.length > 0) { writePayload(exchange, ByteBuffer.wrap(data)); } else { exchange.endExchange(); } } private static void writePayload(final WebSocketHttpExchange exchange, final ByteBuffer payload) { exchange.sendData(payload).addNotifier(new IoFuture.Notifier<Void, Object>() { @Override public void notify(final IoFuture<? extends Void> ioFuture, final Object attachment) { if (ioFuture.getStatus() == IoFuture.Status.DONE) { exchange.endExchange(); } else { exchange.close(); } } }, null); }
Perform the upgrade using no payload
/** * Perform the upgrade using no payload */
protected final void performUpgrade(final WebSocketHttpExchange exchange) { performUpgrade(exchange, EMPTY); }
Selects the first matching supported sub protocol and add it the the headers of the exchange.
/** * Selects the first matching supported sub protocol and add it the the headers of the exchange. * */
protected final void selectSubprotocol(final WebSocketHttpExchange exchange) { String requestedSubprotocols = exchange.getRequestHeader(Headers.SEC_WEB_SOCKET_PROTOCOL_STRING); if (requestedSubprotocols == null) { return; } String[] requestedSubprotocolArray = PATTERN.split(requestedSubprotocols); String subProtocol = supportedSubprotols(requestedSubprotocolArray); if (subProtocol != null && !subProtocol.isEmpty()) { exchange.setResponseHeader(Headers.SEC_WEB_SOCKET_PROTOCOL_STRING, subProtocol); } } protected final void selectExtensions(final WebSocketHttpExchange exchange) { List<WebSocketExtension> requestedExtensions = WebSocketExtension.parse(exchange.getRequestHeader(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING)); List<WebSocketExtension> extensions = selectedExtension(requestedExtensions); if (extensions != null && !extensions.isEmpty()) { exchange.setResponseHeader(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING, WebSocketExtension.toExtensionHeader(extensions)); } } protected String supportedSubprotols(String[] requestedSubprotocolArray) { for (String p : requestedSubprotocolArray) { String requestedSubprotocol = p.trim(); for (String supportedSubprotocol : subprotocols) { if (requestedSubprotocol.equals(supportedSubprotocol)) { return supportedSubprotocol; } } } return null; } protected List<WebSocketExtension> selectedExtension(List<WebSocketExtension> extensionList) { List<WebSocketExtension> selected = new ArrayList<>(); List<ExtensionHandshake> configured = new ArrayList<>(); for (WebSocketExtension ext : extensionList) { for (ExtensionHandshake extHandshake : availableExtensions) { WebSocketExtension negotiated = extHandshake.accept(ext); if (negotiated != null && !extHandshake.isIncompatible(configured)) { selected.add(negotiated); configured.add(extHandshake); } } } return selected; }
Add a new WebSocket Extension handshake to the list of available extensions.
Params:
  • extension – a new ExtensionHandshake
/** * Add a new WebSocket Extension handshake to the list of available extensions. * * @param extension a new {@code ExtensionHandshake} */
public final void addExtension(ExtensionHandshake extension) { availableExtensions.add(extension); allowExtensions = true; }
Create the ExtensionFunction list associated with the negotiated extensions defined in the exchange's response.
Params:
  • exchange – the exchange used to retrieve negotiated extensions
Returns: a list of ExtensionFunction with the implementation of the extensions
/** * Create the {@code ExtensionFunction} list associated with the negotiated extensions defined in the exchange's response. * * @param exchange the exchange used to retrieve negotiated extensions * @return a list of {@code ExtensionFunction} with the implementation of the extensions */
protected final List<ExtensionFunction> initExtensions(final WebSocketHttpExchange exchange) { String extHeader = exchange.getResponseHeaders().get(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING) != null ? exchange.getResponseHeaders().get(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING).get(0) : null; List<ExtensionFunction> negotiated = new ArrayList<>(); if (extHeader != null) { List<WebSocketExtension> extensions = WebSocketExtension.parse(extHeader); if (extensions != null && !extensions.isEmpty()) { for (WebSocketExtension ext : extensions) { for (ExtensionHandshake extHandshake : availableExtensions) { if (extHandshake.getName().equals(ext.getName())) { negotiated.add(extHandshake.create()); } } } } } return negotiated; } }