使用Tensorflow搭建回归预测模型之八:模型与外部接口对接

前一篇中,我们讨论了模型的压缩,将标准tensorflow格式的模型文件转换成tflite格式,极大的缩小了模型的大小。

本篇我们将介绍如何使用标准C/C++来调用tflite格式的模型。

接下来依次介绍下:

一、BUILD文件修改:

# Description:
# TensorFlow Lite A/C of Traffic Assist.

package(default_visibility = ["//visibility:public"])

licenses(["notice"])  # Apache 2.0

load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")

exports_files(glob([
    "testdata/*.txt",
]))

tf_cc_binary(
    name = "ta_ac",
    srcs = [
        "get_ac_settings.h",
        "get_ac_settings_impl.h",
        "ta_ac.cc",
        "arm_caller.cc",
        "ta_ac.h",
    ],
    linkopts = tflite_linkopts() + select({
        "//tensorflow:android": [
            "-pie",  # Android 5.0 and later supports only PIE
            "-lm",  # some builtin ops, e.g., tanh, need -lm
        ],
        "//conditions:default": [],
    }),
    deps = [
        ":data_helpers",
        "//tensorflow/contrib/lite:framework",
        "//tensorflow/contrib/lite:string_util",
        "//tensorflow/contrib/lite/kernels:builtin_ops",
    ],
)

cc_library(
    name = "data_helpers",
    srcs = ["data_helpers.cc"],
    hdrs = [
        "data_helpers.h",
        "data_helpers_impl.h",
        "ta_ac.h",
    ],
    deps = [
        "//tensorflow/contrib/lite:builtin_op_data",
        "//tensorflow/contrib/lite:framework",
        "//tensorflow/contrib/lite:schema_fbs_version",
        "//tensorflow/contrib/lite:string",
        "//tensorflow/contrib/lite:string_util",
        "//tensorflow/contrib/lite/kernels:builtin_ops",
        "//tensorflow/contrib/lite/schema:schema_fbs",
    ],
)

cc_test(
    name = "ta_ac_test",
    srcs = [
        "get_ac_settings.h",
        "get_ac_settings_impl.h",
        "ta_ac_test.cc",
    ],
    data = [
        "testdata/ac_data_input.txt",
    ],
    tags = ["no_oss"],
    deps = [
        ":data_helpers",
        "@com_google_googletest//:gtest",
    ],
)

在tf_cc_library中增加一个主文件,arm_caller.cc,该文件的作用是用来模拟外部接口。

二、ta_ac.h文件改动:

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_TA_AC_TA_AC_H
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_TA_AC_TA_AC_H

#include "tensorflow/contrib/lite/string.h"

namespace tflite {
namespace ta_ac {

struct Settings {
  bool verbose = false;
  bool accel = false;
  bool input_floating = false;
  bool profiling = false;
  int loop_count = 1;
  float input_mean = 127.5f;
  float input_std = 127.5f;
  string model_name = "./model_ac.tflite";
  string input_data_name = "./ac_data_input.txt";
  string labels_file_name = "./ac_labels.txt";
  string input_layer_type = "uint8_t";
  int number_of_threads = 4;
};

struct ac_settings {
    float temp =0.0;
    int direct =0;
    int power = 0;
};

extern ac_settings RunInference(Settings* s,std::vector<float> ac_input,int data_width,int data_height,int data_channels);
extern Settings getopt(int argc, char** argv);

} // namespace ta_ac
}  // namespace tflite

#endif  // TENSORFLOW_CONTRIB_LITE_EXAMPLES_TA_AC_TA_AC_H

在命令空间里增加了一个结构体和两个接口,用于提供给外部接口调用:

1、增加了一个结构体:

struct ac_settings {
    float temp =0.0; //空调温度
    int direct =0;//空调风向
    int power = 0; //空调风力
};  

2、推理接口

extern ac_settings RunInference(Settings* s,std::vector<float> ac_input,int data_width,int data_height,int data_channels);

