/*
 * Copyright (c) 2011, 2020 Oracle and/or its affiliates. All rights reserved.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License v. 2.0, which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * This Source Code may also be made available under the following Secondary
 * Licenses when the conditions for such availability set forth in the
 * Eclipse Public License v. 2.0 are satisfied: GNU General Public License,
 * version 2 with the GNU Classpath Exception, which is available at
 * https://www.gnu.org/software/classpath/license.html.
 *
 * SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
 */

package org.glassfish.grizzly.websockets.rfc6455;

import java.net.URI;
import java.util.Locale;

import org.glassfish.grizzly.Buffer;
import org.glassfish.grizzly.http.HttpContent;
import org.glassfish.grizzly.http.HttpRequestPacket;
import org.glassfish.grizzly.websockets.Constants;
import org.glassfish.grizzly.websockets.DataFrame;
import org.glassfish.grizzly.websockets.FrameType;
import org.glassfish.grizzly.websockets.HandShake;
import org.glassfish.grizzly.websockets.Masker;
import org.glassfish.grizzly.websockets.ProtocolError;
import org.glassfish.grizzly.websockets.ProtocolHandler;
import org.glassfish.grizzly.websockets.frametypes.BinaryFrameType;
import org.glassfish.grizzly.websockets.frametypes.ClosingFrameType;
import org.glassfish.grizzly.websockets.frametypes.ContinuationFrameType;
import org.glassfish.grizzly.websockets.frametypes.PingFrameType;
import org.glassfish.grizzly.websockets.frametypes.PongFrameType;
import org.glassfish.grizzly.websockets.frametypes.TextFrameType;

public class RFC6455Handler extends ProtocolHandler {

    private final ParsingState state = new ParsingState();

    // ------------------------------------------------------------ Constructors

    public RFC6455Handler(boolean mask) {
        super(mask);
    }

    // -------------------------------------------- Methods from ProtocolHandler

    @Override
    public HandShake createClientHandShake(URI uri) {
        return new RFC6455HandShake(uri);
    }

    @Override
    public HandShake createServerHandShake(HttpContent requestContent) {
        return new RFC6455HandShake((HttpRequestPacket) requestContent.getHttpHeader());
    }

    @Override
    public byte[] frame(DataFrame frame) {
        byte opcode = checkForLastFrame(frame, getOpcode(frame.getType()));
        final byte[] bytes = frame.getType().getBytes(frame);
        final byte[] lengthBytes = encodeLength(bytes.length);

        int length = 1 + lengthBytes.length + bytes.length + (maskData ? Constants.MASK_SIZE : 0);
        int payloadStart = 1 + lengthBytes.length + (maskData ? Constants.MASK_SIZE : 0);
        final byte[] packet = new byte[length];
        packet[0] = opcode;
        System.arraycopy(lengthBytes, 0, packet, 1, lengthBytes.length);
        if (maskData) {
            Masker masker = new Masker();
            packet[1] |= 0x80;
            masker.mask(packet, payloadStart, bytes);
            System.arraycopy(masker.getMask(), 0, packet, payloadStart - Constants.MASK_SIZE, Constants.MASK_SIZE);
        } else {
            System.arraycopy(bytes, 0, packet, payloadStart, bytes.length);
        }
        return packet;
    }

