Задать вопрос
@daanila01

Обучаю сеть на tensorflow js но моя ошибка почти не падает какое есть решение?

Я знакомлюсь с tensorflowjs решил попробовать создать LSTM модель для работы с языками, хотел написать небольшого чат бота, но либо в моем коде ошибка либо что-то еще но ошибка сети после нескольких эпох держится около 1400+-

const natural   = require('natural')
const input     = require('input')
const tf        = require('@tensorflow/tfjs-node')
const fs        = require('fs')

// Params Network
const epochs    = 10000

// Tokenizers
const tokens = { normal: { }, reverse: { }, max: null, min: null }
const tokenizer = new natural.WordPunctTokenizer()


const maxLength = 32

const text = `Однажды в лесу, где все деревья были высокими и зелеными, жил маленький зайчик по имени Тимми.`

const data = [ ]
for ( let i = 0; i < text?.length; i+=1 ) {

    const input = text?.slice(i, i + maxLength)
    const output = text?.slice(i + maxLength, i + maxLength + maxLength)

    if ( input?.length !== maxLength || output?.length !== maxLength ) {
        continue
    }

    data.push({
        input,
        output
    })

}

for ( let str of data ) {

    str = str?.input + ' ' + str?.output

    const context = str?.replace(/ /gi, ' _ ')
    const sequence = tokenizer.tokenize(context)
    for ( const char of sequence ) {
        if ( !tokens?.normal?.[char] ) {
            tokens.normal[char] = Object.keys(tokens.normal)?.length + 1
            tokens.reverse[tokens?.normal?.[char]] = char
        }
    }
  
}

tokens.max = Math.max(...Object.keys(tokens?.normal)?.map(key => tokens?.normal?.[key]))
tokens.min = Math.min(...Object.keys(tokens?.normal)?.map(key => tokens?.normal?.[key]))

const encode = (context) => {

    context = context?.replace(/ /gi, ' _ ')
    const sequence = tokenizer.tokenize(context)

    const sequences = [ ]
    for ( const char of sequence ) {
        if ( tokens?.normal?.[char] ) {
            sequences.push(tokens?.normal?.[char])
        }
    }

    const paddedSequence = [ ...sequences?.slice(-maxLength) ]
    while ( paddedSequence?.length < maxLength ) {
        paddedSequence.push(0)
    }

    return paddedSequence

}

const decode = sequence => {

    const paddedSequence = [ ]
    for ( const char of sequence ) {
        if ( tokens?.reverse?.[char] ) {
          paddedSequence.push(tokens?.reverse?.[char] || '')
        }
    }
  
    const string = paddedSequence?.join('')?.replace(/\_/gi, ' ')?.trim()
    return string
  
}

const ask = (context, model) => {

    const encodedInput = tf.tensor(encode(context), [1, maxLength])
    const predict = model.predict(encodedInput)
    const output = predict.arraySync()[0]
    const paddedOutput = output?.map(p => Math.floor(p * tokens?.max))

    return decode(paddedOutput)

}

    ;(async () => {

        const xsData = tf.tensor(data.map(d => encode(d.input)))
        const ysData = tf.tensor(data.map(d => encode(d.output)))

        const model = tf.sequential({
            layers: [
                tf.layers.embedding({
                    inputDim: tokens?.max,
                    outputDim: maxLength,
                    inputLength: maxLength
                }),
                tf.layers.lstm({
                    units: 64,
                    activation: 'tanh',
                    returnSequences: true
                }),
                tf.layers.dropout({ rate: 0.2 }),
                tf.layers.lstm({
                    units: 32,
                    activation: 'tanh'
                }),
                tf.layers.dropout({ rate: 0.2 }),
                tf.layers.dense({ units: maxLength, activation: 'softmax' })
            ]
        })

        model.compile({
            loss: 'categoricalCrossentropy',
            optimizer: tf.train.adam(0.005),
            metrics: ['accuracy']
        })

        console.clear()

        let startTime = Date.now()
        let time = Date.now()

        await model.fit(xsData, ysData, {
            epochs,
            batchSize: 128,
            validationSplit: 0.2,
            // shuffle: true,
            callbacks: {
                onEpochEnd: async (epoche, logs) => {

                    console.clear()
                    
                    let newTime = Date.now()
                    epoche++

                    let str = ''
                    const difference = newTime - time
                    if ( difference < 1e3 ) {
                        str = `${ difference }ms`
                    } else {
                        const seconds = Math.floor(difference / 1e3)
                        str = `${ seconds }s ${ difference - seconds * 1e3 }ms`
                    }

                    console.log(`Step: ${ str }`)
                    console.log(`Remaining: ~${ Math.ceil(((epochs - epoche) * (newTime - startTime)) / epoche / (1e3 * 6)) / 10 }min`)
                    console.log(`Epoch: ${ epoche }/${ epochs } ${ Math.floor(epoche / epochs * 1000) / 10 }%`)
                    console.log(`Error: ${ logs?.loss }`)
                    console.log('')

                    time = newTime
                    await model.save('file://./model')

                }
            }
        })

        await model.save('file://./model')
        
        console.log(ask('Привет', model))
        console.log(ask('Как дела?', model))
        console.log(ask('Как дела', model))

    })()
  • Вопрос задан
  • 139 просмотров
Подписаться 2 Простой Комментировать
Пригласить эксперта
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Похожие вопросы