protected void checkClassFile(ClassFile file) throws Exception {
    Map<Integer, Triple> calls = new HashMap<>();

    ConstPool pool = file.getConstPool();
    for (int i = 1; i < pool.getSize(); ++i) {
      // we have a method call
      BytecodeUtils.Ref ref = BytecodeUtils.getRef(pool, i);
      String className = ref.getClassName(pool, i);
      if (className != null) {
        String methodName = ref.getName(pool, i);
        String methodDesc = ref.getDesc(pool, i);
        fillCalls(i, className, methodName, methodDesc, calls);
      }
    }

    if (calls.isEmpty() && annotations.isEmpty()) {
      return;
    }

    String className = file.getName();

    AnnotationsAttribute faa =
        (AnnotationsAttribute) file.getAttribute(AnnotationsAttribute.visibleTag);
    checkAnnotations(className, TYPE_USAGE.getMethodName(), faa, -1);

    List<MethodInfo> methods = file.getMethods();
    for (MethodInfo m : methods) {
      try {
        // ignore abstract methods
        if (m.getCodeAttribute() == null) {
          continue;
        }

        AnnotationsAttribute maa =
            (AnnotationsAttribute) m.getAttribute(AnnotationsAttribute.visibleTag);
        boolean annotationsChecked = false;
        int firstLine = -1;

        CodeIterator it = m.getCodeAttribute().iterator();
        while (it.hasNext()) {
          // loop through the bytecode
          final int index = it.next();
          final int line = m.getLineNumber(index);

          if (annotationsChecked == false) {
            annotationsChecked = true;
            firstLine = line;
            checkAnnotations(
                className, m.getName(), maa, line - 2); // -2 to get the line above the method
          }

          int op = it.byteAt(index);
          // if the bytecode is a method invocation
          if (op == CodeIterator.INVOKEVIRTUAL
              || op == CodeIterator.INVOKESTATIC
              || op == CodeIterator.INVOKEINTERFACE
              || op == CodeIterator.INVOKESPECIAL) {
            int val = it.s16bitAt(index + 1);
            Triple triple = calls.get(val);
            if (triple != null) {
              Map<Tuple, Set<CodeLine>> map = report.get(triple.className);
              Set<CodeLine> set = map.get(triple.tuple);
              CodeLine cl = new CodeLine(className, m.getName(), line);
              set.add(cl.modify()); // check for .jsp, etc
            }
          }
        }

        if (BaseMethodExclusion.isBridge(m) == false) {
          SignatureAttribute.MethodSignature signature =
              SignatureAttribute.toMethodSignature(m.getDescriptor());
          handleMethodSignature(className, m.getName(), firstLine - 1, signature.getReturnType());
          handleMethodSignature(
              className, m.getName(), firstLine - 1, signature.getParameterTypes());
          handleMethodSignature(
              className, m.getName(), firstLine - 1, signature.getExceptionTypes());
        }

        ParameterAnnotationsAttribute paa =
            (ParameterAnnotationsAttribute)
                m.getAttribute(ParameterAnnotationsAttribute.visibleTag);
        if (paa != null) {
          Annotation[][] paas = paa.getAnnotations();
          if (paas != null) {
            for (Annotation[] params : paas) {
              for (Annotation a : params) {
                for (Map.Entry<String, Boolean> entry : annotations.entrySet()) {
                  if (entry.getKey().equals(a.getTypeName())) {
                    checkAnnotation(
                        className, m.getName(), firstLine - 1, entry.getValue(), entry.getKey(), a);
                  }
                }
              }
            }
          }
        }

        m.getCodeAttribute().computeMaxStack();
      } catch (Exception e) {
        e.printStackTrace();
      }
    }
  }
  public boolean transformClass(ClassFile file, ClassLoader loader, boolean modifiableClass) {
    Set<Integer> methodCallLocations = new HashSet<Integer>();
    Integer newCallLocation = null;
    Integer methodReflectionLocation = null;
    Integer fakeCallRequiredLocation = null;
    // first we need to scan the constant pool looking for
    // CONSTANT_method_info_ref structures
    ConstPool pool = file.getConstPool();
    for (int i = 1; i < pool.getSize(); ++i) {
      // we have a method call
      if (pool.getTag(i) == ConstPool.CONST_Methodref) {
        String className = pool.getMethodrefClassName(i);
        String methodName = pool.getMethodrefName(i);

        if (className.equals(Method.class.getName())) {
          if (methodName.equals("invoke")) {
            // store the location in the const pool of the method ref
            methodCallLocations.add(i);
            // we have found a method call

            // if we have not already stored a reference to our new
            // method in the const pool
            if (newCallLocation == null) {
              methodReflectionLocation =
                  pool.addClassInfo("org.fakereplace.reflection.MethodReflection");
              int nt = pool.addNameAndTypeInfo("fakeCallRequired", "(Ljava/lang/reflect/Method;)Z");
              fakeCallRequiredLocation = pool.addMethodrefInfo(methodReflectionLocation, nt);
              newCallLocation = pool.addNameAndTypeInfo(METHOD_NAME, REPLACED_METHOD_DESCRIPTOR);
            }
          }
        }
      }
    }

    // this means we found an instance of the call, now we have to iterate
    // through the methods and replace instances of the call
    if (newCallLocation != null) {
      List<MethodInfo> methods = file.getMethods();
      for (MethodInfo m : methods) {
        try {
          // ignore abstract methods
          if (m.getCodeAttribute() == null) {
            continue;
          }
          CodeIterator it = m.getCodeAttribute().iterator();
          while (it.hasNext()) {
            // loop through the bytecode
            int index = it.next();
            int op = it.byteAt(index);
            // if the bytecode is a method invocation
            if (op == CodeIterator.INVOKEVIRTUAL) {
              int val = it.s16bitAt(index + 1);
              // if the method call is one of the methods we are
              // replacing
              if (methodCallLocations.contains(val)) {
                Bytecode b = new Bytecode(file.getConstPool());
                // our stack looks like Method, instance,params
                // we need Method, instance, params , Method
                b.add(Opcode.DUP_X2);
                b.add(Opcode.POP);
                b.add(Opcode.DUP_X2);
                b.add(Opcode.POP);
                b.add(Opcode.DUP_X2);
                b.addInvokestatic(
                    methodReflectionLocation, "fakeCallRequired", "(Ljava/lang/reflect/Method;)Z");
                b.add(Opcode.IFEQ);
                JumpMarker performRealCall = JumpUtils.addJumpInstruction(b);
                // now perform the fake call
                b.addInvokestatic(methodReflectionLocation, "invoke", REPLACED_METHOD_DESCRIPTOR);
                b.add(Opcode.GOTO);
                JumpMarker finish = JumpUtils.addJumpInstruction(b);
                performRealCall.mark();
                b.addInvokevirtual(Method.class.getName(), METHOD_NAME, METHOD_DESCRIPTOR);
                finish.mark();
                it.writeByte(CodeIterator.NOP, index);
                it.writeByte(CodeIterator.NOP, index + 1);
                it.writeByte(CodeIterator.NOP, index + 2);
                it.insert(b.get());
              }
            }
          }
          m.getCodeAttribute().computeMaxStack();
        } catch (Exception e) {
          log.error("Bad byte code transforming " + file.getName());
          e.printStackTrace();
        }
      }
      return true;
    } else {
      return false;
    }
  }
