Exemplo n.º 1
0
 /**
  * 从字节流快速加载
  *
  * @param byteArray
  * @return
  */
 public static MaxEntModel create(ByteArray byteArray) {
   MaxEntModel m = new MaxEntModel();
   m.correctionConstant = byteArray.nextInt(); // correctionConstant
   m.correctionParam = byteArray.nextDouble(); // getCorrectionParameter
   // label
   int numOutcomes = byteArray.nextInt();
   String[] outcomeLabels = new String[numOutcomes];
   m.outcomeNames = outcomeLabels;
   for (int i = 0; i < numOutcomes; i++) outcomeLabels[i] = byteArray.nextString();
   // pattern
   int numOCTypes = byteArray.nextInt();
   int[][] outcomePatterns = new int[numOCTypes][];
   for (int i = 0; i < numOCTypes; i++) {
     int length = byteArray.nextInt();
     int[] infoInts = new int[length];
     for (int j = 0; j < length; j++) {
       infoInts[j] = byteArray.nextInt();
     }
     outcomePatterns[i] = infoInts;
   }
   // feature
   int NUM_PREDS = byteArray.nextInt();
   String[] predLabels = new String[NUM_PREDS];
   m.pmap = new DoubleArrayTrie<Integer>();
   for (int i = 0; i < NUM_PREDS; i++) {
     predLabels[i] = byteArray.nextString();
   }
   Integer[] v = new Integer[NUM_PREDS];
   for (int i = 0; i < v.length; i++) {
     v[i] = byteArray.nextInt();
   }
   m.pmap.load(byteArray, v);
   // params
   Context[] params = new Context[NUM_PREDS];
   int pid = 0;
   for (int i = 0; i < outcomePatterns.length; i++) {
     int[] outcomePattern = new int[outcomePatterns[i].length - 1];
     for (int k = 1; k < outcomePatterns[i].length; k++) {
       outcomePattern[k - 1] = outcomePatterns[i][k];
     }
     for (int j = 0; j < outcomePatterns[i][0]; j++) {
       double[] contextParameters = new double[outcomePatterns[i].length - 1];
       for (int k = 1; k < outcomePatterns[i].length; k++) {
         contextParameters[k - 1] = byteArray.nextDouble();
       }
       params[pid] = new Context(outcomePattern, contextParameters);
       pid++;
     }
   }
   // prior
   m.prior = new UniformPrior();
   m.prior.setLabels(outcomeLabels);
   // eval
   m.evalParams =
       new EvalParameters(params, m.correctionParam, m.correctionConstant, outcomeLabels.length);
   return m;
 }
Exemplo n.º 2
0
 /**
  * 从文件加载,同时缓存为二进制文件
  *
  * @param path
  * @return
  */
 public static MaxEntModel create(String path) {
   MaxEntModel m = new MaxEntModel();
   try {
     BufferedReader br =
         new BufferedReader(new InputStreamReader(new FileInputStream(path), "UTF-8"));
     DataOutputStream out = new DataOutputStream(new FileOutputStream(path + Predefine.BIN_EXT));
     br.readLine(); // type
     m.correctionConstant = Integer.parseInt(br.readLine()); // correctionConstant
     out.writeInt(m.correctionConstant);
     m.correctionParam = Double.parseDouble(br.readLine()); // getCorrectionParameter
     out.writeDouble(m.correctionParam);
     // label
     int numOutcomes = Integer.parseInt(br.readLine());
     out.writeInt(numOutcomes);
     String[] outcomeLabels = new String[numOutcomes];
     m.outcomeNames = outcomeLabels;
     for (int i = 0; i < numOutcomes; i++) {
       outcomeLabels[i] = br.readLine();
       TextUtility.writeString(outcomeLabels[i], out);
     }
     // pattern
     int numOCTypes = Integer.parseInt(br.readLine());
     out.writeInt(numOCTypes);
     int[][] outcomePatterns = new int[numOCTypes][];
     for (int i = 0; i < numOCTypes; i++) {
       StringTokenizer tok = new StringTokenizer(br.readLine(), " ");
       int[] infoInts = new int[tok.countTokens()];
       out.writeInt(infoInts.length);
       for (int j = 0; tok.hasMoreTokens(); j++) {
         infoInts[j] = Integer.parseInt(tok.nextToken());
         out.writeInt(infoInts[j]);
       }
       outcomePatterns[i] = infoInts;
     }
     // feature
     int NUM_PREDS = Integer.parseInt(br.readLine());
     out.writeInt(NUM_PREDS);
     String[] predLabels = new String[NUM_PREDS];
     m.pmap = new DoubleArrayTrie<Integer>();
     TreeMap<String, Integer> tmpMap = new TreeMap<String, Integer>();
     for (int i = 0; i < NUM_PREDS; i++) {
       predLabels[i] = br.readLine();
       TextUtility.writeString(predLabels[i], out);
       tmpMap.put(predLabels[i], i);
     }
     m.pmap.build(tmpMap);
     for (Map.Entry<String, Integer> entry : tmpMap.entrySet()) {
       out.writeInt(entry.getValue());
     }
     m.pmap.save(out);
     // params
     Context[] params = new Context[NUM_PREDS];
     int pid = 0;
     for (int i = 0; i < outcomePatterns.length; i++) {
       int[] outcomePattern = new int[outcomePatterns[i].length - 1];
       for (int k = 1; k < outcomePatterns[i].length; k++) {
         outcomePattern[k - 1] = outcomePatterns[i][k];
       }
       for (int j = 0; j < outcomePatterns[i][0]; j++) {
         double[] contextParameters = new double[outcomePatterns[i].length - 1];
         for (int k = 1; k < outcomePatterns[i].length; k++) {
           contextParameters[k - 1] = Double.parseDouble(br.readLine());
           out.writeDouble(contextParameters[k - 1]);
         }
         params[pid] = new Context(outcomePattern, contextParameters);
         pid++;
       }
     }
     // prior
     m.prior = new UniformPrior();
     m.prior.setLabels(outcomeLabels);
     // eval
     m.evalParams =
         new EvalParameters(params, m.correctionParam, m.correctionConstant, outcomeLabels.length);
   } catch (Exception e) {
     logger.severe("从" + path + "加载最大熵模型失败!" + TextUtility.exceptionToString(e));
     return null;
   }
   return m;
 }