Ejemplo n.º 1
0
 // Only handles rectangle multi dim arrays now
 public static void acceptInit(JMethodDeclaration init, Hashtable constants) {
   JBlock body = init.getBody();
   JFormalParameter[] params = init.getParameters();
   for (int i = params.length - 1; i >= 0; i--) {
     LinkedList<JIntLiteral> dims = new LinkedList<JIntLiteral>();
     Object val = constants.get(params[i]);
     Object temp = val;
     while (temp instanceof Object[]) {
       dims.add(new JIntLiteral(((Object[]) temp).length));
       temp = ((Object[]) temp)[0];
     }
     if (dims.size() > 0) {
       dumpAssign(val, body, new JLocalVariableExpression(null, params[i]));
       body.addStatementFirst(
           new JExpressionStatement(
               null,
               new JAssignmentExpression(
                   null,
                   new JLocalVariableExpression(null, params[i]),
                   new JNewArrayExpression(
                       null, params[i].getType(), dims.toArray(new JExpression[0]), null)),
               null));
     }
   }
 }
Ejemplo n.º 2
0
  /**
   * generate the code for the steady state loop. Inline the work function inside of an infinite
   * while loop.
   *
   * @param filter The single fused filter of the application.
   * @return A JStatement that is a while loop with the work function inlined.
   */
  JStatement generateSteadyStateLoop(SIRFilter filter) {

    JBlock block = new JBlock(null, new JStatement[0], null);

    JBlock workBlock = (JBlock) ObjectDeepCloner.deepCopy(filter.getWork().getBody());

    // add the cloned work function to the block
    block.addStatement(workBlock);

    // return the infinite loop
    return new JWhileStatement(null, new JBooleanLiteral(null, true), block, null);
  }
Ejemplo n.º 3
0
  /**
   * Construct the main function.
   *
   * @param filter The single SIRFilter of the application.
   * @return A JBlock with the statements of the main function.
   */
  private JBlock mainFunction(SIRFilter filter) {

    JBlock statements = new JBlock(null, new JStatement[0], null);

    // create the params list, for some reason
    // calling toArray() on the list breaks a later pass
    List paramList = filter.getParams();
    JExpression[] paramArray;
    if (paramList == null || paramList.size() == 0) paramArray = new JExpression[0];
    else paramArray = (JExpression[]) paramList.toArray(new JExpression[0]);

    // add the call to the init function
    statements.addStatement(
        new JExpressionStatement(
            null,
            new JMethodCallExpression(
                null, new JThisExpression(null), filter.getInit().getName(), paramArray),
            null));

    // add the call to the pre(init)work function if this filter
    // is a two stage..
    if (filter instanceof SIRTwoStageFilter) {
      SIRTwoStageFilter two = (SIRTwoStageFilter) filter;
      statements.addStatement(
          new JExpressionStatement(
              null,
              new JMethodCallExpression(
                  null, new JThisExpression(null), two.getInitWork().getName(), new JExpression[0]),
              null));
    }

    // add the call to the work function
    statements.addStatement(generateSteadyStateLoop(filter));

    return statements;
  }
Ejemplo n.º 4
0
 private static void dumpAssign(Object array, JBlock body, JExpression prefix) {
   if (array instanceof JExpression) {
     if (((JExpression) array).isConstant())
       body.addStatementFirst(
           new JExpressionStatement(
               null, new JAssignmentExpression(null, prefix, (JExpression) array), null));
   } else if (array instanceof Object[]) {
     for (int i = ((Object[]) array).length - 1; i >= 0; i--)
       dumpAssign(
           ((Object[]) array)[i],
           body,
           new JArrayAccessExpression(null, prefix, new JIntLiteral(i)));
   } else {
     System.err.println("WARNING: Non Array input to dumpAssign" + array);
   }
 }