3、初始化参数接口,推理前的参数设置
extern Settings getopt(int argc, char** argv);

三、ta_ac.cc文件修改:

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>

#include <fcntl.h>      // NOLINT(build/include_order)
#include <getopt.h>     // NOLINT(build/include_order)
#include <sys/time.h>   // NOLINT(build/include_order)
#include <sys/types.h>  // NOLINT(build/include_order)
#include <sys/uio.h>    // NOLINT(build/include_order)
#include <unistd.h>     // NOLINT(build/include_order)

#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/optional_debug_tools.h"
#include "tensorflow/contrib/lite/string_util.h"

#include "tensorflow/contrib/lite/examples/ta_ac/data_helpers.h"
#include "tensorflow/contrib/lite/examples/ta_ac/data_helpers_impl.h"
#include "tensorflow/contrib/lite/examples/ta_ac/get_ac_settings.h"

#define LOG(x) std::cerr

namespace tflite {
namespace ta_ac {

double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }

// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings. It pads with empty strings so the length
// of the result is a multiple of 16, because our model expects that.
TfLiteStatus ReadLabelsFile(const string& file_name,
                            std::vector<string>* result,
                            size_t* found_label_count) {
  std::ifstream file(file_name);
  if (!file) {
    LOG(FATAL) << "Labels file " << file_name << " not found\n";
    return kTfLiteError;
  }
  result->clear();
  string line;
  while (std::getline(file, line)) {
    result->push_back(line);
  }
  *found_label_count = result->size();
  const int padding = 16;
  while (result->size() % padding) {
    result->emplace_back();
  }
  return kTfLiteOk;
}

void PrintProfilingInfo(const profiling::ProfileEvent* e, uint32_t op_index,
                        TfLiteRegistration registration) {
  // output something like
  // time (ms) , Node xxx, OpCode xxx, symblic name
  //      5.352, Node   5, OpCode   4, DEPTHWISE_CONV_2D

  LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3)
            << (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0
            << ", Node " << std::setw(3) << std::setprecision(3) << op_index
            << ", OpCode " << std::setw(3) << std::setprecision(3)
            << registration.builtin_code << ", "
            << EnumNameBuiltinOperator(
                   static_cast<BuiltinOperator>(registration.builtin_code))
            << "\n";
}

ac_settings RunInference(Settings* s,std::vector<float> ac_input,int data_width,int data_height,int data_channels) {
    ac_settings ac;
  if (!s->model_name.c_str()) {
    LOG(ERROR) << "no model file name\n";
    exit(-1);
  }

  std::unique_ptr<tflite::FlatBufferModel> model;
  std::unique_ptr<tflite::Interpreter> interpreter;
  model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
  if (!model) {
    LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n";
    exit(-1);
  }
  LOG(INFO) << "Loaded tensorflow lite model " << s->model_name << "\n";
  model->error_reporter();
  //LOG(INFO) << "resolved reporter\n";

  tflite::ops::builtin::BuiltinOpResolver resolver;

  tflite::InterpreterBuilder(*model, resolver)(&interpreter);
  if (!interpreter) {
    LOG(FATAL) << "Failed to construct interpreter\n";
    exit(-1);
  }

  interpreter->UseNNAPI(s->accel);

  if (s->verbose) {
    LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";
    LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";
    LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";
    LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n";

    int t_size = interpreter->tensors_size();
    for (int i = 0; i < t_size; i++) {
      if (interpreter->tensor(i)->name)
        LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
                  << interpreter->tensor(i)->bytes << ", "
                  << interpreter->tensor(i)->type << ", "
                  << interpreter->tensor(i)->params.scale << ", "
                  << interpreter->tensor(i)->params.zero_point << "\n";
    }
  }

  if (s->number_of_threads != -1) {
    interpreter->SetNumThreads(s->number_of_threads);
  }

  int input = interpreter->inputs()[0];

  if (s->verbose)
  {
      LOG(INFO) << "input: " << input << "\n";
  }

  const std::vector<int> inputs = interpreter->inputs();
  const std::vector<int> outputs = interpreter->outputs();

  if (s->verbose)
  {
    LOG(INFO) << "number of inputs: " << inputs.size() << "\n";;
    LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
  }

  if (interpreter->AllocateTensors() != kTfLiteOk)
  {
    LOG(FATAL) << "Failed to allocate tensors!";
  }

  if (s->verbose)
  {
      PrintInterpreterState(interpreter.get());
  }

  //LOG(INFO) << "interpreter->tensor(input)->type: " << interpreter->tensor(input)->type << "\n";
  switch (interpreter->tensor(input)->type) {
    case kTfLiteFloat32:
      s->input_floating = true;
      setdata<float>(interpreter->typed_tensor<float>(input), ac_input.data(), data_height, data_width, data_channels, s);
      break;
    case kTfLiteUInt8:
      setdata<uint8_t>(interpreter->typed_tensor<uint8_t>(input), ac_input.data(), data_height, data_width, data_channels, s);
      break;
    default:
      LOG(FATAL) << "cannot handle input type " << interpreter->tensor(input)->type << " yet";
      exit(-1);
  }

  profiling::Profiler* profiler = new profiling::Profiler();
  interpreter->SetProfiler(profiler);

  if (s->profiling) profiler->StartProfiling();

  struct timeval start_time, stop_time;
  gettimeofday(&start_time, nullptr);
  for (int i = 0; i < s->loop_count; i++) {
    if (interpreter->Invoke() != kTfLiteOk) {
      LOG(FATAL) << "Failed to invoke tflite!\n";
    }
  }
  gettimeofday(&stop_time, nullptr);
  LOG(INFO) << "invoked \n";
  LOG(INFO) << "inference time: "
            << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
            << " ms \n";

  if (s->profiling) {
    profiler->StopProfiling();
    auto profile_events = profiler->GetProfileEvents();
    for (int i = 0; i < profile_events.size(); i++) {
      auto op_index = profile_events[i]->event_metadata;
      const auto node_and_registration =
          interpreter->node_and_registration(op_index);
      const TfLiteRegistration registration = node_and_registration->second;
      PrintProfilingInfo(profile_events[i], op_index, registration);
    }
  }

  int output = interpreter->outputs()[0];
  //LOG(INFO) << "RunInference interpreter->tensor(output)->type: " << interpreter->tensor(output)->type << "\n";

  float temp = interpreter->typed_output_tensor<float>(0)[0];
  float direct = interpreter->typed_output_tensor<float>(1)[0];
  float power = interpreter->typed_output_tensor<float>(2)[0];

  //LOG(INFO) << "RunInference temp: " << temp << "\n";
  //LOG(INFO) << "RunInference direct: " << direct << "\n";
  //LOG(INFO) << "RunInference power: " << power << "\n";

  LOG(INFO) << "temp:" << int(temp) << " C\n";
  LOG(INFO) << "direct:" << int(direct) << " direction(0:head,1:body,2:leg)\n";
  LOG(INFO) << "power:" << int(power) << " power(0:auto,9:large)\n";

    ac.temp = temp;
    ac.direct = direct;
    ac.power = power;
  LOG(INFO) << "RunInference ac.temp: " << ac.temp << "\n";
  LOG(INFO) << "RunInference ac.direct: " << ac.direct << "\n";
  LOG(INFO) << "RunInference ac.power: " << ac.power << "\n";

    return ac;
}

void display_usage() {
  LOG(INFO) << "ta_ac\n"
            << "--accelerated, -a: [0|1], use Android NNAPI or not\n"
            << "--count, -c: loop interpreter->Invoke() for certain times\n"
            << "--input_mean, -b: input mean\n"
            << "--input_std, -s: input standard deviation\n"
            << "--data, -d: data_name.txt\n"
            << "--labels, -l: labels for the model\n"
            << "--tflite_model, -m: model_name.tflite\n"
            << "--profiling, -p: [0|1], profiling or not\n"
            << "--threads, -t: number of threads\n"
            << "--verbose, -v: [0|1] print more information\n"
            << "\n";
}

Settings getopt(int argc, char** argv)
{
        Settings s;
      int c;
        while (1) {
    static struct option long_options[] = {
        {"accelerated", required_argument, nullptr, ‘a‘},
        {"count", required_argument, nullptr, ‘c‘},
        {"verbose", required_argument, nullptr, ‘v‘},
        {"data", required_argument, nullptr, ‘d‘},
        {"labels", required_argument, nullptr, ‘l‘},
        {"tflite_model", required_argument, nullptr, ‘m‘},
        {"profiling", required_argument, nullptr, ‘p‘},
        {"threads", required_argument, nullptr, ‘t‘},
        {"input_mean", required_argument, nullptr, ‘b‘},
        {"input_std", required_argument, nullptr, ‘s‘},
        {nullptr, 0, nullptr, 0}};

    /* getopt_long stores the option index here. */
    int option_index = 0;

    c = getopt_long(argc, argv, "a:b:c:d:f:l:m:p:s:t:v:", long_options, &option_index);

    /* Detect the end of the options. */
    if (c == -1) break;

    switch (c) {
      case ‘a‘:
        s.accel = strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case ‘b‘:
        s.input_mean = strtod(optarg, nullptr);
        break;
      case ‘c‘:
        s.loop_count =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case ‘d‘:
        s.input_data_name = optarg;
        break;
      case ‘l‘:
        s.labels_file_name = optarg;
        break;
      case ‘m‘:
        s.model_name = optarg;
        break;
      case ‘p‘:
        s.profiling =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case ‘s‘:
        s.input_std = strtod(optarg, nullptr);
        break;
      case ‘t‘:
        s.number_of_threads = strtol(  // NOLINT(runtime/deprecated_fn)
            optarg, nullptr, 10);
        break;
      case ‘v‘:
        s.verbose =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case ‘h‘:
      case ‘?‘:
        /* getopt_long already printed an error message. */
        display_usage();
        exit(-1);
      default:
        exit(-1);
    }
  }
    return s;
}

#if 0
int Main(int argc, char** argv) {
  Settings s;
    ac_settings ac;
    s = getopt(argc, argv);

//input data
  int data_width = 4;
  int data_height = 3;
  int data_channels = 1;

  std::vector<float> ac_in(data_height * data_width * data_channels);
// test code {2018,11,16,14.88,31.21549,121.30741,15.18,31.20742,121.44468,14.5,14.4,14});
  ac_in[0] = 2018;
  ac_in[1] = 11;
  ac_in[2] = 16;
  ac_in[3] = 14.88f;
  ac_in[4] = 31.21549f;
  ac_in[5] = 121.30741f;
  ac_in[6] = 15.18;
  ac_in[7] = 31.20742;
  ac_in[8] = 121.44468;
  ac_in[9] = 14.5f;
  ac_in[10] = 14.4f;
  ac_in[11] = 14;

//inference
  ac = RunInference(&s,ac_in,data_width,data_height,data_channels);

//output data
  LOG(INFO) << "Main ac.temp: " << ac.temp << "\n";
  LOG(INFO) << "Main ac.direct: " << ac.direct << "\n";
  LOG(INFO) << "Main ac.power: " << ac.power << "\n";
  return 0;
}
#endif

}  // namespace ta_ac
}  // namespace tflite

