/*
 * Copyright 2015 Terracotta, Inc., a Software AG company.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.terracotta.offheapstore.storage.portability;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.io.OutputStream;
import java.io.Serializable;
import java.lang.ref.WeakReference;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import org.terracotta.offheapstore.util.ByteBufferInputStream;
import org.terracotta.offheapstore.util.FindbugsSuppressWarnings;

A trivially compressed Java serialization based portability.

Class descriptors in the resultant bytes are encoded as integers. Mappings between the integer representation and the ObjectStreamClass, and the Class and the integer representation are stored in a single on-heap map.

Author:Chris Dennis
/** * A trivially compressed Java serialization based portability. * <p> * Class descriptors in the resultant bytes are encoded as integers. Mappings * between the integer representation and the {@link ObjectStreamClass}, and the * {@code Class} and the integer representation are stored in a single on-heap * map. * * @author Chris Dennis */
public class SerializablePortability implements Portability<Serializable> { protected int nextStreamIndex = 0; protected final ConcurrentMap<Object, Object> lookup = new ConcurrentHashMap<>(); private final ClassLoader loader; public SerializablePortability() { this(null); } public SerializablePortability(ClassLoader loader) { this.loader = loader; } @Override public ByteBuffer encode(Serializable object) { try { ByteArrayOutputStream bout = new ByteArrayOutputStream(); try (ObjectOutputStream oout = getObjectOutputStream(bout)) { oout.writeObject(object); } return ByteBuffer.wrap(bout.toByteArray()); } catch (IOException e) { throw new AssertionError(e); } } @Override public Serializable decode(ByteBuffer buffer) { try { try (ObjectInputStream oin = getObjectInputStream(new ByteBufferInputStream(buffer))) { return (Serializable) oin.readObject(); } } catch (ClassNotFoundException | IOException e) { throw new AssertionError(e); } } public ObjectOutputStream getObjectOutputStream(OutputStream out) throws IOException { return new OOS(out); } public ObjectInputStream getObjectInputStream(InputStream input) throws IOException { return new OIS(input, loader); } @Override public boolean equals(Object value, ByteBuffer readBuffer) { return value.equals(decode(readBuffer)); } // protected int getOrAddMapping(ObjectStreamClass desc) throws IOException { // SerializableDataKey probe = new SerializableDataKey(desc); // Integer rep = (Integer) lookup.get(probe); // if (rep == null) { // ObjectStreamClass disconnected = disconnect(desc); // SerializableDataKey key = new SerializableDataKey(disconnected); // rep = nextStreamIndex.getAndIncrement(); // // ObjectStreamClass existingOsc = (ObjectStreamClass) lookup.putIfAbsent(rep, disconnected); // if (existingOsc == null) { // Integer existingRep = (Integer) lookup.putIfAbsent(key, rep); // if (existingRep == null) { // return rep.intValue(); // } else { // /* // * A racing thread established a mapping already. We must clean up // * our half complete mapping. // */ // lookup.remove(rep); // return existingRep.intValue(); // } // } else { // //impossible as governed by AtomicInteger - excluding wrap-around == 2^32 types) // throw new AssertionError(); // } // } else { // return rep.intValue(); // } // } @FindbugsSuppressWarnings("JLM_JSR166_UTILCONCURRENT_MONITORENTER") protected int getOrAddMapping(ObjectStreamClass desc) throws IOException { SerializableDataKey probe = new SerializableDataKey(desc, false); Integer rep = (Integer) lookup.get(probe); if (rep == null) { synchronized (lookup) { rep = (Integer) lookup.get(probe); if (rep == null) { ObjectStreamClass disconnected = disconnect(desc); SerializableDataKey key = new SerializableDataKey(disconnected, true); rep = nextStreamIndex++; ObjectStreamClass existingOsc = (ObjectStreamClass) lookup.putIfAbsent(rep, disconnected); if (existingOsc != null) { throw new AssertionError("Existing mapping for this index detected : " + rep + " => " + existingOsc.getName()); } Integer existingRep = (Integer) lookup.putIfAbsent(key, rep); if (existingRep != null) { throw new AssertionError("Existing mapping to this type detected : " + existingRep + " => " + disconnected.getName()); } addedMapping(rep, disconnected); } } } return rep; } protected void addedMapping(Integer rep, ObjectStreamClass disconnected) { //no-op } class OOS extends ObjectOutputStream { public OOS(OutputStream out) throws IOException { super(out); } @Override protected void writeClassDescriptor(final ObjectStreamClass desc) throws IOException { writeInt(getOrAddMapping(desc)); } } class OIS extends ObjectInputStream { private final ClassLoader loader; public OIS(InputStream in, ClassLoader loader) throws IOException { super(in); this.loader = loader; } @Override protected ObjectStreamClass readClassDescriptor() throws IOException, ClassNotFoundException { return (ObjectStreamClass) lookup.get(readInt()); } @Override protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { try { final ClassLoader cl = loader == null ? Thread.currentThread().getContextClassLoader() : loader; if (cl == null) { return super.resolveClass(desc); } else { try { return Class.forName(desc.getName(), false, cl); } catch (ClassNotFoundException e) { return super.resolveClass(desc); } } } catch (SecurityException ex) { return super.resolveClass(desc); } } } protected static class SerializableDataKey { private final ObjectStreamClass osc; private final int hashCode; private transient WeakReference<Class<?>> klazz; public SerializableDataKey(ObjectStreamClass desc, boolean store) throws IOException { Class<?> forClass = desc.forClass(); if (forClass != null) { if (store) { throw new AssertionError("Must not store ObjectStreamClass instances with strong references to classes"); } else if (ObjectStreamClass.lookup(forClass) == desc) { this.klazz = new WeakReference<>(forClass); } } this.hashCode = (3 * desc.getName().hashCode()) ^ (7 * (int) (desc.getSerialVersionUID() >>> 32)) ^ (11 * (int) desc.getSerialVersionUID()); this.osc = desc; } @Override public boolean equals(Object o) { return o instanceof SerializableDataKey && SerializablePortability.equals(this, (SerializableDataKey) o); } @Override public int hashCode() { return hashCode; } public Class<?> forClass() { if (klazz == null) { return null; } else { return klazz.get(); } } public void setClass(Class<?> clazz) { klazz = new WeakReference<>(clazz); } public ObjectStreamClass getObjectStreamClass() { return osc; } } private static boolean equals(SerializableDataKey k1, SerializableDataKey k2) { Class<?> k1Clazz = k1.forClass(); Class<?> k2Clazz = k2.forClass(); if (k1Clazz != null && k2Clazz != null) { return k1Clazz == k2Clazz; } else if (SerializablePortability.equals(k1.getObjectStreamClass(), k2.getObjectStreamClass())) { if (k1Clazz != null) { k2.setClass(k1Clazz); } else if (k2Clazz != null) { k1.setClass(k2Clazz); } return true; } else { return false; } } private static boolean equals(ObjectStreamClass osc1, ObjectStreamClass osc2) { if (osc1 == osc2) { return true; } else if (osc1.getName().equals(osc2.getName()) && osc1.getSerialVersionUID() == osc2.getSerialVersionUID() && osc1.getFields().length == osc2.getFields().length) { try { return Arrays.equals(getSerializedForm(osc1), getSerializedForm(osc2)); } catch (IOException e) { throw new AssertionError(e); } } else { return false; } } protected static ObjectStreamClass disconnect(ObjectStreamClass desc) { try { ObjectInputStream oin = new ObjectInputStream(new ByteArrayInputStream(getSerializedForm(desc))) { @Override protected Class<?> resolveClass(ObjectStreamClass osc) throws IOException, ClassNotFoundException { //Our stored OSC instances should not reference classes - doing so could cause perm-gen leaks return null; } }; return (ObjectStreamClass) oin.readObject(); } catch (ClassNotFoundException | IOException e) { throw new AssertionError(e); } } private static byte[] getSerializedForm(ObjectStreamClass desc) throws IOException { ByteArrayOutputStream bout = new ByteArrayOutputStream(); try { try (ObjectOutputStream oout = new ObjectOutputStream(bout)) { oout.writeObject(desc); } } finally { bout.close(); } return bout.toByteArray(); } }