Ejemplo n.º 5
0
  /**
   * Make a work function for a joiner
   *
   * @param joiner the InputSliceNode that we are generating code for.
   * @param backEndBits way to refer to other portions of backend
   * @param joiner_code place to put code
   */
  public static void makeJoinerWork(
      InputSliceNode joiner, BackEndFactory backEndBits, CodeStoreHelper joiner_code) {
    JMethodDeclaration joinerWork;

    // the work function will need a temporary variable
    ALocalVariable t = ALocalVariable.makeTmp(joiner.getEdgeToNext().getType());

    Channel downstream = backEndBits.getChannel(joiner.getEdgeToNext());

    // the body of the work method
    JBlock body = new JBlock();

    if (backEndBits.sliceNeedsJoinerCode(joiner.getParent())) {
      // There should be generated code for the joiner
      // state machine in the CodeStoreHelper as the only method.
      //
      // generated code is
      // T tmp;
      // tmp = joiner_code();
      // push(tmp);
      //
      // TODO: inline the joiner code at the call site,
      // if inlining, delete joiner code after inlining leaving
      // only this method in the helper.
      assert joiner_code.getMethods().length == 1;
      JMethodDeclaration callable_joiner = joiner_code.getMethods()[0];

      body.addStatement(t.getDecl());
      body.addStatement(
          new JExpressionStatement(
              new JAssignmentExpression(
                  t.getRef(),
                  new JMethodCallExpression(callable_joiner.getName(), new JExpression[0]))));
      body.addStatement(
          new JExpressionStatement(
              new JMethodCallExpression(
                  downstream.pushMethodName(), new JExpression[] {t.getRef()})));
    } else {
      // slice does not need joiner code, so just transfer from upstream
      // to downstream buffer.
      //
      // generated code is
      // T tmp;
      // tmp = pop();
      // push(tmp);
      //
      assert joiner.getWidth() == 1;
      Channel upstream = backEndBits.getChannel(joiner.getSingleEdge());

      body.addStatement(t.getDecl());
      body.addStatement(
          new JExpressionStatement(
              new JAssignmentExpression(
                  t.getRef(),
                  new JMethodCallExpression(upstream.popMethodName(), new JExpression[0]))));
      body.addStatement(
          new JExpressionStatement(
              new JMethodCallExpression(
                  downstream.pushMethodName(), new JExpression[] {t.getRef()})));
    }
    joinerWork =
        new JMethodDeclaration(
            CStdType.Void,
            "_joinerWork_" + joiner.getNextFilter().getFilter().getName(),
            JFormalParameter.EMPTY,
            body);
    joiner_code.setWorkMethod(joinerWork);
    joiner_code.addMethod(joinerWork);
  }
