package com.zeroturnaround.javarebel.util;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.WeakHashMap;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;

public class ClassLoaderLocalMap implements Opcodes {
  private static final Map classLoaderToHolderClassName = Collections.synchronizedMap(new WeakHashMap());  
  private static final Map globalMap = Collections.synchronizedMap(new HashMap()); 

  private static volatile int counter = 1;
  private static final String NAME_PREFIX = "GEN$$ClassLoaderProperties";

  private static Method defineMethod;
  private static Method findLoadedClass;

  static {
    try {
      defineMethod = ClassLoader.class.getDeclaredMethod("defineClass",
          new Class[] { String.class, byte[].class, int.class, int.class });
      defineMethod.setAccessible(true);        

      findLoadedClass = ClassLoader.class.getDeclaredMethod("findLoadedClass", new Class[] { String.class});
      findLoadedClass.setAccessible(true);   
    }
    catch (NoSuchMethodException e) {
      throw new RuntimeException(e);
    }
  }    

  public static boolean containsKey(ClassLoader cl, Object key) {    
    if (cl == null) {
      return globalMap.containsKey(key);
    }

    // Synchronizing over ClassLoader is usually safest
    synchronized (cl) {
      if (!hasHolder(cl)) return false;
      return getLocalMap(cl).containsKey(key) ;
    }
  }

  public static void put(ClassLoader cl, Object key, Object value) {
    if (cl == null) {
      globalMap.put(key, value);
      return;
    }

    // Synchronizing over ClassLoader is usually safest
    synchronized (cl) {
      getLocalMap(cl).put(key, value);
    }
  }

  public static Object get(ClassLoader cl, Object key) {
    if (cl == null) {
      return globalMap.get(key);
    }

    // Synchronizing over ClassLoader is usually safest
    synchronized (cl) {
      return getLocalMap(cl).get(key);
    }
  }

  private static boolean hasHolder(ClassLoader cl) {
    String propertiesClassName = (String) classLoaderToHolderClassName.get(cl);
    if (propertiesClassName == null)
      return false;

    try {
      Class klass = (Class) findLoadedClass.invoke(cl, new Object[] {propertiesClassName});
      if (klass == null) return false;

    } catch (IllegalArgumentException e) {
      throw new RuntimeException(e);
    } catch (IllegalAccessException e) {
      throw new RuntimeException(e);
    } catch (InvocationTargetException e) {
      throw new RuntimeException(e.getTargetException());
    }

    return true;
  }

  private static Map getLocalMap(ClassLoader cl) {
    String holderClassName = (String) classLoaderToHolderClassName.get(cl);
    if (holderClassName == null) {
      holderClassName = nextHolderName();
      classLoaderToHolderClassName.put(cl, holderClassName);
    }

    Class holderClass;
    try {
      holderClass = (Class) findLoadedClass.invoke(cl, new Object[] {holderClassName});
    } catch (IllegalArgumentException e) {
      throw new RuntimeException(e);
    } catch (IllegalAccessException e) {
      throw new RuntimeException(e);
    } catch (InvocationTargetException e) {
      throw new RuntimeException(e.getTargetException());
    }

    if (holderClass == null) {
      byte[] classBytes = buildHolderByteCode(holderClassName);

      try {
        holderClass = (Class) defineMethod.invoke(cl, 
            new Object[] {holderClassName, classBytes, new Integer(0), new Integer(classBytes.length)});        
      }
      catch (InvocationTargetException e1) {
        throw new RuntimeException(e1.getTargetException());
      }
      catch (Throwable e1) {     
        throw new RuntimeException(e1);
      }
    }

    try {
      return (Map) holderClass.getDeclaredField("localMap").get(null);
    }
    catch (Throwable e1) {     
      throw new RuntimeException(e1);
    }

  }

  private static byte[] buildHolderByteCode(String holderClassName) {
    ClassWriter cw = new ClassWriter(0);
    FieldVisitor fv;
    MethodVisitor mv;

    cw.visit(V1_2, ACC_PUBLIC + ACC_SUPER, holderClassName, null, "java/lang/Object", null);

    {
      fv = cw.visitField(ACC_PUBLIC + ACC_FINAL + ACC_STATIC, "values", "Ljava/util/Map;", null, null);
      fv.visitEnd();
    }
    {
      mv = cw.visitMethod(ACC_STATIC, "<clinit>", "()V", null, null);
      mv.visitCode();
      mv.visitTypeInsn(NEW, "java/util/HashMap");
      mv.visitInsn(DUP);
      mv.visitMethodInsn(INVOKESPECIAL, "java/util/HashMap", "<init>", "()V");
      mv.visitFieldInsn(PUTSTATIC, holderClassName, "localMap", "Ljava/util/Map;");
      mv.visitInsn(RETURN);
      mv.visitMaxs(2, 0);
      mv.visitEnd();
    }
    {
      mv = cw.visitMethod(ACC_PUBLIC, "<init>", "()V", null, null);
      mv.visitCode();
      mv.visitVarInsn(ALOAD, 0);
      mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V");
      mv.visitInsn(RETURN);
      mv.visitMaxs(1, 1);
      mv.visitEnd();
    }
    cw.visitEnd();

    return cw.toByteArray();
  }

  private static String nextHolderName() {
    return NAME_PREFIX + counter++;
  }
}
