diff --git a/matrix/matrix.go b/matrix/matrix.go index 837402b..81fc20e 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -72,9 +72,6 @@ func NewFromArrays(a [][]float64) *Matrix { return m } -// Vector is just an array of values. -type Vector []float64 - // Matrix represents a 2d dense array of floats. type Matrix struct { epsilon float64 @@ -82,6 +79,17 @@ type Matrix struct { rows, cols int } +// Epsilon returns the maximum precision for math operations. +func (m *Matrix) Epsilon() float64 { + return m.epsilon +} + +// WithEpsilon sets the epsilon on the matrix and returns a reference to the matrix. +func (m *Matrix) WithEpsilon(epsilon float64) *Matrix { + m.epsilon = epsilon + return m +} + // Arrays returns the matrix as a two dimensional jagged array. func (m *Matrix) Arrays() [][]float64 { a := make([][]float64, m.rows, m.cols) diff --git a/matrix/matrix_test.go b/matrix/matrix_test.go new file mode 100644 index 0000000..3216838 --- /dev/null +++ b/matrix/matrix_test.go @@ -0,0 +1,56 @@ +package matrix + +import ( + "testing" + + assert "github.com/blendlabs/go-assert" +) + +func TestNew(t *testing.T) { + assert := assert.New(t) + + m := New(10, 5) + rows, cols := m.Size() + assert.Equal(10, rows) + assert.Equal(5, cols) + assert.Zero(m.Get(0, 0)) + assert.Zero(m.Get(9, 4)) +} + +func TestIdentitiy(t *testing.T) { + assert := assert.New(t) + + id := Identity(5) + rows, cols := id.Size() + assert.Equal(5, rows) + assert.Equal(5, cols) + assert.Equal(1, id.Get(0, 0)) + assert.Equal(1, id.Get(1, 1)) + assert.Equal(1, id.Get(2, 2)) + assert.Equal(1, id.Get(3, 3)) + assert.Equal(1, id.Get(4, 4)) + assert.Equal(0, id.Get(0, 1)) + assert.Equal(0, id.Get(1, 0)) + assert.Equal(0, id.Get(4, 0)) + assert.Equal(0, id.Get(0, 4)) +} + +func TestNewFromArrays(t *testing.T) { + assert := assert.New(t) + +} + +func TestOnes(t *testing.T) { + assert := assert.New(t) + + ones := Ones(5, 10) + rows, cols := ones.Size() + assert.Equal(5, rows) + assert.Equal(10, cols) + + for row := 0; row < rows; row++ { + for col := 0; col < cols; col++ { + assert.Equal(1, ones.Get(row, col)) + } + } +} diff --git a/matrix/vector.go b/matrix/vector.go new file mode 100644 index 0000000..7141ea0 --- /dev/null +++ b/matrix/vector.go @@ -0,0 +1,17 @@ +package matrix + +// Vector is just an array of values. +type Vector []float64 + +// DotProduct returns the dot product of two vectors. +func (v Vector) DotProduct(v2 Vector) (result float64, err error) { + if len(v) != len(v2) { + err = ErrDimensionMismatch + return + } + + for i := 0; i < len(v); i++ { + result = result + (v[i] * v2[i]) + } + return +} \ No newline at end of file