updates
This commit is contained in:
parent
88499d5576
commit
724d6e3c2a
3 changed files with 84 additions and 3 deletions
|
@ -72,9 +72,6 @@ func NewFromArrays(a [][]float64) *Matrix {
|
|||
return m
|
||||
}
|
||||
|
||||
// Vector is just an array of values.
|
||||
type Vector []float64
|
||||
|
||||
// Matrix represents a 2d dense array of floats.
|
||||
type Matrix struct {
|
||||
epsilon float64
|
||||
|
@ -82,6 +79,17 @@ type Matrix struct {
|
|||
rows, cols int
|
||||
}
|
||||
|
||||
// Epsilon returns the maximum precision for math operations.
|
||||
func (m *Matrix) Epsilon() float64 {
|
||||
return m.epsilon
|
||||
}
|
||||
|
||||
// WithEpsilon sets the epsilon on the matrix and returns a reference to the matrix.
|
||||
func (m *Matrix) WithEpsilon(epsilon float64) *Matrix {
|
||||
m.epsilon = epsilon
|
||||
return m
|
||||
}
|
||||
|
||||
// Arrays returns the matrix as a two dimensional jagged array.
|
||||
func (m *Matrix) Arrays() [][]float64 {
|
||||
a := make([][]float64, m.rows, m.cols)
|
||||
|
|
56
matrix/matrix_test.go
Normal file
56
matrix/matrix_test.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package matrix
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
assert "github.com/blendlabs/go-assert"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
m := New(10, 5)
|
||||
rows, cols := m.Size()
|
||||
assert.Equal(10, rows)
|
||||
assert.Equal(5, cols)
|
||||
assert.Zero(m.Get(0, 0))
|
||||
assert.Zero(m.Get(9, 4))
|
||||
}
|
||||
|
||||
func TestIdentitiy(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
id := Identity(5)
|
||||
rows, cols := id.Size()
|
||||
assert.Equal(5, rows)
|
||||
assert.Equal(5, cols)
|
||||
assert.Equal(1, id.Get(0, 0))
|
||||
assert.Equal(1, id.Get(1, 1))
|
||||
assert.Equal(1, id.Get(2, 2))
|
||||
assert.Equal(1, id.Get(3, 3))
|
||||
assert.Equal(1, id.Get(4, 4))
|
||||
assert.Equal(0, id.Get(0, 1))
|
||||
assert.Equal(0, id.Get(1, 0))
|
||||
assert.Equal(0, id.Get(4, 0))
|
||||
assert.Equal(0, id.Get(0, 4))
|
||||
}
|
||||
|
||||
func TestNewFromArrays(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
}
|
||||
|
||||
func TestOnes(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ones := Ones(5, 10)
|
||||
rows, cols := ones.Size()
|
||||
assert.Equal(5, rows)
|
||||
assert.Equal(10, cols)
|
||||
|
||||
for row := 0; row < rows; row++ {
|
||||
for col := 0; col < cols; col++ {
|
||||
assert.Equal(1, ones.Get(row, col))
|
||||
}
|
||||
}
|
||||
}
|
17
matrix/vector.go
Normal file
17
matrix/vector.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package matrix
|
||||
|
||||
// Vector is just an array of values.
|
||||
type Vector []float64
|
||||
|
||||
// DotProduct returns the dot product of two vectors.
|
||||
func (v Vector) DotProduct(v2 Vector) (result float64, err error) {
|
||||
if len(v) != len(v2) {
|
||||
err = ErrDimensionMismatch
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(v); i++ {
|
||||
result = result + (v[i] * v2[i])
|
||||
}
|
||||
return
|
||||
}
|
Loading…
Reference in a new issue