예제 #1
0
 // return this + that
 public SparseVector plus(SparseVector that) {
   if (this.N != that.N) throw new RuntimeException("Vector lengths disagree");
   SparseVector c = new SparseVector(N);
   for (int i : this.st.keys()) c.put(i, this.get(i)); // c = this
   for (int i : that.st.keys()) c.put(i, that.get(i) + c.get(i)); // c = c + that
   return c;
 }
예제 #2
0
  // return the dot product of this vector with that vector
  public double dot(SparseVector that) {
    if (this.N != that.N) throw new RuntimeException("Vector lengths disagree");
    double sum = 0.0;

    // iterate over the vector with the fewest nonzeros
    if (this.st.size() <= that.st.size()) {
      for (int i : this.st.keys()) if (that.st.contains(i)) sum += this.get(i) * that.get(i);
    } else {
      for (int i : that.st.keys()) if (this.st.contains(i)) sum += this.get(i) * that.get(i);
    }
    return sum;
  }
예제 #3
0
 // test client
 public static void main(String[] args) {
   SparseVector a = new SparseVector(10);
   SparseVector b = new SparseVector(10);
   a.put(3, 0.50);
   a.put(9, 0.75);
   a.put(6, 0.11);
   a.put(6, 0.00);
   b.put(3, 0.60);
   b.put(4, 0.90);
   StdOut.println("a = " + a);
   StdOut.println("b = " + b);
   StdOut.println("a dot b = " + a.dot(b));
   StdOut.println("a + b   = " + a.plus(b));
 }
예제 #4
0
 // return alpha * this
 public SparseVector scale(double alpha) {
   SparseVector c = new SparseVector(N);
   for (int i : this.st.keys()) c.put(i, alpha * this.get(i));
   return c;
 }
예제 #5
0
 // return the 2-norm
 public double norm() {
   SparseVector a = this;
   return Math.sqrt(a.dot(a));
 }