protected void checkClassFile(ClassFile file) throws Exception { Map<Integer, Triple> calls = new HashMap<>(); ConstPool pool = file.getConstPool(); for (int i = 1; i < pool.getSize(); ++i) { // we have a method call BytecodeUtils.Ref ref = BytecodeUtils.getRef(pool, i); String className = ref.getClassName(pool, i); if (className != null) { String methodName = ref.getName(pool, i); String methodDesc = ref.getDesc(pool, i); fillCalls(i, className, methodName, methodDesc, calls); } } if (calls.isEmpty() && annotations.isEmpty()) { return; } String className = file.getName(); AnnotationsAttribute faa = (AnnotationsAttribute) file.getAttribute(AnnotationsAttribute.visibleTag); checkAnnotations(className, TYPE_USAGE.getMethodName(), faa, -1); List<MethodInfo> methods = file.getMethods(); for (MethodInfo m : methods) { try { // ignore abstract methods if (m.getCodeAttribute() == null) { continue; } AnnotationsAttribute maa = (AnnotationsAttribute) m.getAttribute(AnnotationsAttribute.visibleTag); boolean annotationsChecked = false; int firstLine = -1; CodeIterator it = m.getCodeAttribute().iterator(); while (it.hasNext()) { // loop through the bytecode final int index = it.next(); final int line = m.getLineNumber(index); if (annotationsChecked == false) { annotationsChecked = true; firstLine = line; checkAnnotations( className, m.getName(), maa, line - 2); // -2 to get the line above the method } int op = it.byteAt(index); // if the bytecode is a method invocation if (op == CodeIterator.INVOKEVIRTUAL || op == CodeIterator.INVOKESTATIC || op == CodeIterator.INVOKEINTERFACE || op == CodeIterator.INVOKESPECIAL) { int val = it.s16bitAt(index + 1); Triple triple = calls.get(val); if (triple != null) { Map<Tuple, Set<CodeLine>> map = report.get(triple.className); Set<CodeLine> set = map.get(triple.tuple); CodeLine cl = new CodeLine(className, m.getName(), line); set.add(cl.modify()); // check for .jsp, etc } } } if (BaseMethodExclusion.isBridge(m) == false) { SignatureAttribute.MethodSignature signature = SignatureAttribute.toMethodSignature(m.getDescriptor()); handleMethodSignature(className, m.getName(), firstLine - 1, signature.getReturnType()); handleMethodSignature( className, m.getName(), firstLine - 1, signature.getParameterTypes()); handleMethodSignature( className, m.getName(), firstLine - 1, signature.getExceptionTypes()); } ParameterAnnotationsAttribute paa = (ParameterAnnotationsAttribute) m.getAttribute(ParameterAnnotationsAttribute.visibleTag); if (paa != null) { Annotation[][] paas = paa.getAnnotations(); if (paas != null) { for (Annotation[] params : paas) { for (Annotation a : params) { for (Map.Entry<String, Boolean> entry : annotations.entrySet()) { if (entry.getKey().equals(a.getTypeName())) { checkAnnotation( className, m.getName(), firstLine - 1, entry.getValue(), entry.getKey(), a); } } } } } } m.getCodeAttribute().computeMaxStack(); } catch (Exception e) { e.printStackTrace(); } } }
public boolean transformClass(ClassFile file, ClassLoader loader, boolean modifiableClass) { Set<Integer> methodCallLocations = new HashSet<Integer>(); Integer newCallLocation = null; Integer methodReflectionLocation = null; Integer fakeCallRequiredLocation = null; // first we need to scan the constant pool looking for // CONSTANT_method_info_ref structures ConstPool pool = file.getConstPool(); for (int i = 1; i < pool.getSize(); ++i) { // we have a method call if (pool.getTag(i) == ConstPool.CONST_Methodref) { String className = pool.getMethodrefClassName(i); String methodName = pool.getMethodrefName(i); if (className.equals(Method.class.getName())) { if (methodName.equals("invoke")) { // store the location in the const pool of the method ref methodCallLocations.add(i); // we have found a method call // if we have not already stored a reference to our new // method in the const pool if (newCallLocation == null) { methodReflectionLocation = pool.addClassInfo("org.fakereplace.reflection.MethodReflection"); int nt = pool.addNameAndTypeInfo("fakeCallRequired", "(Ljava/lang/reflect/Method;)Z"); fakeCallRequiredLocation = pool.addMethodrefInfo(methodReflectionLocation, nt); newCallLocation = pool.addNameAndTypeInfo(METHOD_NAME, REPLACED_METHOD_DESCRIPTOR); } } } } } // this means we found an instance of the call, now we have to iterate // through the methods and replace instances of the call if (newCallLocation != null) { List<MethodInfo> methods = file.getMethods(); for (MethodInfo m : methods) { try { // ignore abstract methods if (m.getCodeAttribute() == null) { continue; } CodeIterator it = m.getCodeAttribute().iterator(); while (it.hasNext()) { // loop through the bytecode int index = it.next(); int op = it.byteAt(index); // if the bytecode is a method invocation if (op == CodeIterator.INVOKEVIRTUAL) { int val = it.s16bitAt(index + 1); // if the method call is one of the methods we are // replacing if (methodCallLocations.contains(val)) { Bytecode b = new Bytecode(file.getConstPool()); // our stack looks like Method, instance,params // we need Method, instance, params , Method b.add(Opcode.DUP_X2); b.add(Opcode.POP); b.add(Opcode.DUP_X2); b.add(Opcode.POP); b.add(Opcode.DUP_X2); b.addInvokestatic( methodReflectionLocation, "fakeCallRequired", "(Ljava/lang/reflect/Method;)Z"); b.add(Opcode.IFEQ); JumpMarker performRealCall = JumpUtils.addJumpInstruction(b); // now perform the fake call b.addInvokestatic(methodReflectionLocation, "invoke", REPLACED_METHOD_DESCRIPTOR); b.add(Opcode.GOTO); JumpMarker finish = JumpUtils.addJumpInstruction(b); performRealCall.mark(); b.addInvokevirtual(Method.class.getName(), METHOD_NAME, METHOD_DESCRIPTOR); finish.mark(); it.writeByte(CodeIterator.NOP, index); it.writeByte(CodeIterator.NOP, index + 1); it.writeByte(CodeIterator.NOP, index + 2); it.insert(b.get()); } } } m.getCodeAttribute().computeMaxStack(); } catch (Exception e) { log.error("Bad byte code transforming " + file.getName()); e.printStackTrace(); } } return true; } else { return false; } }
/** Gets a string representation of the bytecode instruction at the specified position. */ public static String instructionString(CodeIterator iter, int pos, ConstPool pool) { int opcode = iter.byteAt(pos); if (opcode > opcodes.length || opcode < 0) throw new IllegalArgumentException("Invalid opcode, opcode: " + opcode + " pos: " + pos); String opstring = opcodes[opcode]; switch (opcode) { case BIPUSH: return opstring + " " + iter.byteAt(pos + 1); case SIPUSH: return opstring + " " + iter.s16bitAt(pos + 1); case LDC: return opstring + " " + ldc(pool, iter.byteAt(pos + 1)); case LDC_W: case LDC2_W: return opstring + " " + ldc(pool, iter.u16bitAt(pos + 1)); case ILOAD: case LLOAD: case FLOAD: case DLOAD: case ALOAD: case ISTORE: case LSTORE: case FSTORE: case DSTORE: case ASTORE: return opstring + " " + iter.byteAt(pos + 1); case IFEQ: case IFGE: case IFGT: case IFLE: case IFLT: case IFNE: case IFNONNULL: case IFNULL: case IF_ACMPEQ: case IF_ACMPNE: case IF_ICMPEQ: case IF_ICMPGE: case IF_ICMPGT: case IF_ICMPLE: case IF_ICMPLT: case IF_ICMPNE: return opstring + " " + (iter.s16bitAt(pos + 1) + pos); case IINC: return opstring + " " + iter.byteAt(pos + 1); case GOTO: case JSR: return opstring + " " + (iter.s16bitAt(pos + 1) + pos); case RET: return opstring + " " + iter.byteAt(pos + 1); case TABLESWITCH: return tableSwitch(iter, pos); case LOOKUPSWITCH: return lookupSwitch(iter, pos); case GETSTATIC: case PUTSTATIC: case GETFIELD: case PUTFIELD: return opstring + " " + fieldInfo(pool, iter.u16bitAt(pos + 1)); case INVOKEVIRTUAL: case INVOKESPECIAL: case INVOKESTATIC: return opstring + " " + methodInfo(pool, iter.u16bitAt(pos + 1)); case INVOKEINTERFACE: return opstring + " " + interfaceMethodInfo(pool, iter.u16bitAt(pos + 1)); case INVOKEDYNAMIC: return opstring + " " + iter.u16bitAt(pos + 1); case NEW: return opstring + " " + classInfo(pool, iter.u16bitAt(pos + 1)); case NEWARRAY: return opstring + " " + arrayInfo(iter.byteAt(pos + 1)); case ANEWARRAY: case CHECKCAST: return opstring + " " + classInfo(pool, iter.u16bitAt(pos + 1)); case WIDE: return wide(iter, pos); case MULTIANEWARRAY: return opstring + " " + classInfo(pool, iter.u16bitAt(pos + 1)); case GOTO_W: case JSR_W: return opstring + " " + (iter.s32bitAt(pos + 1) + pos); default: return opstring; } }