Ejemplo n.º 6
0
  /**
   * Create fields and code for a joiner, as follows. Do not create a joiner if all weights are 0:
   * this code fails rather than creating nonsensical kopi code. Note that this <b>always</b>
   * creates code, if you want to reuse any existing code call {@link #getJoinerCode(InputSliceNode,
   * BackEndFactory) getJoinerCode} instead.
   *
   * <pre>
   * joiner as a state machine, driven off arrays:
   *
   * / * joiner (unless single edge, just delegated to a channel
   * arity (4) and weight s but not duplication.
   * /
   *
   * T pop_1_M() {fprintf(stderr, "pop_1_M\n"); return 0;}
   * T pop_2_M() {fprintf(stderr, "pop_2_M\n"); return 0;}
   * T pop_4_M() {fprintf(stderr, "pop_4_M\n"); return 0;}
   *
   *
   * static int joiner_M_edge = 4 - 1;
   * static int joiner_M_weight = 0;
   *
   * static inline T joiner_M() {
   *
   * / * attempt to place const either applies it to function, or gives parse error
   * do we need to move this to file scope to convince inliner to work on joiner_M?
   * /
   * static T (*pops[4])() = {
   * pop_1_M,
   * pop_2_M,
   * 0,              / * 0-weight edge * /
   * pop_4_M
   * };
   *
   * static const int weights[4] = {2, 1, 0, 2};
   *
   * while (joiner_M_weight == 0) { / * "if" if do not generate for 0-length edges. * /
   * joiner_M_edge = (joiner_M_edge + 1) % 4;
   * joiner_M_weight = weights[joiner_M_edge];
   * }
   * joiner_M_weight--;
   *
   * return pops[joiner_M_edge]();
   * }
   *
   * joiner as a case statement, which is what we implement:
   *
   *
   * static int joiner_M_unrolled_edge = 3 - 1;
   * static int joiner_M_unrolled_weight = 0;
   *
   * static inline T joiner_M_unrolled() {
   *
   * static const int weights[3] = {2-1, 1-1, 2-1};
   *
   * if (--joiner_M_unrolled_weight < 0) {
   * joiner_M_unrolled_edge = (joiner_M_unrolled_edge + 1) % 3;
   * joiner_M_unrolled_weight = weights[joiner_M_unrolled_edge];
   * }
   *
   * switch (joiner_M_unrolled_edge) {
   * case 0:
   * return pop_1_M();
   * case 1:
   * return pop_2_M();
   * case 2:
   * return pop_4_M();
   * }
   * }
   * </pre>
   *
   * @param joiner An InputSliceNode specifying joiner weights and edges.
   * @param backEndBits to get info from appropriate BackEndFactory
   * @param helper CodeStoreHelper to get the fields and method implementing the joiner
   */
  private static void makeJoinerCode(
      InputSliceNode joiner, BackEndFactory backEndBits, CodeStoreHelper helper) {
    String joiner_name = "_joiner_" + ProcessFilterSliceNode.getUid();
    String joiner_method_name = joiner_name + joiner.getNextFilter().getFilter().getName();

    // size is number of edges with non-zero weight.
    int size = 0;
    for (int w : joiner.getWeights()) {
      if (w != 0) {
        size++;
      }
    }

    assert size > 0 : "asking for code generation for null joiner";

    String edge_name = joiner_name + "_edge";
    String weight_name = joiner_name + "_weight";

    JVariableDefinition edgeVar =
        new JVariableDefinition(
            at.dms.kjc.Constants.ACC_STATIC,
            CStdType.Integer,
            edge_name,
            new JIntLiteral(size - 1));

    JFieldDeclaration edgeDecl = new JFieldDeclaration(edgeVar);
    JFieldAccessExpression edgeExpr = new JFieldAccessExpression(edge_name);

    JVariableDefinition weightVar =
        new JVariableDefinition(
            at.dms.kjc.Constants.ACC_STATIC, CStdType.Integer, weight_name, new JIntLiteral(0));

    JFieldDeclaration weightDecl = new JFieldDeclaration(weightVar);
    JFieldAccessExpression weightExpr = new JFieldAccessExpression(weight_name);

    JIntLiteral[] weightVals = new JIntLiteral[size];
    {
      int i = 0;
      for (int w : joiner.getWeights()) {
        if (w != 0) {
          weightVals[i++] = new JIntLiteral(w - 1);
        }
      }
    }

    JVariableDefinition weightsArray =
        new JVariableDefinition(
            at.dms.kjc.Constants.ACC_STATIC | at.dms.kjc.Constants.ACC_FINAL, // static const in C
            new CArrayType(CStdType.Integer, 1, new JExpression[] {new JIntLiteral(size)}),
            "weights",
            new JArrayInitializer(weightVals));
    JLocalVariableExpression weightsExpr = new JLocalVariableExpression(weightsArray);

    JStatement next_edge_weight_stmt =
        new JIfStatement(
            null,
            new JRelationalExpression(
                at.dms.kjc.Constants.OPE_LT,
                new JPrefixExpression(null, at.dms.kjc.Constants.OPE_PREDEC, weightExpr),
                new JIntLiteral(0)),
            new JBlock(
                new JStatement[] {
                  new JExpressionStatement(
                      new JAssignmentExpression(
                          edgeExpr,
                          new JModuloExpression(
                              null,
                              new JAddExpression(edgeExpr, new JIntLiteral(1)),
                              new JIntLiteral(size)))),
                  new JExpressionStatement(
                      new JAssignmentExpression(
                          weightExpr, new JArrayAccessExpression(weightsExpr, edgeExpr)))
                }),
            new JEmptyStatement(),
            null);

    JSwitchGroup[] cases = new JSwitchGroup[size]; // fill in later.
    JStatement switch_on_edge_stmt = new JSwitchStatement(null, edgeExpr, cases, null);

    {
      int i = 0;
      for (int j = 0; j < joiner.getWeights().length; j++) {
        if (joiner.getWeights()[j] != 0) {
          JMethodCallExpression pop =
              new JMethodCallExpression(
                  backEndBits.getChannel(joiner.getSources()[j]).popMethodName(),
                  new JExpression[0]);
          pop.setType(joiner.getType());

          cases[i] =
              new JSwitchGroup(
                  null,
                  new JSwitchLabel[] {new JSwitchLabel(null, new JIntLiteral(i))},
                  new JStatement[] {new JReturnStatement(null, pop, null)});
          i++;
        }
      }
    }

    JMethodDeclaration joiner_method =
        new JMethodDeclaration(
            null,
            at.dms.kjc.Constants.ACC_STATIC | at.dms.kjc.Constants.ACC_INLINE,
            joiner.getType(),
            joiner_method_name,
            new JFormalParameter[] {},
            new CClassType[] {},
            new JBlock(),
            null,
            null);

    JBlock joiner_block = joiner_method.getBody();

    joiner_block.addStatement(
        new JVariableDeclarationStatement(new JVariableDefinition[] {weightsArray}));
    joiner_block.addStatement(next_edge_weight_stmt);
    joiner_block.addStatement(switch_on_edge_stmt);

    helper.addFields(new JFieldDeclaration[] {edgeDecl, weightDecl});
    helper.addMethod(joiner_method);
  }