    @Override
    public DataFrame parse(Buffer buffer) {

        DataFrame dataFrame;
        try {
            switch (state.state) {
            case 0:
                if (buffer.remaining() < 2) {
                    // Don't have enough bytes to read opcode and lengthCode
                    return null;
                }

                byte opcode = buffer.get();
                boolean rsvBitSet = isBitSet(opcode, 6) || isBitSet(opcode, 5) || isBitSet(opcode, 4);
                if (rsvBitSet) {
                    throw new ProtocolError("RSV bit(s) incorrectly set.");
                }
                state.finalFragment = isBitSet(opcode, 7);
                state.controlFrame = isControlFrame(opcode);
                state.opcode = (byte) (opcode & 0x7f);
                state.frameType = valueOf(inFragmentedType, state.opcode);
                if (!state.finalFragment && state.controlFrame) {
                    throw new ProtocolError("Fragmented control frame");
                }

                if (!state.controlFrame) {
                    if (isContinuationFrame(state.opcode) && !processingFragment) {
                        throw new ProtocolError("End fragment sent, but wasn't processing any previous fragments");
                    }
                    if (processingFragment && !isContinuationFrame(state.opcode)) {
                        throw new ProtocolError("Fragment sent but opcode was not 0");
                    }
                    if (!state.finalFragment && !isContinuationFrame(state.opcode)) {
                        processingFragment = true;
                    }
                    if (!state.finalFragment) {
                        if (inFragmentedType == 0) {
                            inFragmentedType = state.opcode;
                        }
                    }
                }
                byte lengthCode = buffer.get();

                state.masked = (lengthCode & 0x80) == 0x80;
                state.masker = new Masker(buffer);
                if (state.masked) {
                    lengthCode ^= 0x80;
                }
                state.lengthCode = lengthCode;

                state.state++;

            case 1:
                if (state.lengthCode <= 125) {
                    state.length = state.lengthCode;
                } else {
                    if (state.controlFrame) {
                        throw new ProtocolError("Control frame payloads must be no greater than 125 bytes.");
                    }

                    final int lengthBytes = state.lengthCode == 126 ? 2 : 8;
                    if (buffer.remaining() < lengthBytes) {
                        // Don't have enought bytes to read length
                        return null;
                    }
                    state.masker.setBuffer(buffer);
                    state.length = decodeLength(state.masker.unmask(lengthBytes));
                }
                state.state++;
            case 2:
                if (state.masked) {
                    if (buffer.remaining() < Constants.MASK_SIZE) {
                        // Don't have enough bytes to read mask
                        return null;
                    }
                    state.masker.setBuffer(buffer);
                    state.masker.readMask();
                }
                state.state++;
            case 3:
                if (buffer.remaining() < state.length) {
                    return null;
                }

                state.masker.setBuffer(buffer);
                final byte[] data = state.masker.unmask((int) state.length);
                if (data.length != state.length) {
                    throw new ProtocolError(String.format("Data read (%s) is not the expected" + " size (%s)", data.length, state.length));
                }
                dataFrame = state.frameType.create(state.finalFragment, data);

                if (!state.controlFrame && (isTextFrame(state.opcode) || inFragmentedType == 1)) {
                    utf8Decode(state.finalFragment, data, dataFrame);
                }

                if (!state.controlFrame && state.finalFragment) {
                    inFragmentedType = 0;
                    processingFragment = false;
                }
                state.recycle();

                break;
            default:
                // Should never get here
                throw new IllegalStateException("Unexpected state: " + state.state);
            }
        } catch (Exception e) {
            state.recycle();
            if (e instanceof RuntimeException) {
                throw (RuntimeException) e;
            } else {
                throw new RuntimeException(e);
            }
        }

        return dataFrame;

    }

    @Override
    protected boolean isControlFrame(byte opcode) {
        return (opcode & 0x08) == 0x08;
    }

    private boolean isBitSet(final byte b, int bit) {
        return (b >> bit & 1) != 0;
    }

    private boolean isContinuationFrame(byte opcode) {
        return opcode == 0;
    }

    private boolean isTextFrame(byte opcode) {
        return opcode == 1;
    }

    private byte getOpcode(FrameType type) {
        if (type instanceof TextFrameType) {
            return 0x01;
        } else if (type instanceof BinaryFrameType) {
            return 0x02;
        } else if (type instanceof ClosingFrameType) {
            return 0x08;
        } else if (type instanceof PingFrameType) {
            return 0x09;
        } else if (type instanceof PongFrameType) {
            return 0x0A;
        }

        throw new ProtocolError("Unknown frame type: " + type.getClass().getName());
    }

    private FrameType valueOf(byte fragmentType, byte value) {
        final int opcode = value & 0xF;
        switch (opcode) {
        case 0x00:
            return new ContinuationFrameType((fragmentType & 0x01) == 0x01);
        case 0x01:
            return new TextFrameType();
        case 0x02:
            return new BinaryFrameType();
        case 0x08:
            return new ClosingFrameType();
        case 0x09:
            return new PingFrameType();
        case 0x0A:
            return new PongFrameType();
        default:
            throw new ProtocolError(String.format("Unknown frame type: %s, %s", Integer.toHexString(opcode & 0xFF).toUpperCase(Locale.US), connection));
        }
    }

    // ---------------------------------------------------------- Nested Classes

    private static class ParsingState {
        int state = 0;
        byte opcode = (byte) -1;
        long length = -1;
        FrameType frameType;
        boolean masked;
        Masker masker;
        boolean finalFragment;
        boolean controlFrame;
        private byte lengthCode = -1;

        void recycle() {
            state = 0;
            opcode = (byte) -1;
            length = -1;
            lengthCode = -1;
            masked = false;
            masker = null;
            finalFragment = false;
            controlFrame = false;
            frameType = null;
        }
    }

}