Ejemplo n.º 3
0
  /** Gets a string representation of the bytecode instruction at the specified position. */
  public static String instructionString(CodeIterator iter, int pos, ConstPool pool) {
    int opcode = iter.byteAt(pos);

    if (opcode > opcodes.length || opcode < 0)
      throw new IllegalArgumentException("Invalid opcode, opcode: " + opcode + " pos: " + pos);

    String opstring = opcodes[opcode];
    switch (opcode) {
      case BIPUSH:
        return opstring + " " + iter.byteAt(pos + 1);
      case SIPUSH:
        return opstring + " " + iter.s16bitAt(pos + 1);
      case LDC:
        return opstring + " " + ldc(pool, iter.byteAt(pos + 1));
      case LDC_W:
      case LDC2_W:
        return opstring + " " + ldc(pool, iter.u16bitAt(pos + 1));
      case ILOAD:
      case LLOAD:
      case FLOAD:
      case DLOAD:
      case ALOAD:
      case ISTORE:
      case LSTORE:
      case FSTORE:
      case DSTORE:
      case ASTORE:
        return opstring + " " + iter.byteAt(pos + 1);
      case IFEQ:
      case IFGE:
      case IFGT:
      case IFLE:
      case IFLT:
      case IFNE:
      case IFNONNULL:
      case IFNULL:
      case IF_ACMPEQ:
      case IF_ACMPNE:
      case IF_ICMPEQ:
      case IF_ICMPGE:
      case IF_ICMPGT:
      case IF_ICMPLE:
      case IF_ICMPLT:
      case IF_ICMPNE:
        return opstring + " " + (iter.s16bitAt(pos + 1) + pos);
      case IINC:
        return opstring + " " + iter.byteAt(pos + 1);
      case GOTO:
      case JSR:
        return opstring + " " + (iter.s16bitAt(pos + 1) + pos);
      case RET:
        return opstring + " " + iter.byteAt(pos + 1);
      case TABLESWITCH:
        return tableSwitch(iter, pos);
      case LOOKUPSWITCH:
        return lookupSwitch(iter, pos);
      case GETSTATIC:
      case PUTSTATIC:
      case GETFIELD:
      case PUTFIELD:
        return opstring + " " + fieldInfo(pool, iter.u16bitAt(pos + 1));
      case INVOKEVIRTUAL:
      case INVOKESPECIAL:
      case INVOKESTATIC:
        return opstring + " " + methodInfo(pool, iter.u16bitAt(pos + 1));
      case INVOKEINTERFACE:
        return opstring + " " + interfaceMethodInfo(pool, iter.u16bitAt(pos + 1));
      case INVOKEDYNAMIC:
        return opstring + " " + iter.u16bitAt(pos + 1);
      case NEW:
        return opstring + " " + classInfo(pool, iter.u16bitAt(pos + 1));
      case NEWARRAY:
        return opstring + " " + arrayInfo(iter.byteAt(pos + 1));
      case ANEWARRAY:
      case CHECKCAST:
        return opstring + " " + classInfo(pool, iter.u16bitAt(pos + 1));
      case WIDE:
        return wide(iter, pos);
      case MULTIANEWARRAY:
        return opstring + " " + classInfo(pool, iter.u16bitAt(pos + 1));
      case GOTO_W:
      case JSR_W:
        return opstring + " " + (iter.s32bitAt(pos + 1) + pos);
      default:
        return opstring;
    }
  }