#if 0
int main(int argc, char** argv) {
  printf("-----------------------\n");
  printf("-         ta_ac      --\n");
  printf("-     tflite   ok!   --\n");
  printf("-----------------------\n");
  return tflite::ta_ac::Main(argc, argv);
}

#endif

封装了两个接口:RunInference和getopt,并将推理所需的输入数据和输出数据打包和聚合在一起,可以让外部调用起来更方便。

去掉了命名空间内主函数和C主函数,将程序主入口放到arm_caller.cc里。

四、增加一个文件:arm_caller.cc

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <iomanip>
#include <string>
#include <vector>
#include <fcntl.h>      // NOLINT(build/include_order)
#include <getopt.h>     // NOLINT(build/include_order)
#include <sys/time.h>   // NOLINT(build/include_order)
#include <sys/types.h>  // NOLINT(build/include_order)
#include <sys/uio.h>    // NOLINT(build/include_order)
#include <unistd.h>     // NOLINT(build/include_order)

#include "tensorflow/contrib/lite/examples/ta_ac/ta_ac.h"

int main(int argc, char** argv) {
  printf("-----------------------\n");
  printf("-         ta_ac      --\n");
  printf("-     tflite   ok!   --\n");
  printf("-----------------------\n");
  tflite::ta_ac::Settings s;
    tflite::ta_ac::ac_settings ac;
    s = tflite::ta_ac::getopt(argc, argv);

//input data
  int data_width = 4;
  int data_height = 3;
  int data_channels = 1;

  std::vector<float> ac_in(data_height * data_width * data_channels);
// test code {2018,11,16,14.88,31.21549,121.30741,15.18,31.20742,121.44468,14.5,14.4,14});
  ac_in[0] = 2018;
  ac_in[1] = 11;
  ac_in[2] = 16;
  ac_in[3] = 14.88f;
  ac_in[4] = 31.21549f;
  ac_in[5] = 121.30741f;
  ac_in[6] = 15.18;
  ac_in[7] = 31.20742;
  ac_in[8] = 121.44468;
  ac_in[9] = 14.5f;
  ac_in[10] = 14.4f;
  ac_in[11] = 14;

//inference
  ac = tflite::ta_ac::RunInference(&s,ac_in,data_width,data_height,data_channels);

//output data
  printf("arm_caller ac.temp:%d\n", int(ac.temp));
  printf("arm_caller ac.direct:%d\n", ac.direct);
  printf("arm_caller ac.power:%d\n", ac.power);

  return 0;
}

1、定义一些标准头文件

2、特别需要注意的是,需要定义ta_ac.h头文件:

#include "tensorflow/contrib/lite/examples/ta_ac/ta_ac.h",否则无法正常编译和调用。

3、输入接口模拟

//input data
  int data_width = 4;
  int data_height = 3;
  int data_channels = 1;
  std::vector<float> ac_in(data_height * data_width * data_channels);
  ac_in[0] = 2018;
  ac_in[1] = 11;
  ac_in[2] = 16;
  ac_in[3] = 14.88f;
  ac_in[4] = 31.21549f;
  ac_in[5] = 121.30741f;
  ac_in[6] = 15.18;
  ac_in[7] = 31.20742;
  ac_in[8] = 121.44468;
  ac_in[9] = 14.5f;
  ac_in[10] = 14.4f;
  ac_in[11] = 14;

4、参数初始化

tflite::ta_ac::Settings s;
tflite::ta_ac::ac_settings ac;
s = tflite::ta_ac::getopt(argc, argv);

5、调用推理接口

ac = tflite::ta_ac::RunInference(&s,ac_in,data_width,data_height,data_channels);

6、输出推理结果

printf("arm_caller ac.temp:%d\n", int(ac.temp));
  printf("arm_caller ac.direct:%d\n", ac.direct);
  printf("arm_caller ac.power:%d\n", ac.power);

上一篇:

  使用Tensorflow搭建回归预测模型之七:模型压缩

原文地址:https://www.cnblogs.com/jimchen1218/p/11813598.html

时间: 2024-08-30 01:16:45

使用Tensorflow搭建回归预测模型之八:模型与外部接口对接的相关文章

使用Tensorflow搭建回归预测模型之二:数据准备与预处理

前言: 在前一篇中,已经搭建好了Tensorflow环境,本文将介绍如何准备数据与预处理数据. 正文: 在机器学习中,数据是非常关键的一个环节,在模型训练前对数据进行准备也预处理是非常必要的. 一.数据准备: 一般分为三个步骤:数据导入,数据清洗,数据划分. 1.数据导入: 数据存放在原始格式多种多样,具体取决于用于导入数据的机制和数据的来源.比如:有*.csv,*.txt,*xls,*.json等. 2.数据清洗: 数据清洗主要发现并纠正数据中的错误,包含检查数据的一致性,数据的无效值,以及缺

(转)一文学会用 Tensorflow 搭建神经网络

一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day 6: 快速入门 Tensorflow 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码,想看视频的也可以去他的优酷里的频道找. Tensorflow 官网 神经网络是一种数学模型,是存在于计算机的神经系统,由大量的神经元相

利用 TFLearn 快速搭建经典深度学习模型

利用 TFLearn 快速搭建经典深度学习模型 使用 TensorFlow 一个最大的好处是可以用各种运算符(Ops)灵活构建计算图,同时可以支持自定义运算符(见本公众号早期文章<TensorFlow 增加自定义运算符>).由于运算符的粒度较小,在构建深度学习模型时,代码写出来比较冗长,比如实现卷积层:5, 9 这种方式在设计较大模型时会比较麻烦,需要程序员徒手完成各个运算符之间的连接,像一些中间变量的维度变换.运算符参数选项.多个子网络连接处极易发生问题,肉眼检查也很难发现代码中潜伏的 bu

用Tensorflow完成简单的线性回归模型

思路:在数据上选择一条直线y=Wx+b,在这条直线上附件随机生成一些数据点如下图,让TensorFlow建立回归模型,去学习什么样的W和b能更好去拟合这些数据点. 1)随机生成1000个数据点,围绕在y=0.1x+0.3 周围,设置W=0.1,b=0.3,届时看构建的模型是否能学习到w和b的值. import numpy as np import tensorflow as tf import matplotlib.pyplot as plt num_points=1000 vectors_se

统计学习方法 李航---第6章 逻辑回归与最大熵模型

第6章 逻辑回归与最大熵模型 逻辑回归(logistic regression)是统计学习中的经典分类方法.最大嫡是概率模型学习的一个准则将其推广到分类问题得到最大熵模型(maximum entropy model).逻辑回归模型与最大熵模型都属于对数线性模型. 6.1 逻辑回归模型 定义6.1(逻辑分布):设X是连续随机变量,X服从逻辑斯谛分布是指 X具有下列分布函数和密度函数 式中,u为位置参数,r>0为形状参数. 逻辑分布的密度函数f(x)和分布函数F(x)的图形如图所示.分布函数属于逻辑

学会用tensorflow搭建简单的神经网络 2

#!/usr/bin/env python# _*_ coding: utf-8 _*_import tensorflow as tfimport numpy as np#add_layedef add_layer(inputs, in_size, out_size, activation_function=None): # add one more layer and return the output of this layer Weights = tf.Variable(tf.random

机器学习经典算法详解及Python实现--CART分类决策树、回归树和模型树

摘要: Classification And Regression Tree(CART)是一种很重要的机器学习算法,既可以用于创建分类树(Classification Tree),也可以用于创建回归树(Regression Tree),本文介绍了CART用于离散标签分类决策和连续特征回归时的原理.决策树创建过程分析了信息混乱度度量Gini指数.连续和离散特征的特殊处理.连续和离散特征共存时函数的特殊处理和后剪枝:用于回归时则介绍了回归树和模型树的原理.适用场景和创建过程.个人认为,回归树和模型树

在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现

现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直接调用 TensorFlow 的 C/C++ 接口来导入 TensorFlow 预训练好的模型. 1.环境配置 点此查看 C/C++ 接口的编译 2. 导入预定义的图和训练好的参数值 // set up your input paths const string pathToGraph = "/ho

吴裕雄 python 神经网络——TensorFlow实现回归模型训练预测MNIST手写数据集

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("E:\\MNIST_data\\", one_hot=True) #构建回归模型,输入原始真实值(group truth),采用sotfmax函数拟合,并定义损失函数和优化器 #定义回归模型 x = tf.placeholder(tf.float32,