This commit is contained in:
Will Charczuk 2017-04-17 13:56:19 -07:00
parent 724d6e3c2a
commit 8c4ccc3bb6
3 changed files with 291 additions and 17 deletions

View file

@ -14,16 +14,26 @@ const (
var ( var (
// ErrDimensionMismatch is a typical error. // ErrDimensionMismatch is a typical error.
ErrDimensionMismatch = errors.New("matrix is not square, cannot invert") ErrDimensionMismatch = errors.New("dimension mismatch")
) )
// New returns a new matrix. // New returns a new matrix.
func New(rows, cols int) *Matrix { func New(rows, cols int, values ...float64) *Matrix {
if len(values) == 0 {
return &Matrix{
rows: rows,
cols: cols,
epsilon: DefaultEpsilon,
elements: make([]float64, rows*cols),
}
}
elems := make([]float64, rows*cols)
copy(elems, values)
return &Matrix{ return &Matrix{
rows: rows, rows: rows,
cols: cols, cols: cols,
epsilon: DefaultEpsilon, epsilon: DefaultEpsilon,
elements: make([]float64, rows*cols), elements: elems,
} }
} }
@ -92,9 +102,11 @@ func (m *Matrix) WithEpsilon(epsilon float64) *Matrix {
// Arrays returns the matrix as a two dimensional jagged array. // Arrays returns the matrix as a two dimensional jagged array.
func (m *Matrix) Arrays() [][]float64 { func (m *Matrix) Arrays() [][]float64 {
a := make([][]float64, m.rows, m.cols) a := make([][]float64, m.rows)
for row := 0; row < m.rows; row++ { for row := 0; row < m.rows; row++ {
a[row] = make([]float64, m.cols)
for col := 0; col < m.cols; col++ { for col := 0; col < m.cols; col++ {
a[row][col] = m.Get(row, col) a[row][col] = m.Get(row, col)
} }
@ -145,7 +157,7 @@ func (m *Matrix) Set(row, col int, val float64) {
func (m *Matrix) Col(col int) Vector { func (m *Matrix) Col(col int) Vector {
values := make([]float64, m.rows) values := make([]float64, m.rows)
for row := 0; row < m.rows; row++ { for row := 0; row < m.rows; row++ {
values[col] = m.Get(row, col) values[row] = m.Get(row, col)
} }
return Vector(values) return Vector(values)
} }
@ -181,6 +193,17 @@ func (m *Matrix) DiagonalVector() Vector {
return Vector(values) return Vector(values)
} }
// Diagonal returns a matrix from the diagonal of a matrix.
func (m *Matrix) Diagonal() *Matrix {
rank := minInt(m.rows, m.cols)
m2 := New(rank, rank)
for index := 0; index < rank; index++ {
m2.Set(index, index, m.Get(index, index))
}
return m2
}
// Equals returns if a matrix equals another matrix. // Equals returns if a matrix equals another matrix.
func (m *Matrix) Equals(other *Matrix) bool { func (m *Matrix) Equals(other *Matrix) bool {
if other == nil && m != nil { if other == nil && m != nil {
@ -226,17 +249,6 @@ func (m *Matrix) U() *Matrix {
return m2 return m2
} }
// Diagonal returns a matrix from the diagonal of a matrix.
func (m *Matrix) Diagonal() *Matrix {
rank := minInt(m.rows, m.cols)
m2 := New(rank, rank)
for index := 0; index < rank; index++ {
m2.Set(index, index, m.Get(index, index))
}
return m2
}
// String returns a string representation of the matrix. // String returns a string representation of the matrix.
func (m *Matrix) String() string { func (m *Matrix) String() string {
buffer := bytes.NewBuffer(nil) buffer := bytes.NewBuffer(nil)

View file

@ -17,6 +17,17 @@ func TestNew(t *testing.T) {
assert.Zero(m.Get(9, 4)) assert.Zero(m.Get(9, 4))
} }
func TestNewWithValues(t *testing.T) {
assert := assert.New(t)
m := New(5, 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
rows, cols := m.Size()
assert.Equal(5, rows)
assert.Equal(2, cols)
assert.Equal(1, m.Get(0, 0))
assert.Equal(10, m.Get(4, 1))
}
func TestIdentitiy(t *testing.T) { func TestIdentitiy(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
@ -38,6 +49,15 @@ func TestIdentitiy(t *testing.T) {
func TestNewFromArrays(t *testing.T) { func TestNewFromArrays(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 2, 3, 4},
{5, 6, 7, 8},
})
assert.NotNil(m)
rows, cols := m.Size()
assert.Equal(2, rows)
assert.Equal(4, cols)
} }
func TestOnes(t *testing.T) { func TestOnes(t *testing.T) {
@ -54,3 +74,245 @@ func TestOnes(t *testing.T) {
} }
} }
} }
func TestMatrixEpsilon(t *testing.T) {
assert := assert.New(t)
ones := Ones(2, 2)
ones = ones.WithEpsilon(0.001)
assert.Equal(0.001, ones.Epsilon())
}
func TestMatrixArrays(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 2, 3},
{4, 5, 6},
})
assert.NotNil(m)
arrays := m.Arrays()
assert.Equal(arrays, [][]float64{
{1, 2, 3},
{4, 5, 6},
})
}
func TestMatrixIsSquare(t *testing.T) {
assert := assert.New(t)
assert.False(NewFromArrays([][]float64{
{1, 2, 3},
{4, 5, 6},
}).IsSquare())
assert.False(NewFromArrays([][]float64{
{1, 2},
{3, 4},
{5, 6},
}).IsSquare())
assert.True(NewFromArrays([][]float64{
{1, 2},
{3, 4},
}).IsSquare())
}
func TestMatrixIsSymmetric(t *testing.T) {
assert := assert.New(t)
assert.False(NewFromArrays([][]float64{
{1, 2, 3},
{2, 1, 2},
}).IsSymmetric())
assert.False(NewFromArrays([][]float64{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
}).IsSymmetric())
assert.True(NewFromArrays([][]float64{
{1, 2, 3},
{2, 1, 2},
{3, 2, 1},
}).IsSymmetric())
}
func TestMatrixGet(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
})
assert.Equal(1, m.Get(0, 0))
assert.Equal(2, m.Get(0, 1))
assert.Equal(3, m.Get(0, 2))
assert.Equal(4, m.Get(1, 0))
assert.Equal(5, m.Get(1, 1))
assert.Equal(6, m.Get(1, 2))
assert.Equal(7, m.Get(2, 0))
assert.Equal(8, m.Get(2, 1))
assert.Equal(9, m.Get(2, 2))
}
func TestMatrixSet(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
})
m.Set(1, 1, 99)
assert.Equal(99, m.Get(1, 1))
}
func TestMatrixCol(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
})
assert.Equal([]float64{1, 4, 7}, m.Col(0))
assert.Equal([]float64{2, 5, 8}, m.Col(1))
assert.Equal([]float64{3, 6, 9}, m.Col(2))
}
func TestMatrixRow(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
})
assert.Equal([]float64{1, 2, 3}, m.Row(0))
assert.Equal([]float64{4, 5, 6}, m.Row(1))
assert.Equal([]float64{7, 8, 9}, m.Row(2))
}
func TestMatrixCopy(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
})
m2 := m.Copy()
assert.False(m == m2)
assert.True(m.Equals(m2))
}
func TestMatrixDiagonalVector(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 4, 7},
{4, 2, 8},
{7, 8, 3},
})
diag := m.DiagonalVector()
assert.Equal([]float64{1, 2, 3}, diag)
}
func TestMatrixDiagonalVectorLandscape(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 4, 7, 99},
{4, 2, 8, 99},
})
diag := m.DiagonalVector()
assert.Equal([]float64{1, 2}, diag)
}
func TestMatrixDiagonalVectorPortrait(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 4},
{4, 2},
{99, 99},
})
diag := m.DiagonalVector()
assert.Equal([]float64{1, 2}, diag)
}
func TestMatrixDiagonal(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 4, 7},
{4, 2, 8},
{7, 8, 3},
})
m2 := NewFromArrays([][]float64{
{1, 0, 0},
{0, 2, 0},
{0, 0, 3},
})
assert.True(m.Diagonal().Equals(m2))
}
func TestMatrixEquals(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 4, 7},
{4, 2, 8},
{7, 8, 3},
})
assert.False(m.Equals(nil))
var nilMatrix *Matrix
assert.True(nilMatrix.Equals(nil))
assert.False(m.Equals(New(1, 1)))
assert.False(m.Equals(New(3, 3)))
assert.True(m.Equals(New(3, 3, 1, 4, 7, 4, 2, 8, 7, 8, 3)))
}
func TestMatrixL(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
})
l := m.L()
assert.True(l.Equals(New(3, 3, 1, 2, 3, 0, 5, 6, 0, 0, 9)))
}
func TestMatrixU(t *testing.T) {
assert := assert.New(t)
m := NewFromArrays([][]float64{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
})
u := m.U()
assert.True(u.Equals(New(3, 3, 0, 0, 0, 4, 0, 0, 7, 8, 0)))
}

View file

@ -14,4 +14,4 @@ func (v Vector) DotProduct(v2 Vector) (result float64, err error) {
result = result + (v[i] * v2[i]) result = result + (v[i] * v2[i])
} }
return return
} }