stride not rows + cols
This commit is contained in:
parent
846133c6bf
commit
e4f05d3863
3 changed files with 141 additions and 61 deletions
184
matrix/matrix.go
184
matrix/matrix.go
|
@ -20,8 +20,7 @@ var (
|
||||||
func New(rows, cols int, values ...float64) *Matrix {
|
func New(rows, cols int, values ...float64) *Matrix {
|
||||||
if len(values) == 0 {
|
if len(values) == 0 {
|
||||||
return &Matrix{
|
return &Matrix{
|
||||||
rows: rows,
|
stride: cols,
|
||||||
cols: cols,
|
|
||||||
epsilon: DefaultEpsilon,
|
epsilon: DefaultEpsilon,
|
||||||
elements: make([]float64, rows*cols),
|
elements: make([]float64, rows*cols),
|
||||||
}
|
}
|
||||||
|
@ -29,8 +28,7 @@ func New(rows, cols int, values ...float64) *Matrix {
|
||||||
elems := make([]float64, rows*cols)
|
elems := make([]float64, rows*cols)
|
||||||
copy(elems, values)
|
copy(elems, values)
|
||||||
return &Matrix{
|
return &Matrix{
|
||||||
rows: rows,
|
stride: cols,
|
||||||
cols: cols,
|
|
||||||
epsilon: DefaultEpsilon,
|
epsilon: DefaultEpsilon,
|
||||||
elements: elems,
|
elements: elems,
|
||||||
}
|
}
|
||||||
|
@ -45,8 +43,8 @@ func Identity(order int) *Matrix {
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
// Zeros returns a matrix of a given size zeroed.
|
// Zero returns a matrix of a given size zeroed.
|
||||||
func Zeros(rows, cols int) *Matrix {
|
func Zero(rows, cols int) *Matrix {
|
||||||
return New(rows, cols)
|
return New(rows, cols)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,13 +56,21 @@ func Ones(rows, cols int) *Matrix {
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Matrix{
|
return &Matrix{
|
||||||
rows: rows,
|
stride: cols,
|
||||||
cols: cols,
|
|
||||||
epsilon: DefaultEpsilon,
|
epsilon: DefaultEpsilon,
|
||||||
elements: ones,
|
elements: ones,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Eye returns the eye matrix.
|
||||||
|
func Eye(n int) *Matrix {
|
||||||
|
m := Zero(n, n)
|
||||||
|
for i := 0; i < len(m.elements); i += n + 1 {
|
||||||
|
m.elements[i] = 1
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// NewFromArrays creates a matrix from a jagged array set.
|
// NewFromArrays creates a matrix from a jagged array set.
|
||||||
func NewFromArrays(a [][]float64) *Matrix {
|
func NewFromArrays(a [][]float64) *Matrix {
|
||||||
rows := len(a)
|
rows := len(a)
|
||||||
|
@ -83,9 +89,24 @@ func NewFromArrays(a [][]float64) *Matrix {
|
||||||
|
|
||||||
// Matrix represents a 2d dense array of floats.
|
// Matrix represents a 2d dense array of floats.
|
||||||
type Matrix struct {
|
type Matrix struct {
|
||||||
epsilon float64
|
epsilon float64
|
||||||
elements []float64
|
elements []float64
|
||||||
rows, cols int
|
stride int
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the matrix.
|
||||||
|
func (m *Matrix) String() string {
|
||||||
|
buffer := bytes.NewBuffer(nil)
|
||||||
|
rows, cols := m.Size()
|
||||||
|
|
||||||
|
for row := 0; row < rows; row++ {
|
||||||
|
for col := 0; col < cols; col++ {
|
||||||
|
buffer.WriteString(f64s(m.Get(row, col)))
|
||||||
|
buffer.WriteRune(' ')
|
||||||
|
}
|
||||||
|
buffer.WriteRune('\n')
|
||||||
|
}
|
||||||
|
return buffer.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Epsilon returns the maximum precision for math operations.
|
// Epsilon returns the maximum precision for math operations.
|
||||||
|
@ -99,14 +120,38 @@ func (m *Matrix) WithEpsilon(epsilon float64) *Matrix {
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Each applies the action to each element of the matrix in
|
||||||
|
// rows => cols order.
|
||||||
|
func (m *Matrix) Each(action func(row, cow int, value float64)) {
|
||||||
|
rows, cols := m.Size()
|
||||||
|
for row := 0; row < rows; row++ {
|
||||||
|
for col := 0; col < cols; col++ {
|
||||||
|
action(row, col, m.Get(row, col))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Round rounds all the values in a matrix to it epsilon,
|
||||||
|
// returning a reference to the original
|
||||||
|
func (m *Matrix) Round() *Matrix {
|
||||||
|
rows, cols := m.Size()
|
||||||
|
for row := 0; row < rows; row++ {
|
||||||
|
for col := 0; col < cols; col++ {
|
||||||
|
m.Set(row, col, roundToEpsilon(m.Get(row, col), m.epsilon))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// Arrays returns the matrix as a two dimensional jagged array.
|
// Arrays returns the matrix as a two dimensional jagged array.
|
||||||
func (m *Matrix) Arrays() [][]float64 {
|
func (m *Matrix) Arrays() [][]float64 {
|
||||||
a := make([][]float64, m.rows)
|
rows, cols := m.Size()
|
||||||
|
a := make([][]float64, rows)
|
||||||
|
|
||||||
for row := 0; row < m.rows; row++ {
|
for row := 0; row < rows; row++ {
|
||||||
a[row] = make([]float64, m.cols)
|
a[row] = make([]float64, cols)
|
||||||
|
|
||||||
for col := 0; col < m.cols; col++ {
|
for col := 0; col < cols; col++ {
|
||||||
a[row][col] = m.Get(row, col)
|
a[row][col] = m.Get(row, col)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -115,22 +160,25 @@ func (m *Matrix) Arrays() [][]float64 {
|
||||||
|
|
||||||
// Size returns the dimensions of the matrix.
|
// Size returns the dimensions of the matrix.
|
||||||
func (m *Matrix) Size() (rows, cols int) {
|
func (m *Matrix) Size() (rows, cols int) {
|
||||||
rows = m.rows
|
rows = len(m.elements) / m.stride
|
||||||
cols = m.cols
|
cols = m.stride
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSquare returns if the row count is equal to the column count.
|
// IsSquare returns if the row count is equal to the column count.
|
||||||
func (m *Matrix) IsSquare() bool {
|
func (m *Matrix) IsSquare() bool {
|
||||||
return m.rows == m.cols
|
return m.stride == (len(m.elements) / m.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSymmetric returns if the matrix is symmetric about its diagonal.
|
// IsSymmetric returns if the matrix is symmetric about its diagonal.
|
||||||
func (m *Matrix) IsSymmetric() bool {
|
func (m *Matrix) IsSymmetric() bool {
|
||||||
if m.rows != m.cols {
|
rows, cols := m.Size()
|
||||||
|
|
||||||
|
if rows != cols {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for i := 0; i < m.rows; i++ {
|
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
for j := 0; j < i; j++ {
|
for j := 0; j < i; j++ {
|
||||||
if m.Get(i, j) != m.Get(j, i) {
|
if m.Get(i, j) != m.Get(j, i) {
|
||||||
return false
|
return false
|
||||||
|
@ -142,20 +190,21 @@ func (m *Matrix) IsSymmetric() bool {
|
||||||
|
|
||||||
// Get returns the element at the given row, col.
|
// Get returns the element at the given row, col.
|
||||||
func (m *Matrix) Get(row, col int) float64 {
|
func (m *Matrix) Get(row, col int) float64 {
|
||||||
index := (m.cols * row) + col
|
index := (m.stride * row) + col
|
||||||
return m.elements[index]
|
return m.elements[index]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set sets a value.
|
// Set sets a value.
|
||||||
func (m *Matrix) Set(row, col int, val float64) {
|
func (m *Matrix) Set(row, col int, val float64) {
|
||||||
index := (m.cols * row) + col
|
index := (m.stride * row) + col
|
||||||
m.elements[index] = val
|
m.elements[index] = val
|
||||||
}
|
}
|
||||||
|
|
||||||
// Col returns a column of the matrix as a vector.
|
// Col returns a column of the matrix as a vector.
|
||||||
func (m *Matrix) Col(col int) Vector {
|
func (m *Matrix) Col(col int) Vector {
|
||||||
values := make([]float64, m.rows)
|
rows, _ := m.Size()
|
||||||
for row := 0; row < m.rows; row++ {
|
values := make([]float64, rows)
|
||||||
|
for row := 0; row < rows; row++ {
|
||||||
values[row] = m.Get(row, col)
|
values[row] = m.Get(row, col)
|
||||||
}
|
}
|
||||||
return Vector(values)
|
return Vector(values)
|
||||||
|
@ -163,8 +212,9 @@ func (m *Matrix) Col(col int) Vector {
|
||||||
|
|
||||||
// Row returns a row of the matrix as a vector.
|
// Row returns a row of the matrix as a vector.
|
||||||
func (m *Matrix) Row(row int) Vector {
|
func (m *Matrix) Row(row int) Vector {
|
||||||
values := make([]float64, m.cols)
|
_, cols := m.Size()
|
||||||
for col := 0; col < m.cols; col++ {
|
values := make([]float64, cols)
|
||||||
|
for col := 0; col < cols; col++ {
|
||||||
values[col] = m.Get(row, col)
|
values[col] = m.Get(row, col)
|
||||||
}
|
}
|
||||||
return Vector(values)
|
return Vector(values)
|
||||||
|
@ -172,18 +222,15 @@ func (m *Matrix) Row(row int) Vector {
|
||||||
|
|
||||||
// Copy returns a duplicate of a given matrix.
|
// Copy returns a duplicate of a given matrix.
|
||||||
func (m *Matrix) Copy() *Matrix {
|
func (m *Matrix) Copy() *Matrix {
|
||||||
m2 := New(m.rows, m.cols)
|
m2 := &Matrix{stride: m.stride, epsilon: m.epsilon, elements: make([]float64, len(m.elements))}
|
||||||
for row := 0; row < m.rows; row++ {
|
copy(m2.elements, m.elements)
|
||||||
for col := 0; col < m.cols; col++ {
|
|
||||||
m2.Set(row, col, m.Get(row, col))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return m2
|
return m2
|
||||||
}
|
}
|
||||||
|
|
||||||
// DiagonalVector returns a vector from the diagonal of a matrix.
|
// DiagonalVector returns a vector from the diagonal of a matrix.
|
||||||
func (m *Matrix) DiagonalVector() Vector {
|
func (m *Matrix) DiagonalVector() Vector {
|
||||||
rank := minInt(m.rows, m.cols)
|
rows, cols := m.Size()
|
||||||
|
rank := minInt(rows, cols)
|
||||||
values := make([]float64, rank)
|
values := make([]float64, rank)
|
||||||
|
|
||||||
for index := 0; index < rank; index++ {
|
for index := 0; index < rank; index++ {
|
||||||
|
@ -194,7 +241,8 @@ func (m *Matrix) DiagonalVector() Vector {
|
||||||
|
|
||||||
// Diagonal returns a matrix from the diagonal of a matrix.
|
// Diagonal returns a matrix from the diagonal of a matrix.
|
||||||
func (m *Matrix) Diagonal() *Matrix {
|
func (m *Matrix) Diagonal() *Matrix {
|
||||||
rank := minInt(m.rows, m.cols)
|
rows, cols := m.Size()
|
||||||
|
rank := minInt(rows, cols)
|
||||||
m2 := New(rank, rank)
|
m2 := New(rank, rank)
|
||||||
|
|
||||||
for index := 0; index < rank; index++ {
|
for index := 0; index < rank; index++ {
|
||||||
|
@ -211,16 +259,20 @@ func (m *Matrix) Equals(other *Matrix) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if otherRows, otherCols := other.Size(); otherRows != m.rows || otherCols != m.cols {
|
if m.stride != other.stride {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for row := 0; row < m.rows; row++ {
|
msize := len(m.elements)
|
||||||
for col := 0; col < m.cols; col++ {
|
m2size := len(other.elements)
|
||||||
if m.Get(row, col) != other.Get(row, col) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if msize != m2size {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < msize; i++ {
|
||||||
|
if m.elements[i] != other.elements[i] {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
@ -228,9 +280,10 @@ func (m *Matrix) Equals(other *Matrix) bool {
|
||||||
|
|
||||||
// L returns the matrix with zeros below the diagonal.
|
// L returns the matrix with zeros below the diagonal.
|
||||||
func (m *Matrix) L() *Matrix {
|
func (m *Matrix) L() *Matrix {
|
||||||
m2 := New(m.rows, m.cols)
|
rows, cols := m.Size()
|
||||||
for row := 0; row < m.rows; row++ {
|
m2 := New(rows, cols)
|
||||||
for col := row; col < m.cols; col++ {
|
for row := 0; row < rows; row++ {
|
||||||
|
for col := row; col < cols; col++ {
|
||||||
m2.Set(row, col, m.Get(row, col))
|
m2.Set(row, col, m.Get(row, col))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -240,37 +293,46 @@ func (m *Matrix) L() *Matrix {
|
||||||
// U returns the matrix with zeros above the diagonal.
|
// U returns the matrix with zeros above the diagonal.
|
||||||
// Does not include the diagonal.
|
// Does not include the diagonal.
|
||||||
func (m *Matrix) U() *Matrix {
|
func (m *Matrix) U() *Matrix {
|
||||||
m2 := New(m.rows, m.cols)
|
rows, cols := m.Size()
|
||||||
for row := 0; row < m.rows; row++ {
|
m2 := New(rows, cols)
|
||||||
for col := 0; col < row && col < m.cols; col++ {
|
for row := 0; row < rows; row++ {
|
||||||
|
for col := 0; col < row && col < cols; col++ {
|
||||||
m2.Set(row, col, m.Get(row, col))
|
m2.Set(row, col, m.Get(row, col))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return m2
|
return m2
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns a string representation of the matrix.
|
// math operations
|
||||||
func (m *Matrix) String() string {
|
|
||||||
buffer := bytes.NewBuffer(nil)
|
// Multiply multiplies two matrices.
|
||||||
for row := 0; row < m.rows; row++ {
|
func (m *Matrix) Multiply(m2 *Matrix) (m3 *Matrix, err error) {
|
||||||
for col := 0; col < m.cols; col++ {
|
if m.stride*m2.stride != len(m2.elements) {
|
||||||
buffer.WriteString(f64s(m.Get(row, col)))
|
return nil, ErrDimensionMismatch
|
||||||
buffer.WriteRune(' ')
|
|
||||||
}
|
|
||||||
buffer.WriteRune('\n')
|
|
||||||
}
|
}
|
||||||
return buffer.String()
|
|
||||||
|
m3 = &Matrix{epsilon: m.epsilon, stride: m2.stride, elements: make([]float64, (len(m.elements)/m.stride)*m2.stride)}
|
||||||
|
for m1c0, m3x := 0, 0; m1c0 < len(m.elements); m1c0 += m.stride {
|
||||||
|
for m2r0 := 0; m2r0 < m2.stride; m2r0++ {
|
||||||
|
for m1x, m2x := m1c0, m2r0; m2x < len(m2.elements); m2x += m2.stride {
|
||||||
|
m3.elements[m3x] += m.elements[m1x] * m2.elements[m2x]
|
||||||
|
m1x++
|
||||||
|
}
|
||||||
|
m3x++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decompositions
|
// Decompositions
|
||||||
|
|
||||||
// LU returns the LU decomposition of a matrix.
|
|
||||||
func (m *Matrix) LU() (l, u, p *Matrix) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// QR performs the qr decomposition.
|
// QR performs the qr decomposition.
|
||||||
func (m *Matrix) QR() (q, r *Matrix) {
|
func (m *Matrix) QR() (q, r *Matrix) {
|
||||||
|
defer func() {
|
||||||
|
q = q.Round()
|
||||||
|
r = r.Round()
|
||||||
|
}()
|
||||||
|
|
||||||
rows, cols := m.Size()
|
rows, cols := m.Size()
|
||||||
qr := m.Copy()
|
qr := m.Copy()
|
||||||
q = New(rows, cols)
|
q = New(rows, cols)
|
||||||
|
|
|
@ -328,3 +328,17 @@ func TestMatrixString(t *testing.T) {
|
||||||
|
|
||||||
assert.Equal("1 2 3 \n4 5 6 \n7 8 9 \n", m.String())
|
assert.Equal("1 2 3 \n4 5 6 \n7 8 9 \n", m.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMatrixQR(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
m := NewFromArrays([][]float64{
|
||||||
|
{12, -51, 4},
|
||||||
|
{6, 167, -68},
|
||||||
|
{-4, 24, -41},
|
||||||
|
})
|
||||||
|
|
||||||
|
q, r := m.QR()
|
||||||
|
assert.NotNil(q)
|
||||||
|
assert.NotNil(r)
|
||||||
|
}
|
||||||
|
|
|
@ -30,3 +30,7 @@ func maxInt(values ...int) int {
|
||||||
func f64s(v float64) string {
|
func f64s(v float64) string {
|
||||||
return strconv.FormatFloat(v, 'f', -1, 64)
|
return strconv.FormatFloat(v, 'f', -1, 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func roundToEpsilon(value, epsilon float64) float64 {
|
||||||
|
return math.Nextafter(value, value)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue